Introduction to the Algorithm
A system can be in N different, unobservable states, (i.e, we never know in which state the system actually is). The system also has a finite number of possible observable "outputs" which depend on the actual (unobservable) state of the system.
The input of Viterbi algorithm is a list of observations in a time sequence, and the algorithm calculates the state that was most probable for each time frame, corresponding to the observation.
What is also given, besides the list of observations, is the following:
- the initial probability distribution of the (unobservable) states
- for each state, the probability that it transitions into each other state (including itself)
- for each state, the probability that each observation can be observed in that state.
For more details, see e.g. Wikipedia.
Some remarks on design
Besides having a correct implementation of the algorithm, I mostly had these goals in mind:
- interface should be easy to use, and not error prone
- the data should be validated before starting the calculation
- the state of the algorithm should be observable after each step (hence
nextStep
andgetProbabilitiesForObservations
andgetPreviousStatesObservations
method) - however, it should also be easy to get the end result, which is done by the
calculate
method
Objectives of this review
While any suggestions/remarks are always welcome, below are the points which I'm mostly interested in.
Implementation
- Do you see any errors in the implementation of the Viterbi Algorithm? (In other words, where valid input would give erroneous result.)
- Can you give any invalid inputs, which are not detected by validation?
- Is there a better way to list all the enum's corresponding to all states/observations (in particular, without requiring collections of at least one elements)?
- Could the API be improved? (E.g. to make it more intuitive, easily usable, less error-prone, etc.)
- What is your opinion about the place of validation? What would be the pros of putting it into the model, instead of the machine? (I put it into the machine, because I wanted to keep the model as dumb as possible, and have all the logic in the machine.)
Tests
- Is there a way to better organize the state/observation enum's corresponding to the test cases? (Unfortunately, in Java it is not possible to define an enum within a method, but that would not be a full solution either, as some enum's are shared among test cases.)
- Can you suggest any further test cases for the algorithm itself? (For now, I did not test the "stepping" logic, that will come later.)
- Do you see any superfluous tests.
The code
N.B.: I include only relevant portions. A full, working version can be found on GitHub. As external library, the code uses Guava. (And JUnit / Hamcrest for testing.)
Implementation
public static class ViterbiModel<S extends Enum<S>, T extends Enum<T>> {
public final ImmutableMap<S, Double> initialDistributions;
public final ImmutableTable<S, S, Double> transitionProbabilities;
public final ImmutableTable<S, T, Double> emissionProbabilities;
private ViterbiModel(ImmutableMap<S, Double> initialDistributions,
ImmutableTable<S, S, Double> transitionProbabilities,
ImmutableTable<S, T, Double> emissionProbabilities) {
this.initialDistributions = checkNotNull(initialDistributions);
this.transitionProbabilities = checkNotNull(transitionProbabilities);
this.emissionProbabilities = checkNotNull(emissionProbabilities);
}
public static <S extends Enum<S>, T extends Enum<T>> Builder<S, T> builder() {
return new Builder<>();
}
public static class Builder<S extends Enum<S>, T extends Enum<T>> {
private ImmutableMap<S, Double> initialDistributions;
private ImmutableTable.Builder<S, S, Double> transitionProbabilities = ImmutableTable.builder();
private ImmutableTable.Builder<S, T, Double> emissionProbabilities = ImmutableTable.builder();
public ViterbiModel<S, T> build() {
return new ViterbiModel<S, T>(immutableEnumMap(initialDistributions), transitionProbabilities.build(), emissionProbabilities.build());
}
public Builder<S, T> withInitialDistributions(ImmutableMap<S, Double> initialDistributions) {
this.initialDistributions = initialDistributions;
return this;
}
public Builder<S, T> withTransitionProbability(S src, S dest, Double prob) {
transitionProbabilities.put(src, dest, prob);
return this;
}
public Builder<S, T> withEmissionProbability(S state, T emission, Double prob) {
emissionProbabilities.put(state, emission, prob);
return this;
}
}
}
public static class ViterbiMachine<S extends Enum<S>, T extends Enum<T>> {
private final List<S> possibleStates;
private final List<T> possibleObservations;
private final ViterbiModel<S, T> model;
private final ImmutableList<T> observations;
private Table<S, Integer, Double> stateProbsForObservations = HashBasedTable.create();
private Table<S, Integer, Optional<S>> previousStatesForObservations = HashBasedTable.create();
private int step;
public ViterbiMachine(ViterbiModel<S, T> model, ImmutableList<T> observations) {
this.model = checkNotNull(model);
this.observations = checkNotNull(observations);
try {
possibleStates = ImmutableList.copyOf(getPossibleStates());
} catch (IllegalStateException ise) {
throw new IllegalArgumentException("empty states enum, or no explicit initial distribution provided", ise);
}
try {
possibleObservations = ImmutableList.copyOf(getPossibleObservations());
} catch (IllegalStateException ise) {
throw new IllegalArgumentException("empty observations enum, or no explicit observations provided", ise);
}
validate();
initialize();
}
private void validate() {
if (model.initialDistributions.size() != possibleStates.size()) {
throw new IllegalArgumentException("model.initialDistributions.size() = " + model.initialDistributions.size());
}
double sumInitProbs = 0.0;
for (double prob: model.initialDistributions.values()) {
sumInitProbs += prob;
}
if (!doublesEqual(sumInitProbs, 1.0)) {
throw new IllegalArgumentException("the sum of initial distributions should be 1.0, was " + sumInitProbs);
}
if (observations.size() < 1) {
// should not happen (observations size already checked when retrieving possible enum values),
// only added for the sake of completeness
throw new IllegalArgumentException("at least one observation should be provided, " + observations.size() + " given");
}
if (model.transitionProbabilities.size() < 1) {
throw new IllegalArgumentException("at least one transition probability should be provided, " + model.transitionProbabilities.size() + " given");
}
for (S row : possibleStates) {
double sumRowProbs = 0.0;
for (double prob : rowOrDefault(model.transitionProbabilities, row, ImmutableMap.<S, Double>of()).values()) {
sumRowProbs += prob;
}
if (!doublesEqual(sumRowProbs, 1.0)) {
throw new IllegalArgumentException("sum of transition probabilities for each state should be one, was " + sumRowProbs + " for state " + row);
}
}
if (model.emissionProbabilities.size() < 1) {
throw new IllegalArgumentException("at least one emission probability should be provided, 0 given " + model.emissionProbabilities.size() + " given");
}
for (S row : possibleStates) {
double sumRowProbs = 0.0;
for (double prob : rowOrDefault(model.emissionProbabilities, row, ImmutableMap.<T, Double>of()).values()) {
sumRowProbs += prob;
}
if (!doublesEqual(sumRowProbs, 1.0)) {
throw new IllegalArgumentException("sum of emission probabilities for each state should be one, was " + sumRowProbs + " for state " + row);
}
}
}
private static <S, T, V> V getOrDefault(Table<S, T, V> table, S key1, T key2, V defaultValue) {
V ret = table.get(key1, key2);
if (ret == null) {
ret = defaultValue;
}
return ret;
}
private static <S, T, V> Map<T, V> rowOrDefault(Table<S, T, V> table, S key, Map<T, V> defaultValue) {
Map<T, V> ret = table.row(key);
if (ret == null) {
ret = defaultValue;
}
return ret;
}
private void initialize() {
final T firstObservation = observations.get(0);
for (S state : possibleStates) {
stateProbsForObservations.put(state, 0, model.initialDistributions.getOrDefault(state, 0.0) * getOrDefault(model.emissionProbabilities, state, firstObservation, 0.0));
previousStatesForObservations.put(state, 0, Optional.<S>empty());
}
step = 1;
}
public void nextStep() {
if (step >= observations.size()) {
throw new IllegalStateException("already finished last step");
}
for (S state : possibleStates) {
double maxProb = 0.0;
Optional<S> prevStateWithMaxProb = Optional.empty();
for (S state2 : possibleStates) {
double prob = getOrDefault(stateProbsForObservations, state2, step - 1, 0.0) * getOrDefault(model.transitionProbabilities, state2, state, 0.0);
if (prob > maxProb) {
maxProb = prob;
prevStateWithMaxProb = Optional.of(state2);
}
}
stateProbsForObservations.put(state, step, maxProb * getOrDefault(model.emissionProbabilities, state, observations.get(step), 0.0));
previousStatesForObservations.put(state, step, prevStateWithMaxProb);
}
++step;
}
public ImmutableTable<S, Integer, Double> getProbabilitiesForObservations() {
return ImmutableTable.copyOf(stateProbsForObservations);
}
public ImmutableTable<S, Integer, Optional<S>> getPreviousStatesObservations() {
return ImmutableTable.copyOf(previousStatesForObservations);
}
public List<S> finish() {
if (step != observations.size()) {
throw new IllegalStateException("step = " + step);
}
S stateWithMaxProb = possibleStates.get(0);
double maxProb = stateProbsForObservations.get(stateWithMaxProb, observations.size() - 1);
for (S state : possibleStates) {
double prob = stateProbsForObservations.get(state, observations.size() - 1);
if (prob > maxProb) {
maxProb = prob;
stateWithMaxProb = state;
}
}
List<S> result = new ArrayList<>();
for (int i = observations.size() - 1; i >= 0; --i) {
result.add(stateWithMaxProb);
stateWithMaxProb = previousStatesForObservations.get(stateWithMaxProb, i).orElse(null);
}
return Lists.reverse(result);
}
public List<S> calculate() {
for (int i = 0; i < observations.size() - 1; ++i) {
nextStep();
}
return finish();
}
private S[] getPossibleStates() {
return getEnumsFromIterator(model.initialDistributions.keySet().iterator());
}
private T[] getPossibleObservations() {
return getEnumsFromIterator(observations.iterator());
}
private static <X extends Enum<X>> X[] getEnumsFromIterator(Iterator<X> it) {
if (!it.hasNext()) {
throw new IllegalStateException("iterator should have at least one element");
}
Enum<X> val1 = it.next();
return val1.getDeclaringClass().getEnumConstants();
}
private static boolean doublesEqual(double d1, double d2) {
return Math.abs(d1 - d2) < 0.0000001;
}
}
Tests
public class ViterbiTest {
@Rule
public ExpectedException thrown = ExpectedException.none();
enum ZeroStatesZeroObservationsState { };
enum ZeroStatesZeroObservationsObservation { };
@Test
public void zeroStatesZeroObservationsIsNotOk() {
ViterbiModel<ZeroStatesZeroObservationsState, ZeroStatesZeroObservationsObservation> model = ViterbiModel.<ZeroStatesZeroObservationsState, ZeroStatesZeroObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<ZeroStatesZeroObservationsState, Double>builder()
.build())
.build();
ImmutableList<ZeroStatesZeroObservationsObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
new ViterbiMachine<>(model, observations);
}
enum ZeroStatesOneObservationState { };
enum ZeroStatesOneObservationObservation { OBSERVATION0 };
@Test
public void zeroStatesOneObservationIsNotOk() {
ViterbiModel<ZeroStatesOneObservationState, ZeroStatesOneObservationObservation> model = ViterbiModel.<ZeroStatesOneObservationState, ZeroStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<ZeroStatesOneObservationState, Double>builder()
.build())
.build();
ImmutableList<ZeroStatesOneObservationObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
new ViterbiMachine<>(model, observations);
}
enum OneStateZeroObservationsState { STATE0 };
enum OneStateZeroObservationsObservation { };
@Test
public void oneStateZeroObservationsIsNotOk() {
ViterbiModel<OneStateZeroObservationsState, OneStateZeroObservationsObservation> model = ViterbiModel.<OneStateZeroObservationsState, OneStateZeroObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateZeroObservationsState, Double>builder()
.put(OneStateZeroObservationsState.STATE0, 1.0)
.build())
.build();
ImmutableList<OneStateZeroObservationsObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty observations enum, or no explicit observations provided");
new ViterbiMachine<>(model, observations);
}
enum OneStateOneObservationState { STATE0 };
enum OneStateOneObservationObservation { OBSERVATION0 };
@Test
public void oneStateOneObservationIsOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
ViterbiMachine<OneStateOneObservationState, OneStateOneObservationObservation> machine = new ViterbiMachine<>(model, observations);
List<OneStateOneObservationState> states = machine.calculate();
final List<OneStateOneObservationState> expected = ImmutableList.of(OneStateOneObservationState.STATE0);
assertThat(states, is(expected));
}
@Test
public void oneStateOneObservationMissingInitialDistributionIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty states enum, or no explicit initial distribution provided");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationMissingObservationsIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("empty observations enum, or no explicit observations provided");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationSumInitialDistribNotOneIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.1)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("the sum of initial distributions should be 1.0, was 1.1");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationNoTransitionProbabilitiesIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("at least one transition probability should be provided, 0 given");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationSumTransitionProbabilitiesNotOneIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.1)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of transition probabilities for each state should be one, was 1.1 for state STATE0");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationZeroEmissionProbabilitiesIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("at least one emission probability should be provided, 0 given");
new ViterbiMachine<>(model, observations);
}
@Test
public void oneStateOneObservationSumEmissionProbabilitiesNotOneIsNotOk() {
ViterbiModel<OneStateOneObservationState, OneStateOneObservationObservation> model = ViterbiModel.<OneStateOneObservationState, OneStateOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateOneObservationState, Double>builder()
.put(OneStateOneObservationState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationState.STATE0, 1.0)
.withEmissionProbability(OneStateOneObservationState.STATE0, OneStateOneObservationObservation.OBSERVATION0, 1.1)
.build();
ImmutableList<OneStateOneObservationObservation> observations = ImmutableList.of(OneStateOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of emission probabilities for each state should be one, was 1.1 for state STATE0");
new ViterbiMachine<>(model, observations);
}
enum OneStateTwoObservationsState { STATE0 };
enum OneStateTwoObservationsObservation { OBSERVATION0, OBSERVATION1 };
@Test
public void oneStateTwoObservationsIsOk() {
ViterbiModel<OneStateTwoObservationsState, OneStateTwoObservationsObservation> model = ViterbiModel.<OneStateTwoObservationsState, OneStateTwoObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<OneStateTwoObservationsState, Double>builder()
.put(OneStateTwoObservationsState.STATE0, 1.0)
.build())
.withTransitionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsState.STATE0, 1.0)
.withEmissionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsObservation.OBSERVATION0, 0.4)
.withEmissionProbability(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsObservation.OBSERVATION1, 0.6)
.build();
ImmutableList<OneStateTwoObservationsObservation> observations = ImmutableList.of(OneStateTwoObservationsObservation.OBSERVATION1, OneStateTwoObservationsObservation.OBSERVATION1);
ViterbiMachine<OneStateTwoObservationsState, OneStateTwoObservationsObservation> machine = new ViterbiMachine<>(model, observations);
List<OneStateTwoObservationsState> states = machine.calculate();
final List<OneStateTwoObservationsState> expected = ImmutableList.of(OneStateTwoObservationsState.STATE0, OneStateTwoObservationsState.STATE0);
assertThat(states, is(expected));
}
enum TwoStatesOneObservationState { STATE0, STATE1 };
enum TwoStatesOneObservationObservation { OBSERVATION0 };
@Test
public void twoStatesOneObservationIsOk() {
ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
.put(TwoStatesOneObservationState.STATE0, 0.6)
.put(TwoStatesOneObservationState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE0, 0.4)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE1, 0.6)
.withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.withEmissionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);
ViterbiMachine<TwoStatesOneObservationState, TwoStatesOneObservationObservation> machine = new ViterbiMachine<>(model, observations);
List<TwoStatesOneObservationState> states = machine.calculate();
final List<TwoStatesOneObservationState> expected = ImmutableList.of(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0);
assertThat(states, is(expected));
}
@Test
public void twoStatesOneObservationTransitionsOmittedForOneStateIsNotOk() {
ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
.put(TwoStatesOneObservationState.STATE0, 0.6)
.put(TwoStatesOneObservationState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
.withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.withEmissionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of transition probabilities for each state should be one, was 0.0 for state STATE1");
new ViterbiMachine<>(model, observations);
}
@Test
public void twoStatesOneObservationEmissionsOmittedForOneStateIsNotOk() {
ViterbiModel<TwoStatesOneObservationState, TwoStatesOneObservationObservation> model = ViterbiModel.<TwoStatesOneObservationState, TwoStatesOneObservationObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesOneObservationState, Double>builder()
.put(TwoStatesOneObservationState.STATE0, 0.6)
.put(TwoStatesOneObservationState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE0, 0.7)
.withTransitionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationState.STATE1, 0.3)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE0, 0.4)
.withTransitionProbability(TwoStatesOneObservationState.STATE1, TwoStatesOneObservationState.STATE1, 0.6)
.withEmissionProbability(TwoStatesOneObservationState.STATE0, TwoStatesOneObservationObservation.OBSERVATION0, 1.0)
.build();
ImmutableList<TwoStatesOneObservationObservation> observations = ImmutableList.of(TwoStatesOneObservationObservation.OBSERVATION0, TwoStatesOneObservationObservation.OBSERVATION0);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("sum of emission probabilities for each state should be one, was 0.0 for state STATE1");
new ViterbiMachine<>(model, observations);
}
enum TwoStatesTwoObservationsState { STATE0, STATE1 };
enum TwoStatesTwoObservationsObservation { OBSERVATION0, OBSERVATION1 };
@Test
public void twoStatesTwoObservationsIsOk() {
ViterbiModel<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation> model = ViterbiModel.<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation>builder()
.withInitialDistributions(ImmutableMap.<TwoStatesTwoObservationsState, Double>builder()
.put(TwoStatesTwoObservationsState.STATE0, 0.6)
.put(TwoStatesTwoObservationsState.STATE1, 0.4)
.build())
.withTransitionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE0, 0.7)
.withTransitionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE1, 0.3)
.withTransitionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsState.STATE0, 0.4)
.withTransitionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsState.STATE1, 0.6)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsObservation.OBSERVATION0, 0.6)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsObservation.OBSERVATION1, 0.4)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsObservation.OBSERVATION0, 0.6)
.withEmissionProbability(TwoStatesTwoObservationsState.STATE1, TwoStatesTwoObservationsObservation.OBSERVATION1, 0.4)
.build();
ImmutableList<TwoStatesTwoObservationsObservation> observations = ImmutableList.of(TwoStatesTwoObservationsObservation.OBSERVATION0, TwoStatesTwoObservationsObservation.OBSERVATION0);
ViterbiMachine<TwoStatesTwoObservationsState, TwoStatesTwoObservationsObservation> machine = new ViterbiMachine<>(model, observations);
List<TwoStatesTwoObservationsState> states = machine.calculate();
final List<TwoStatesTwoObservationsState> expected = ImmutableList.of(TwoStatesTwoObservationsState.STATE0, TwoStatesTwoObservationsState.STATE0);
assertThat(states, is(expected));
}
enum WikipediaState { HEALTHY, FEVER };
enum WikipediaObservation { OK, COLD, DIZZY };
@Test
public void wikipediaSample() {
ViterbiModel<WikipediaState, WikipediaObservation> model = ViterbiModel.<WikipediaState, WikipediaObservation>builder()
.withInitialDistributions(ImmutableMap.<WikipediaState, Double>builder()
.put(WikipediaState.HEALTHY, 0.6)
.put(WikipediaState.FEVER, 0.4)
.build())
.withTransitionProbability(WikipediaState.HEALTHY, WikipediaState.HEALTHY, 0.7)
.withTransitionProbability(WikipediaState.HEALTHY, WikipediaState.FEVER, 0.3)
.withTransitionProbability(WikipediaState.FEVER, WikipediaState.HEALTHY, 0.4)
.withTransitionProbability(WikipediaState.FEVER, WikipediaState.FEVER, 0.6)
.withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.OK, 0.5)
.withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.COLD, 0.4)
.withEmissionProbability(WikipediaState.HEALTHY, WikipediaObservation.DIZZY, 0.1)
.withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.OK, 0.1)
.withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.COLD, 0.3)
.withEmissionProbability(WikipediaState.FEVER, WikipediaObservation.DIZZY, 0.6)
.build();
ImmutableList<WikipediaObservation> observations = ImmutableList.of(WikipediaObservation.OK, WikipediaObservation.COLD, WikipediaObservation.DIZZY);
ViterbiMachine<WikipediaState, WikipediaObservation> machine = new ViterbiMachine<>(model, observations);
List<WikipediaState> states = machine.calculate();
final List<WikipediaState> expected = ImmutableList.of(WikipediaState.HEALTHY, WikipediaState.HEALTHY, WikipediaState.FEVER);
assertThat(states, is(expected));
}
// ... SNIP
}
One more remark regarding the API
This API might seem verbose, but it is the best I could come up with so far. I had previously tried more concise ones, but they were more error prone, and also more difficult to manage for a large (4-5+) number of states/observations.
For reference, here are the previous attempts at the API:
public static int [] viterbi(int numStates, int numObservations,
double [] initialDistrib,
double [][] transitionProbs, double [][] emissionProbs,
int [] observations) // --> causes huge/unmenegeable arrays
public static List<String> viterbi(Set<String> states,
Set<String> emissions,
Map<Key<String>, Double> transitionProbs,
Map<Key<String>, Double> emissionProbs,
Map<String, Double> initProbs,
List<String> observations) // --> a bit better, but not type safe
2 Answers 2
If you updated your repo to contain a make/ant/maven/graven build file, I would be able to easily change and run your code. Without being able to reproduce your build environment, I can make some general comments.
Don't roll your own builder
Consider using Google's CallBuilder library to save a lot of boilerplate code. The library makes it easy to make a builder simply by annotating your constructor. You may have to implement a custom "style" class to duplicate the exact behavior you have in your custom builder; however, I think it's worthwhile. Using code generation to make builders saves you from a lot of repetitive, error-prone code, and helps enforce consistent builder interfaces across your project.
In fact, writing CallBuilder style classes for all of the Gauva data structures would be an extremely cool and useful project. But that is beyond the scope of this algorithm.
Make the constructor for ViterbiModel
more accepting
Something like:
private ViterbiModel(Map<? extends S, Double> initialDistributions,
Table<? extends S, ? extends S, Double> transitionProbabilities,
Table<? extends S, ? extends T, Double> emissionProbabilities)
Then, inside of the constructor, use the ImmutableMap.copyOf
and ImmutableTable.copyOf
methods to make and store immutable copies. These same changes need to be extended appropriately to the builder.
Make a ViterbiObservations
class
It should contain the list of observations. It should provide a Builder. This is for uniformity, matching the VirterbiModel
class.
Perform validation in constructors
Validate the ViterbiModel
and VirterbiObservations
objects separately and in the appropriate constructors. Failing early in such a case is an important way to communicate with the user. If they manage to create a VirterbiModel
without throwing any exceptions, it should be a valid one.
Be more accepting of generic types
You should have
ViterbiMachine(ViterbiModel<S, ? extends T> model, ImmutableList<T> observations)
Since a series of child observations could be emitted in a model consisting of parent types.
Extend ImmutableTable
The getOrDefault
and rowOrDefault
methods you write are nice. However, they should belong to the table class itself. So, extend ImmutableTable
to a class that has these methods.
Inline the initialize()
method
It is unclear why this is not part of the constructor.
Make a utility class
Some of your smaller functions have very little to do with VirterbiMachines
. Move them to another class.
Do not force S, T
to be enum types
I don't understand why these need to be enums. Might someone want to create a VirterbiMachine
where, say the states were integers and the outputs were strings? Surely your code could allow this.
Some nitpicks on top of the already excellent review by Benjamin:
- Use the standard library's
Objects.requireNonNull
overcheckNotNull
- The way you obtain the enum values feels a bit contrived. Consider using
values()
instead. It's a bit annoying that java's generics are as weak as they are, making it impossible to "just" callX.values()
(which is guaranteed to exist)... - You could replace the
ImmutableMap
dependency with the standard library as well (compare this SO answer). This would allow you to useEnumMap
for marginally better performance. - You could furthermore replace the Tables
stateProbsForObservations
andpreviousStatesForObservations
with maps of the typesMap<S, Double[]>
andMap<S, Optional<S>[]>
respectively. Which can be filled byEnumMap
again, resulting in a further decrease in memory footprint and increase in performance. Again, for most uses that is only marginal. - I dislike the use of exceptions as flow-control and validation in the constructor of the
ViterbiMachine
. To avoid this happening, you could check the preconditions for the operations you are performing explicitly instead of relying on downstream methods to fail with a certain exception. YMMV :) - I dislike that the library does not expose
getOrDefault
androwOrDefault
, but that's really not something you can fix :/ - The API does never imply that after calling
nextStep()
even once,calculate
will throw anIllegalStateException
. IG, I would try to avoid making obtaining results prone to illegal state exceptions. - I would also expect to be able to call
calculate()
multiple times, but that's because I'm a sucker for caching and smart & lazy calculator classes. I just enjoy implementing these...