Algorithm Implementation/Viterbi algorithm

From Wikibooks, open books for an open world
Jump to navigation Jump to search

The following implementations of the w:Viterbi algorithm were removed from an earlier copy of the Wikipedia page because they were too long and unencyclopaedic - but we hope you'll find them useful here!

Java implementation[edit | edit source]

public class Viterbi {
	private static String[] states = { "#", "NN", "VB" };
	private static String[] observations = { "I", "write", "a letter" };
	private static double[] start_probability = { 0.3, 0.4, 0.3 };
	private static double[][] transition_probability = { { 0.2, 0.2, 0.6 },
			{ 0.4, 0.1, 0.5 }, { 0.1, 0.8, 0.1 } };
	private static double[][] emission_probability = { { 0.01, 0.02, 0.02 },
			{ 0.8, 0.01, 0.5 }, { 0.19, 0.97, 0.48 } };
	private static class TNode {
		public int[] v_path;
		public double v_prob;

		public TNode( int[] v_path, double v_prob) {
			this.v_path = copyIntArray(v_path);
			this.v_prob = v_prob;
		}
	}

	private static int[] copyIntArray(int[] ia) {
		int[] newIa = new int[ia.length];
		for (int i = 0; i < ia.length; i++) {
			newIa[i] = ia[i];
		}
		return newIa;
	}

	private static int[] copyIntArray(int[] ia, int newInt) {
		int[] newIa = new int[ia.length + 1];
		for (int i = 0; i < ia.length; i++) {
			newIa[i] = ia[i];
		}
		newIa[ia.length] = newInt;
		return newIa;
	}

	// forwardViterbi(observations, states, start_probability,
	// transition_probability, emission_probability)
	public int[] forwardViterbi(String[] y, String[] X, double[] sp,
			double[][] tp, double[][] ep) {
		TNode[] T = new TNode[X.length];
		for (int state = 0; state < X.length; state++) {
			int[] intArray = new int[1];
			intArray[0] = state;
			T[state] = new TNode( intArray, sp[state] * ep[state][0]);
		}

		for (int output = 1; output < y.length; output++) {
			TNode[] U = new TNode[X.length];
			for (int next_state = 0; next_state < X.length; next_state++) {
				int[] argmax = new int[0];
				double valmax = 0;
				for (int state = 0; state < X.length; state++) {
					int[] v_path = copyIntArray(T[state].v_path);
					double v_prob = T[state].v_prob;
					double p = ep[next_state][output] * tp[state][next_state];
					v_prob *= p;
					if (v_prob > valmax) {
						if (v_path.length == y.length) {
							argmax = copyIntArray(v_path);
						} else {
							argmax = copyIntArray(v_path, next_state);
						}
						valmax = v_prob;

					}
				}
				U[next_state] = new TNode( argmax, valmax);
			}
			T = U;
		}
		// apply sum/max to the final states:
		int[] argmax = new int[0];
		double valmax = 0;
		for (int state = 0; state < X.length; state++) {
			int[] v_path = copyIntArray(T[state].v_path);
			double v_prob = T[state].v_prob;
			if (v_prob > valmax) {
				argmax = copyIntArray(v_path);
				valmax = v_prob;
			}
		}
		System.out.print("Viterbi path: [");
		for (int i = 0; i < argmax.length; i++) {
			System.out.print(states[argmax[i]] + ", ");
		}
		System.out.println("].\n Probability of the whole system: " + valmax);
		return argmax;
	}

	public static void main(String[] args) throws Exception {
		System.out.print("\nStates: ");
		for (int i = 0; i < states.length; i++) {
			System.out.print(states[i] + ", ");
		}
		System.out.print("\n\nObservations: ");
		for (int i = 0; i < observations.length; i++) {
			System.out.print(observations[i] + ", ");
		}
		System.out.print("\n\nStart probability: ");
		for (int i = 0; i < states.length; i++) {
			System.out.print(states[i] + ": " + start_probability[i] + ", ");
		}
		System.out.println("\n\nTransition probability:");
		for (int i = 0; i < states.length; i++) {
			System.out.print(" " + states[i] + ": {");
			for (int j = 0; j < states.length; j++) {
				System.out.print("  " + states[j] + ": "
						+ transition_probability[i][j] + ", ");
			}
			System.out.println("}");
		}
		System.out.println("\n\nEmission probability:");
		for (int i = 0; i < states.length; i++) {
			System.out.print(" " + states[i] + ": {");
			for (int j = 0; j < observations.length; j++) {
				System.out.print("  " + observations[j] + ": "
						+ emission_probability[i][j] + ", ");
			}
			System.out.println("}");
		}
		Viterbi v=new Viterbi();
		v.forwardViterbi(observations, states, start_probability,
				transition_probability, emission_probability);
		
	}
}

F# implementation[edit | edit source]

(* Nick Heiner *)

(* Viterbi algorithm, as described here: http://people.ccmr.cornell.edu/~ginsparg/INFO295/vit.pdf

  priorProbs: prior probability of a hidden state occuring
  transitions: probability of one hidden state transitioning into another
  emissionProbs: probability of a hidden state emitting an observed state
  observation: a sequence of observed states
  hiddens: a list of all possible hidden states

  Returns: probability of most likely path * hidden state list representing the path

*)
let viterbi (priorProbs : 'hidden -> float) (transitions : ('hidden * 'hidden) -> float) (emissionProbs : (('observed * 'hidden) -> float))
  (observation : 'observed []) (hiddens : 'hidden list) : float * 'hidden list =

  (* Referred to as v_state(time) in the notes *)
  (* Probability of the most probable path ending in state at time *)
  let rec mostLikelyPathProb (state : 'hidden) (time : int) : float * 'hidden list =
    let emission = emissionProbs (observation.[time], state)
    match time with 
      (* If we're at time 0, then just use the emission probability and the prior probability for this state *)
      | 0 -> emission * priorProbs state, [state]

      (* If we're not at time 0, then recursively look for the most likely path ending at this time *)
      | t when t > 0 ->
          let prob, path = Seq.maxBy fst (seq { for hiddenState in hiddens -> 
                                                (* Recursively look for most likely path at t - 1 *)
                                                let prob, path = mostLikelyPathProb hiddenState (time - 1)
                                                (* Rate each path by how likely it is to transition into the current state *)
                                                transitions (List.head path, state) * prob, path})
          emission * prob, state::path

      (* If time is < 0, then throw an error *)
      | _ -> failwith "time must be >= 0"

  (* Look for the most likely path that ends at t_max *)
  let prob, revPath = Seq.maxBy fst (seq { for hiddenState in hiddens -> mostLikelyPathProb hiddenState ((Array.length observation) - 1)}) 
  prob, List.rev revPath

(* example using data from this article: *)
type wikiHiddens = Healthy | Fever
let wikiHiddenList = [Healthy; Fever]
type wikiObservations = Normal | Cold | Dizzy

let wikiPriors = function
  | Healthy -> 0.6
  | Fever -> 0.4

let wikiTransitions = function
  | (Healthy, Healthy) -> 0.7
  | (Healthy, Fever) -> 0.3
  | (Fever, Healthy) -> 0.4
  | (Fever, Fever) -> 0.6

let wikiEmissions = function
  | (Cold, Healthy) -> 0.4
  | (Normal, Healthy) -> 0.5
  | (Dizzy, Healthy) -> 0.1
  | (Cold, Fever) -> 0.3
  | (Normal, Fever) -> 0.1
  | (Dizzy, Fever) -> 0.6

viterbi wikiPriors wikiTransitions wikiEmissions [| Dizzy; Normal; Cold |] wikiHiddenList

Clojure implementation[edit | edit source]

(ns ident.viterbi
  (:use [clojure.pprint]))

(defstruct hmm :n :m :init-probs :emission-probs :state-transitions)

(defn make-hmm [{:keys [states, obs, init-probs, emission-probs, state-transitions]}]
  (struct-map hmm
    :n (count states)
    :m (count obs)
    :states states
    :obs obs
    :init-probs init-probs ;; n dim
    :emission-probs emission-probs ;;m x n
    :state-transitions state-transitions))

(defn indexed [s]
  (map vector (iterate inc 0) s))

(defn argmax [coll]
  (loop [s (indexed coll)
         max (first s)]
    (if (empty? s)
      max
      (let [[idx elt] (first s)
            [max-indx max-elt] max]
        (if (> elt max-elt)
          (recur (rest s) (first s))
	  (recur (rest s) max))))))

(defn pprint-hmm [hmm]
  (println "number of states: " (:n hmm) " number of observations:  " (:m hmm))
  (print "init probabilities: ") (pprint (:init-probs hmm))
  (print "emission probs: " ) (pprint (:emission-probs hmm))
  (print "state-transitions: " ) (pprint (:state-transitions hmm)))

(defn init-alphas [hmm obs]
  (map (fn [x]
         (* (aget (:init-probs hmm) x) (aget (:emission-probs hmm) x obs)))
       (range (:n hmm))))

(defn forward [hmm alphas obs]
  (map (fn [j]
         (* (reduce (fn [sum i]
                      (+ sum (* (nth alphas i) (aget (:state-transitions hmm) i j))))
                    0
                    (range (:n hmm)))
            (aget (:emission-probs hmm) j obs))) (range (:n hmm))))

(defn delta-max [hmm deltas obs]
  (map (fn [j]
         (* (apply max (map (fn [i]
                              (* (nth deltas i)
                                 (aget (:state-transitions hmm) i j)))
                            (range (:n hmm))))
            (aget (:emission-probs hmm) j obs)))
       (range (:n hmm))))

(defn backtrack [paths deltas]
  (loop [path (reverse paths)
         term (first (argmax deltas))
         backtrack []]
    (if (empty? path)
      (reverse (conj backtrack term))
      (recur (rest path) (nth (first path) term) (conj backtrack term)))))

(defn update-paths [hmm deltas]
  (map (fn [j]
         (first (argmax (map (fn [i]
                                  (* (nth deltas i)
                                     (aget (:state-transitions hmm) i j)))
                                (range (:n hmm))))))
       (range (:n hmm))))

(defn viterbi [hmm observs]
  (loop [obs (rest observs)
         alphas (init-alphas hmm (first observs))
         deltas alphas
         paths []]
    (if (empty? obs)
      [(backtrack paths deltas) (float (reduce + alphas))]
      (recur (rest obs)
             (forward hmm alphas (first obs))
             (delta-max hmm deltas (first obs))
             (conj paths (update-paths hmm deltas))))))

C# implementation[edit | edit source]

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Viterbi
{
    class Program
    {
        //Weather states
        static String HEALTHY = "Healthy";
        static String FEVER = "Fever";
        //Dependable actions (observations)
        static String DIZZY = "dizzy";
        static String COLD = "cold";
        static String NORMAL = "normal";

        static void Main(string[] args)
        {
            //initialize our arrays of states and observations
            String[] states = { HEALTHY, FEVER };
            String[] observations = { DIZZY, COLD, NORMAL };

            var start_probability = new Dictionary<String, float>();
            start_probability.Add(HEALTHY, 0.6f);
            start_probability.Add(FEVER, 0.4f);
            //Transition probability
            var transition_probability = new Dictionary<String, Dictionary<String, float>>();
            var t1 = new Dictionary<String, float>();
            t1.Add(HEALTHY, 0.7f);
            t1.Add(FEVER, 0.3f);
            Dictionary<String, float> t2 = new Dictionary<String, float>();
            t2.Add(HEALTHY, 0.4f);
            t2.Add(FEVER, 0.6f);
            transition_probability.Add(HEALTHY, t1);
            transition_probability.Add(FEVER, t2);

            //emission_probability
            var emission_probability = new Dictionary<String, Dictionary<String, float>>();
            var e1 = new Dictionary<String, float>();
            e1.Add(DIZZY, 0.1f);
            e1.Add(COLD, 0.4f);
            e1.Add(NORMAL, 0.5f);

            Dictionary<String, float> e2 = new Dictionary<String, float>();
            e2.Add(DIZZY, 0.6f);
            e2.Add(COLD, 0.3f);
            e2.Add(NORMAL, 0.1f);

            emission_probability.Add(HEALTHY, e1);
            emission_probability.Add(FEVER, e2);

            Object[] ret = forward_viterbi(observations, states, start_probability, transition_probability, emission_probability);
            Console.WriteLine((float)ret[0]);
            Console.WriteLine((String)ret[1]);
            Console.WriteLine((float)ret[2]);
            Console.ReadLine();

        }

        public static Object[] forward_viterbi(String[] obs, String[] states, Dictionary<String, float> start_p, Dictionary<String, Dictionary<String, float>> trans_p, Dictionary<String, Dictionary<String, float>> emit_p)
        {

            var T = new Dictionary<String, Object[]>();
            foreach (String state in states)
            {
                T.Add(state, new Object[] { start_p[state], state, start_p[state] });
            }

            foreach (String output in obs)
            {
                var U = new Dictionary<String, Object[]>();

                foreach (String next_state in states)
                {
                    float total = 0;
                    String argmax = "";
                    float valmax = 0;

                    float prob = 1;
                    String v_path = "";
                    float v_prob = 1;

                    foreach (String source_state in states)
                    {

                        Object[] objs = T[source_state];
                        prob = ((float)objs[0]);
                        v_path = (String)objs[1];
                        v_prob = ((float)objs[2]);

                        float p = emit_p[source_state][output] * trans_p[source_state][next_state];
                        prob *= p;
                        v_prob *= p;
                        total += prob;

                        if (v_prob > valmax)
                        {
                            argmax = v_path + "," + next_state;
                            valmax = v_prob;
                        }
                    }
                    U.Add(next_state, new Object[] { total, argmax, valmax });
                }
                T = U;
            }

            float xtotal = 0;
            String xargmax = "";
            float xvalmax = 0;

            float xprob;
            String xv_path;
            float xv_prob;

            foreach (String state in states)
            {
                Object[] objs = T[state];
                xprob = ((float)objs[0]);
                xv_path = ((String)objs[1]);
                xv_prob = ((float)objs[2]);

                xtotal += xprob;
                if (xv_prob > xvalmax)
                {
                    xargmax = xv_path;
                    xvalmax = xv_prob;
                }
            }
            return new Object[] { xtotal, xargmax, xvalmax };
        }
    }
}