Algorithm Implementation/Viterbi algorithm

From Wikibooks, open books for an open world
Jump to: navigation, search

The following implementations of the w:Viterbi algorithm were removed from an earlier copy of the Wikipedia page because they were too long and unencyclopaedic - but we hope you'll find them useful here!

Java implementation[edit]

import java.util.Hashtable;
 
public class Viterbi 
{
	static final String HEALTHY = "Healthy";
	static final String FEVER = "Fever";
 
	static final String DIZZY = "dizzy";
	static final String COLD = "cold";
	static final String NORMAL = "normal";
 
	public static void main(String[] args) 
	{
		String[] states = new String[] {HEALTHY, FEVER};
 
		String[] observations = new String[] {DIZZY, COLD, NORMAL};
 
		Hashtable<String, Float> start_probability = new Hashtable<String, Float>();
		start_probability.put(HEALTHY, 0.6f);
		start_probability.put(FEVER, 0.4f);
 
		// transition_probability
		Hashtable<String, Hashtable<String, Float>> transition_probability = 
			new Hashtable<String, Hashtable<String, Float>>();
			Hashtable<String, Float> t1 = new Hashtable<String, Float>();
			t1.put(HEALTHY, 0.7f);
			t1.put(FEVER, 0.3f);
			Hashtable<String, Float> t2 = new Hashtable<String, Float>();
			t2.put(HEALTHY, 0.4f);
			t2.put(FEVER, 0.6f);
		transition_probability.put(HEALTHY, t1);
		transition_probability.put(FEVER, t2);
 
		// emission_probability
		Hashtable<String, Hashtable<String, Float>> emission_probability = 
			new Hashtable<String, Hashtable<String, Float>>();
			Hashtable<String, Float> e1 = new Hashtable<String, Float>();
			e1.put(DIZZY, 0.1f);		
			e1.put(COLD, 0.4f); 
			e1.put(NORMAL, 0.5f);
			Hashtable<String, Float> e2 = new Hashtable<String, Float>();
			e2.put(DIZZY, 0.6f);		
			e2.put(COLD, 0.3f); 
			e2.put(NORMAL, 0.1f);
		emission_probability.put(HEALTHY, e1);
		emission_probability.put(FEVER, e2);
 
		Object[] ret = forward_viterbi(observations,
                           states,
                           start_probability,
                           transition_probability,
                           emission_probability);
		System.out.println(((Float) ret[0]).floatValue());		
		System.out.println((String) ret[1]);
		System.out.println(((Float) ret[2]).floatValue());
	}
 
	public static Object[] forward_viterbi(String[] obs, String[] states,
			Hashtable<String, Float> start_p,
			Hashtable<String, Hashtable<String, Float>> trans_p,
			Hashtable<String, Hashtable<String, Float>> emit_p)
	{
		Hashtable<String, Object[]> T = new Hashtable<String, Object[]>();
		for (String state : states)
			T.put(state, new Object[] {start_p.get(state), state, start_p.get(state)});
 
		for (String output : obs)
		{
			Hashtable<String, Object[]> U = new Hashtable<String, Object[]>();
			for (String next_state : states)
			{
				float total = 0;
				String argmax = "";
				float valmax = 0;
 
				float prob = 1;
				String v_path = "";
				float v_prob = 1;	
 
				for (String source_state : states)
				{
					Object[] objs = T.get(source_state);
					prob = ((Float) objs[0]).floatValue();
					v_path = (String) objs[1];
					v_prob = ((Float) objs[2]).floatValue();
 
					float p = emit_p.get(source_state).get(output) *
							  trans_p.get(source_state).get(next_state);
					prob *= p;
					v_prob *= p;
					total += prob;
					if (v_prob > valmax)
					{
						argmax = v_path + "," + next_state;
						valmax = v_prob;
					}
				}
				U.put(next_state, new Object[] {total, argmax, valmax});
			}
			T = U;			
		}
 
		float total = 0;
		String argmax = "";
		float valmax = 0;
 
		float prob;
		String v_path;
		float v_prob;
 
		for (String state : states)
		{
			Object[] objs = T.get(state);
			prob = ((Float) objs[0]).floatValue();
			v_path = (String) objs[1];
			v_prob = ((Float) objs[2]).floatValue();
			total += prob;
			if (v_prob > valmax)
			{
				argmax = v_path;
				valmax = v_prob;
			}
		}	
		return new Object[]{total, argmax, valmax};	
	}
}

F# implementation[edit]

(* Nick Heiner *)
 
(* Viterbi algorithm, as described here: http://people.ccmr.cornell.edu/~ginsparg/INFO295/vit.pdf 
 
  priorProbs: prior probability of a hidden state occuring
  transitions: probability of one hidden state transitioning into another
  emissionProbs: probability of a hidden state emitting an observed state
  observation: a sequence of observed states
  hiddens: a list of all possible hidden states
 
  Returns: probability of most likely path * hidden state list representing the path
 
*)
let viterbi (priorProbs : 'hidden -> float) (transitions : ('hidden * 'hidden) -> float) (emissionProbs : (('observed * 'hidden) -> float))
  (observation : 'observed []) (hiddens : 'hidden list) : float * 'hidden list =
 
  (* Referred to as v_state(time) in the notes *)
  (* Probability of the most probable path ending in state at time *)
  let rec mostLikelyPathProb (state : 'hidden) (time : int) : float * 'hidden list =
    let emission = emissionProbs (observation.[time], state)
    match time with 
      (* If we're at time 0, then just use the emission probability and the prior probability for this state *)
      | 0 -> emission * priorProbs state, [state]
 
      (* If we're not at time 0, then recursively look for the most likely path ending at this time *)
      | t when t > 0 ->
          let prob, path = Seq.maxBy fst (seq { for hiddenState in hiddens -> 
                                                (* Recursively look for most likely path at t - 1 *)
                                                let prob, path = mostLikelyPathProb hiddenState (time - 1)
                                                (* Rate each path by how likely it is to transition into the current state *)
                                                transitions (List.head path, state) * prob, path})
          emission * prob, state::path
 
      (* If time is < 0, then throw an error *)
      | _ -> failwith "time must be >= 0"
 
  (* Look for the most likely path that ends at t_max *)
  let prob, revPath = Seq.maxBy fst (seq { for hiddenState in hiddens -> mostLikelyPathProb hiddenState ((Array.length observation) - 1)}) 
  prob, List.rev revPath
 
(* example using data from this article: *)
type wikiHiddens = Healthy | Fever
let wikiHiddenList = [Healthy; Fever]
type wikiObservations = Normal | Cold | Dizzy
 
let wikiPriors = function
  | Healthy -> 0.6
  | Fever -> 0.4
 
let wikiTransitions = function
  | (Healthy, Healthy) -> 0.7
  | (Healthy, Fever) -> 0.4
  | (Fever, Healthy) -> 0.4
  | (Fever, Fever) -> 0.6
 
let wikiEmissions = function
  | (Cold, Healthy) -> 0.4
  | (Normal, Healthy) -> 0.5
  | (Dizzy, Healthy) -> 0.1
  | (Cold, Fever) -> 0.3
  | (Normal, Fever) -> 0.1
  | (Dizzy, Fever) -> 0.6
 
viterbi wikiPriors wikiTransitions wikiEmissions [| Dizzy; Normal; Cold |] wikiHiddenList

Clojure implementation[edit]

(ns ident.viterbi
  (:use [clojure.pprint]))
 
(defstruct hmm :n :m :init-probs :emission-probs :state-transitions)
 
(defn make-hmm [{:keys [states, obs, init-probs, emission-probs, state-transitions]}]
  (struct-map hmm
    :n (count states)
    :m (count obs)
    :states states
    :obs obs
    :init-probs init-probs ;; n dim
    :emission-probs emission-probs ;;m x n
    :state-transitions state-transitions))
 
(defn indexed [s]
  (map vector (iterate inc 0) s))
 
(defn argmax [coll]
  (loop [s (indexed coll)
         max (first s)]
    (if (empty? s)
      max
      (let [[idx elt] (first s)
            [max-indx max-elt] max]
        (if (> elt max-elt)
          (recur (rest s) (first s))
	  (recur (rest s) max))))))
 
(defn pprint-hmm [hmm]
  (println "number of states: " (:n hmm) " number of observations:  " (:m hmm))
  (print "init probabilities: ") (pprint (:init-probs hmm))
  (print "emission probs: " ) (pprint (:emission-probs hmm))
  (print "state-transitions: " ) (pprint (:state-transitions hmm)))
 
(defn init-alphas [hmm obs]
  (map (fn [x]
         (* (aget (:init-probs hmm) x) (aget (:emission-probs hmm) x obs)))
       (range (:n hmm))))
 
(defn forward [hmm alphas obs]
  (map (fn [j]
         (* (reduce (fn [sum i]
                      (+ sum (* (nth alphas i) (aget (:state-transitions hmm) i j))))
                    0
                    (range (:n hmm)))
            (aget (:emission-probs hmm) j obs))) (range (:n hmm))))
 
(defn delta-max [hmm deltas obs]
  (map (fn [j]
         (* (apply max (map (fn [i]
                              (* (nth deltas i)
                                 (aget (:state-transitions hmm) i j)))
                            (range (:n hmm))))
            (aget (:emission-probs hmm) j obs)))
       (range (:n hmm))))
 
(defn backtrack [paths deltas]
  (loop [path (reverse paths)
         term (first (argmax deltas))
         backtrack []]
    (if (empty? path)
      (reverse (conj backtrack term))
      (recur (rest path) (nth (first path) term) (conj backtrack term)))))
 
(defn update-paths [hmm deltas]
  (map (fn [j]
         (first (argmax (map (fn [i]
                                  (* (nth deltas i)
                                     (aget (:state-transitions hmm) i j)))
                                (range (:n hmm))))))
       (range (:n hmm))))
 
(defn viterbi [hmm observs]
  (loop [obs (rest observs)
         alphas (init-alphas hmm (first observs))
         deltas alphas
         paths []]
    (if (empty? obs)
      [(backtrack paths deltas) (float (reduce + alphas))]
      (recur (rest obs)
             (forward hmm alphas (first obs))
             (delta-max hmm deltas (first obs))
             (conj paths (update-paths hmm deltas))))))

C# implementation[edit]

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
 
namespace Viterbi
{
    class Program
    {
        //Weather states
        static String HEALTHY = "Healthy";
        static String FEVER = "Fever";
        //Dependable actions (observations)
        static String DIZZY = "dizzy";
        static String COLD = "cold";
        static String NORMAL = "normal";
 
        static void Main(string[] args)
        {
            //initialize our arrays of states and observations
            String[] states = { HEALTHY, FEVER };
            String[] observations = { DIZZY, COLD, NORMAL };
 
            var start_probability = new Dictionary<String, float>();
            start_probability.Add(HEALTHY, 0.6f);
            start_probability.Add(FEVER, 0.4f);
            //Transition probability
            var transition_probability = new Dictionary<String, Dictionary<String, float>>();
            var t1 = new Dictionary<String, float>();
            t1.Add(HEALTHY, 0.7f);
            t1.Add(FEVER, 0.3f);
            Dictionary<String, float> t2 = new Dictionary<String, float>();
            t2.Add(HEALTHY, 0.4f);
            t2.Add(FEVER, 0.6f);
            transition_probability.Add(HEALTHY, t1);
            transition_probability.Add(FEVER, t2);
 
            //emission_probability
            var emission_probability = new Dictionary<String, Dictionary<String, float>>();
            var e1 = new Dictionary<String, float>();
            e1.Add(DIZZY, 0.1f);
            e1.Add(COLD, 0.4f);
            e1.Add(NORMAL, 0.5f);
 
            Dictionary<String, float> e2 = new Dictionary<String, float>();
            e2.Add(DIZZY, 0.6f);
            e2.Add(COLD, 0.3f);
            e2.Add(NORMAL, 0.1f);
 
            emission_probability.Add(HEALTHY, e1);
            emission_probability.Add(FEVER, e2);
 
            Object[] ret = forward_viterbi(observations, states, start_probability, transition_probability, emission_probability);
            Console.WriteLine((float)ret[0]);
            Console.WriteLine((String)ret[1]);
            Console.WriteLine((float)ret[2]);
            Console.ReadLine();
 
        }
 
        public static Object[] forward_viterbi(String[] obs, String[] states, Dictionary<String, float> start_p, Dictionary<String, Dictionary<String, float>> trans_p, Dictionary<String, Dictionary<String, float>> emit_p)
        {
 
            var T = new Dictionary<String, Object[]>();
            foreach (String state in states)
            {
                T.Add(state, new Object[] { start_p[state], state, start_p[state] });
            }
 
 
            foreach (String output in obs)
            {
                var U = new Dictionary<String, Object[]>();
 
                foreach (String next_state in states)
                {
                    float total = 0;
                    String argmax = "";
                    float valmax = 0;
 
                    float prob = 1;
                    String v_path = "";
                    float v_prob = 1;
 
                    foreach (String source_state in states)
                    {
 
                        Object[] objs = T[source_state];
                        prob = ((float)objs[0]);
                        v_path = (String)objs[1];
                        v_prob = ((float)objs[2]);
 
                        float p = emit_p[source_state][output] * trans_p[source_state][next_state];
                        prob *= p;
                        v_prob *= p;
                        total += prob;
 
                        if (v_prob > valmax)
                        {
                            argmax = v_path + "," + next_state;
                            valmax = v_prob;
                        }
                    }
                    U.Add(next_state, new Object[] { total, argmax, valmax });
                }
                T = U;
            }
 
 
 
 
 
 
            float xtotal = 0;
            String xargmax = "";
            float xvalmax = 0;
 
            float xprob;
            String xv_path;
            float xv_prob;
 
            foreach (String state in states)
            {
                Object[] objs = T[state];
                xprob = ((float)objs[0]);
                xv_path = ((String)objs[1]);
                xv_prob = ((float)objs[2]);
 
                xtotal += xprob;
                if (xv_prob > xvalmax)
                {
                    xargmax = xv_path;
                    xvalmax = xv_prob;
                }
            }
            return new Object[] { xtotal, xargmax, xvalmax };
 
 
        }
    }
}