Algorithm Implementation/Viterbi algorithm
Appearance
The following implementations of the w:Viterbi algorithm were removed from an earlier copy of the Wikipedia page because they were too long and unencyclopaedic - but we hope you'll find them useful here!
Java implementation
[edit | edit source]public class Viterbi {
private static String[] states = { "#", "NN", "VB" };
private static String[] observations = { "I", "write", "a letter" };
private static double[] start_probability = { 0.3, 0.4, 0.3 };
private static double[][] transition_probability = { { 0.2, 0.2, 0.6 },
{ 0.4, 0.1, 0.5 }, { 0.1, 0.8, 0.1 } };
private static double[][] emission_probability = { { 0.01, 0.02, 0.02 },
{ 0.8, 0.01, 0.5 }, { 0.19, 0.97, 0.48 } };
private static class TNode {
public int[] v_path;
public double v_prob;
public TNode( int[] v_path, double v_prob) {
this.v_path = copyIntArray(v_path);
this.v_prob = v_prob;
}
}
private static int[] copyIntArray(int[] ia) {
int[] newIa = new int[ia.length];
for (int i = 0; i < ia.length; i++) {
newIa[i] = ia[i];
}
return newIa;
}
private static int[] copyIntArray(int[] ia, int newInt) {
int[] newIa = new int[ia.length + 1];
for (int i = 0; i < ia.length; i++) {
newIa[i] = ia[i];
}
newIa[ia.length] = newInt;
return newIa;
}
// forwardViterbi(observations, states, start_probability,
// transition_probability, emission_probability)
public int[] forwardViterbi(String[] y, String[] X, double[] sp,
double[][] tp, double[][] ep) {
TNode[] T = new TNode[X.length];
for (int state = 0; state < X.length; state++) {
int[] intArray = new int[1];
intArray[0] = state;
T[state] = new TNode( intArray, sp[state] * ep[state][0]);
}
for (int output = 1; output < y.length; output++) {
TNode[] U = new TNode[X.length];
for (int next_state = 0; next_state < X.length; next_state++) {
int[] argmax = new int[0];
double valmax = 0;
for (int state = 0; state < X.length; state++) {
int[] v_path = copyIntArray(T[state].v_path);
double v_prob = T[state].v_prob;
double p = ep[next_state][output] * tp[state][next_state];
v_prob *= p;
if (v_prob > valmax) {
if (v_path.length == y.length) {
argmax = copyIntArray(v_path);
} else {
argmax = copyIntArray(v_path, next_state);
}
valmax = v_prob;
}
}
U[next_state] = new TNode( argmax, valmax);
}
T = U;
}
// apply sum/max to the final states:
int[] argmax = new int[0];
double valmax = 0;
for (int state = 0; state < X.length; state++) {
int[] v_path = copyIntArray(T[state].v_path);
double v_prob = T[state].v_prob;
if (v_prob > valmax) {
argmax = copyIntArray(v_path);
valmax = v_prob;
}
}
System.out.print("Viterbi path: [");
for (int i = 0; i < argmax.length; i++) {
System.out.print(states[argmax[i]] + ", ");
}
System.out.println("].\n Probability of the whole system: " + valmax);
return argmax;
}
public static void main(String[] args) throws Exception {
System.out.print("\nStates: ");
for (int i = 0; i < states.length; i++) {
System.out.print(states[i] + ", ");
}
System.out.print("\n\nObservations: ");
for (int i = 0; i < observations.length; i++) {
System.out.print(observations[i] + ", ");
}
System.out.print("\n\nStart probability: ");
for (int i = 0; i < states.length; i++) {
System.out.print(states[i] + ": " + start_probability[i] + ", ");
}
System.out.println("\n\nTransition probability:");
for (int i = 0; i < states.length; i++) {
System.out.print(" " + states[i] + ": {");
for (int j = 0; j < states.length; j++) {
System.out.print(" " + states[j] + ": "
+ transition_probability[i][j] + ", ");
}
System.out.println("}");
}
System.out.println("\n\nEmission probability:");
for (int i = 0; i < states.length; i++) {
System.out.print(" " + states[i] + ": {");
for (int j = 0; j < observations.length; j++) {
System.out.print(" " + observations[j] + ": "
+ emission_probability[i][j] + ", ");
}
System.out.println("}");
}
Viterbi v=new Viterbi();
v.forwardViterbi(observations, states, start_probability,
transition_probability, emission_probability);
}
}
F# implementation
[edit | edit source](* Nick Heiner *)
(* Viterbi algorithm, as described here: http://people.ccmr.cornell.edu/~ginsparg/INFO295/vit.pdf
priorProbs: prior probability of a hidden state occuring
transitions: probability of one hidden state transitioning into another
emissionProbs: probability of a hidden state emitting an observed state
observation: a sequence of observed states
hiddens: a list of all possible hidden states
Returns: probability of most likely path * hidden state list representing the path
*)
let viterbi (priorProbs : 'hidden -> float) (transitions : ('hidden * 'hidden) -> float) (emissionProbs : (('observed * 'hidden) -> float))
(observation : 'observed []) (hiddens : 'hidden list) : float * 'hidden list =
(* Referred to as v_state(time) in the notes *)
(* Probability of the most probable path ending in state at time *)
let rec mostLikelyPathProb (state : 'hidden) (time : int) : float * 'hidden list =
let emission = emissionProbs (observation.[time], state)
match time with
(* If we're at time 0, then just use the emission probability and the prior probability for this state *)
| 0 -> emission * priorProbs state, [state]
(* If we're not at time 0, then recursively look for the most likely path ending at this time *)
| t when t > 0 ->
let prob, path = Seq.maxBy fst (seq { for hiddenState in hiddens ->
(* Recursively look for most likely path at t - 1 *)
let prob, path = mostLikelyPathProb hiddenState (time - 1)
(* Rate each path by how likely it is to transition into the current state *)
transitions (List.head path, state) * prob, path})
emission * prob, state::path
(* If time is < 0, then throw an error *)
| _ -> failwith "time must be >= 0"
(* Look for the most likely path that ends at t_max *)
let prob, revPath = Seq.maxBy fst (seq { for hiddenState in hiddens -> mostLikelyPathProb hiddenState ((Array.length observation) - 1)})
prob, List.rev revPath
(* example using data from this article: *)
type wikiHiddens = Healthy | Fever
let wikiHiddenList = [Healthy; Fever]
type wikiObservations = Normal | Cold | Dizzy
let wikiPriors = function
| Healthy -> 0.6
| Fever -> 0.4
let wikiTransitions = function
| (Healthy, Healthy) -> 0.7
| (Healthy, Fever) -> 0.3
| (Fever, Healthy) -> 0.4
| (Fever, Fever) -> 0.6
let wikiEmissions = function
| (Cold, Healthy) -> 0.4
| (Normal, Healthy) -> 0.5
| (Dizzy, Healthy) -> 0.1
| (Cold, Fever) -> 0.3
| (Normal, Fever) -> 0.1
| (Dizzy, Fever) -> 0.6
viterbi wikiPriors wikiTransitions wikiEmissions [| Dizzy; Normal; Cold |] wikiHiddenList
Clojure implementation
[edit | edit source](ns ident.viterbi
(:use [clojure.pprint]))
(defstruct hmm :n :m :init-probs :emission-probs :state-transitions)
(defn make-hmm [{:keys [states, obs, init-probs, emission-probs, state-transitions]}]
(struct-map hmm
:n (count states)
:m (count obs)
:states states
:obs obs
:init-probs init-probs ;; n dim
:emission-probs emission-probs ;;m x n
:state-transitions state-transitions))
(defn indexed [s]
(map vector (iterate inc 0) s))
(defn argmax [coll]
(loop [s (indexed coll)
max (first s)]
(if (empty? s)
max
(let [[idx elt] (first s)
[max-indx max-elt] max]
(if (> elt max-elt)
(recur (rest s) (first s))
(recur (rest s) max))))))
(defn pprint-hmm [hmm]
(println "number of states: " (:n hmm) " number of observations: " (:m hmm))
(print "init probabilities: ") (pprint (:init-probs hmm))
(print "emission probs: " ) (pprint (:emission-probs hmm))
(print "state-transitions: " ) (pprint (:state-transitions hmm)))
(defn init-alphas [hmm obs]
(map (fn [x]
(* (aget (:init-probs hmm) x) (aget (:emission-probs hmm) x obs)))
(range (:n hmm))))
(defn forward [hmm alphas obs]
(map (fn [j]
(* (reduce (fn [sum i]
(+ sum (* (nth alphas i) (aget (:state-transitions hmm) i j))))
0
(range (:n hmm)))
(aget (:emission-probs hmm) j obs))) (range (:n hmm))))
(defn delta-max [hmm deltas obs]
(map (fn [j]
(* (apply max (map (fn [i]
(* (nth deltas i)
(aget (:state-transitions hmm) i j)))
(range (:n hmm))))
(aget (:emission-probs hmm) j obs)))
(range (:n hmm))))
(defn backtrack [paths deltas]
(loop [path (reverse paths)
term (first (argmax deltas))
backtrack []]
(if (empty? path)
(reverse (conj backtrack term))
(recur (rest path) (nth (first path) term) (conj backtrack term)))))
(defn update-paths [hmm deltas]
(map (fn [j]
(first (argmax (map (fn [i]
(* (nth deltas i)
(aget (:state-transitions hmm) i j)))
(range (:n hmm))))))
(range (:n hmm))))
(defn viterbi [hmm observs]
(loop [obs (rest observs)
alphas (init-alphas hmm (first observs))
deltas alphas
paths []]
(if (empty? obs)
[(backtrack paths deltas) (float (reduce + alphas))]
(recur (rest obs)
(forward hmm alphas (first obs))
(delta-max hmm deltas (first obs))
(conj paths (update-paths hmm deltas))))))
C# implementation
[edit | edit source]using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Viterbi
{
class Program
{
//Weather states
static String HEALTHY = "Healthy";
static String FEVER = "Fever";
//Dependable actions (observations)
static String DIZZY = "dizzy";
static String COLD = "cold";
static String NORMAL = "normal";
static void Main(string[] args)
{
//initialize our arrays of states and observations
String[] states = { HEALTHY, FEVER };
String[] observations = { DIZZY, COLD, NORMAL };
var start_probability = new Dictionary<String, float>();
start_probability.Add(HEALTHY, 0.6f);
start_probability.Add(FEVER, 0.4f);
//Transition probability
var transition_probability = new Dictionary<String, Dictionary<String, float>>();
var t1 = new Dictionary<String, float>();
t1.Add(HEALTHY, 0.7f);
t1.Add(FEVER, 0.3f);
Dictionary<String, float> t2 = new Dictionary<String, float>();
t2.Add(HEALTHY, 0.4f);
t2.Add(FEVER, 0.6f);
transition_probability.Add(HEALTHY, t1);
transition_probability.Add(FEVER, t2);
//emission_probability
var emission_probability = new Dictionary<String, Dictionary<String, float>>();
var e1 = new Dictionary<String, float>();
e1.Add(DIZZY, 0.1f);
e1.Add(COLD, 0.4f);
e1.Add(NORMAL, 0.5f);
Dictionary<String, float> e2 = new Dictionary<String, float>();
e2.Add(DIZZY, 0.6f);
e2.Add(COLD, 0.3f);
e2.Add(NORMAL, 0.1f);
emission_probability.Add(HEALTHY, e1);
emission_probability.Add(FEVER, e2);
Object[] ret = forward_viterbi(observations, states, start_probability, transition_probability, emission_probability);
Console.WriteLine((float)ret[0]);
Console.WriteLine((String)ret[1]);
Console.WriteLine((float)ret[2]);
Console.ReadLine();
}
public static Object[] forward_viterbi(String[] obs, String[] states, Dictionary<String, float> start_p, Dictionary<String, Dictionary<String, float>> trans_p, Dictionary<String, Dictionary<String, float>> emit_p)
{
var T = new Dictionary<String, Object[]>();
foreach (String state in states)
{
T.Add(state, new Object[] { start_p[state], state, start_p[state] });
}
foreach (String output in obs)
{
var U = new Dictionary<String, Object[]>();
foreach (String next_state in states)
{
float total = 0;
String argmax = "";
float valmax = 0;
float prob = 1;
String v_path = "";
float v_prob = 1;
foreach (String source_state in states)
{
Object[] objs = T[source_state];
prob = ((float)objs[0]);
v_path = (String)objs[1];
v_prob = ((float)objs[2]);
float p = emit_p[source_state][output] * trans_p[source_state][next_state];
prob *= p;
v_prob *= p;
total += prob;
if (v_prob > valmax)
{
argmax = v_path + "," + next_state;
valmax = v_prob;
}
}
U.Add(next_state, new Object[] { total, argmax, valmax });
}
T = U;
}
float xtotal = 0;
String xargmax = "";
float xvalmax = 0;
float xprob;
String xv_path;
float xv_prob;
foreach (String state in states)
{
Object[] objs = T[state];
xprob = ((float)objs[0]);
xv_path = ((String)objs[1]);
xv_prob = ((float)objs[2]);
xtotal += xprob;
if (xv_prob > xvalmax)
{
xargmax = xv_path;
xvalmax = xv_prob;
}
}
return new Object[] { xtotal, xargmax, xvalmax };
}
}
}