Session Hidden_Markov_Models

Theory Auxiliary

theory Auxiliary
  imports Main "HOL-Library.Extended_Nonnegative_Real"
begin

section ‹Auxiliary Material›

context
  fixes S :: "'s set"
  assumes "finite S"
begin

lemma Max_image_commute:
  "(MAX x  S. MAX y  S. f x y) = (MAX y  S. MAX x  S. f x y)"
proof (rule Max_eq_if, goal_cases)
  case 3
  { fix a assume "a  S"
    with Max_in[OF finite_imageI[OF ‹finite S], of "f a"] have "Max (f a ` S)  f a ` S"
      by auto
    then obtain b where "f a b = Max (f a ` S)" "b  S"
      by auto
    from a  S have "f a b  (MAX a  S. f a b)"
      by (auto intro: Max_ge finite_imageI[OF ‹finite S])
    with f a b = _ b  S have "bS. Max (f a ` S)  (MAX a  S. f a b)"
      by auto
  }
  then show ?case
    by auto
next
  case 4
  { fix b assume "b  S"
    with Max_in[OF finite_imageI[OF ‹finite S], of "λ a. f a b"] have
      "(MAX a  S. f a b)  (λa. f a b) ` S"
      by auto
    then obtain a where "f a b = (MAX a  S. f a b)" "a  S"
      by auto
    from b  S have "f a b  Max (f a ` S)"
      by (auto intro: Max_ge finite_imageI[OF ‹finite S])
    with f a b = _ a  S have "aS. (MAX a  S. f a b)  Max (f a ` S)"
      by auto
  }
    then show ?case
      by auto
  qed (use ‹finite S in auto)

lemma Max_image_left_mult:
  "(MAX x  S. c * f x) = (c :: ennreal) * (MAX x  S. f x)" if "S  {}"
  apply (rule Max_eqI)
  subgoal
    using ‹finite S by auto
  subgoal for y
    using ‹finite S by (auto intro: mult_left_mono)
  subgoal
    using Max_in[OF finite_imageI[OF ‹finite S], of f] S  {} by auto
  done

end (* Finite set *)

lemma Max_to_image:
  "Max {f t | t. t  S} = Max (f ` S)"
  by (rule arg_cong[where f = Max]) auto

lemma Max_to_image2:
  "Max {f t | t. P t} = Max (f ` {t. P t})"
  by (rule arg_cong[where f = Max]) auto

lemma Max_image_cong:
  "Max (f ` S) = Max (g ` T)" if "S = T" "x. x  T  f x = g x"
  by (intro arg_cong[where f = Max] image_cong[OF that])

lemma Max_image_cong_simp:
  "Max (f ` S) = Max (g ` T)" if "S = T" "x. x  T =simp=> f x = g x"
  using Max_image_cong[OF that[unfolded simp_implies_def]] .

lemma Max_eq_image_if:
  assumes
    "finite S" "finite T" "x  S. y  T. f x  g y" "x  T. y  S. g x  f y"
  shows "Max (f ` S) = Max (g ` T)"
  using assms by (auto intro: Max_eq_if)

theorem Max_in_image:
  assumes "finite A" and "A  {}"
  obtains x where "x  A" and "Max (f ` A) = f x"
proof -
  from Max_in[of "f ` A"] assms have "Max (f ` A)  f ` A"
    by auto
  then show ?thesis
    by (auto intro: that)
qed

lemma Max_ge_image:
  "Max (f ` S)  f x" if "finite S" "x  S"
  using that by (auto intro: Max_ge)

lemma Max_image_pair:
  assumes "finite S" "finite T" "T  {}"
  shows "(MAX s  S. MAX t  T. f s t) = (MAX (s, t)  S × T. f s t)"
proof ((rule Max_eq_image_if; clarsimp?), goal_cases)
  case (3 x)
  from ‹finite T T  {} obtain y where "y  T" and "Max (f x ` T) = f x y"
    by (rule Max_in_image)
  with x  S show ?case
    by auto
next
  case (4 a b)
  with ‹finite T show ?case
    by force
qed (use assms in auto)


fun argmax where
  "argmax f (x # xs) =
    List.fold (λ a (b, v). let w = f a in if w > v then (a, w) else (b, v)) xs (x, f x)"

lemma list_cases:
  assumes "xs = []  P []"
      and " x. xs = [x]  P [x]"
      and " x y ys. xs = (x # y # ys)  P (x # y # ys)"
    shows "P xs"
  apply (cases xs)
   apply (simp add: assms)
  subgoal for y ys
    by (cases ys; simp add: assms)
  done

lemma argmax:
  assumes "xs  []"
  shows
    "fst (argmax f xs)  set xs" (is "?A")
    "f (fst (argmax f xs)) = snd (argmax f xs)" (is "?B")
    "snd (argmax f xs) = (MAX x  set xs. f x)" (is "?C")
proof -
  let ?f = "λ a (b, v). let w = f a in if w > v then (a, w) else (b, v)"
  have "fst (List.fold ?f xs (x, f x))  {x}  set xs" if "xs  []" for x xs
    using that by (induction xs arbitrary: x rule: list_nonempty_induct)(auto simp: Let_def max_def)
  with xs  [] show ?A
    by (cases xs rule: list_cases; fastforce)
  have "f (fst (List.fold ?f xs (x, f x))) = snd (List.fold ?f xs (x, f x))" if "xs  []" for x xs
    using that by (induction xs arbitrary: x rule: list_nonempty_induct)(auto simp: Let_def max_def)
  with xs  [] show ?B
    by (cases xs rule: list_cases; fastforce)
  have "snd (List.fold ?f xs (x, f x)) = (MAX x  {x}  set xs. f x)"if "xs  []" for x xs
    using that by (induction xs arbitrary: x rule: list_nonempty_induct)(auto simp: Let_def max_def)
  with xs  [] show ?C
    by (cases xs rule: list_cases; fastforce)
qed

end (* Theory *)

Theory Hidden_Markov_Model

section ‹Hidden Markov Models›

theory Hidden_Markov_Model
  imports
    Markov_Models.Discrete_Time_Markov_Chain Auxiliary
    "HOL-Library.IArray"
begin

subsection ‹Definitions›

text ‹Definition of Markov Kernels that are closed w.r.t. to a set of states.›
locale Closed_Kernel =
  fixes K :: "'s  't pmf" and S :: "'t set"
  assumes finite: "finite S"
      and wellformed: "S  {}"
      and closed: " s. K s  S"

text ‹
  An HMM is parameterized by a Markov kernel for the transition probabilites between internal states,
  a Markov kernel for the output probabilities of observations,
  and a fixed set of observations.
›
locale HMM_defs =
  fixes 𝒦 :: "'s  's pmf" and 𝒪 :: "'s  't pmf" and 𝒪s :: "'t set"

locale HMM =
  HMM_defs + O: Closed_Kernel 𝒪 𝒪s
begin

lemma observations_finite: "finite 𝒪s"
  and observations_wellformed: "𝒪s  {}"
  and observations_closed: " s. 𝒪 s  𝒪s"
  using O.finite O.wellformed O.closed by -

end

text ‹Fixed set of internal states.›
locale HMM2_defs = HMM_defs 𝒦 𝒪 for 𝒦 :: "'s  's pmf" and 𝒪 :: "'s  't pmf" +
  fixes 𝒮 :: "'s set"

locale HMM2 = HMM2_defs + HMM + K: Closed_Kernel 𝒦 𝒮
begin

lemma states_finite: "finite 𝒮"
  and states_wellformed: "𝒮  {}"
  and states_closed: " s. 𝒦 s  𝒮"
  using K.finite K.wellformed K.closed by -

end

text ‹
  The set of internal states is now given as a list to iterate over.
  This is needed for the computations on HMMs.
›
locale HMM3_defs = HMM2_defs 𝒪s 𝒦 for 𝒪s :: "'t set" and 𝒦 :: "'s  's pmf" +
  fixes state_list :: "'s list"

locale HMM3 = HMM3_defs _ _ 𝒪s 𝒦 + HMM2 𝒪s 𝒦 for 𝒪s :: "'t set" and 𝒦 :: "'s  's pmf" +
  assumes state_list_𝒮: "set state_list = 𝒮"

context HMM_defs
begin

no_notation (ASCII) comp  (infixl "o" 55)

text ‹The ``default'' observation.›
definition
  "obs  SOME x. x  𝒪s"

lemma (in HMM) obs:
  "obs  𝒪s"
  unfolding obs_def using observations_wellformed by (auto intro: someI_ex)

text ‹
  The HMM is encoded as a Markov chain over pairs of states and observations.
  This is the Markov chain's defining Markov kernel.
›
definition
  "K  λ (s1, o1 :: 't). bind_pmf (𝒦 s1) (λ s2. map_pmf (λ o2. (s2, o2)) (𝒪 s2))"

sublocale MC_syntax K .

text ‹
  Uniform distribution of the pairs (s, o)› for a fixed state s›.
›
definition "I (s :: 's) = map_pmf (λ x. (s, x)) (pmf_of_set 𝒪s)"

text ‹
  The likelihood of an observation sequence given a starting state s› is defined in terms of
  the trace space of the Markov kernel given the uniform distribution of pairs for s›.
›
definition
  "likelihood s os = T' (I s) {ω  space S.  o0 xs ω'. ω = (s, o0) ## xs @- ω'  map snd xs = os}"

abbreviation (input) "L os ω   xs ω'. ω = xs @- ω'  map snd xs = os"

lemma likelihood_alt_def: "likelihood s os = T' (I s) {(s, o) ## xs @- ω' |o xs ω'. map snd xs = os}"
  unfolding likelihood_def by (simp add: in_S)


subsection ‹Iteration Rule For Likelihood›

lemma L_Nil:
  "L [] ω = True"
  by simp

lemma emeasure_T_observation_Nil:
  "T (s, o0) {ω  space S. L [] ω} = 1"
  by simp

lemma L_Cons:
  "L (o # os) ω  snd (shd ω) = o  L os (stl ω)"
  apply (cases ω; cases "shd ω"; safe; clarsimp)
   apply force
  subgoal for x xs ω'
    by (force intro: exI[where x = "(x, o) # xs"])
  done

lemma L_measurable[measurable]:
  "Measurable.pred S (L os)"
  apply (induction os)
   apply (simp; fail)
  subgoal premises that for o os
    by(subst L_Cons)
      (intro Measurable.pred_intros_logic
        measurable_compose[OF measurable_shd] measurable_compose[OF measurable_stl that];
        measurable)
  done

lemma init_measurable[measurable]:
  "Measurable.pred S (λx. o0 xs ω'. x = (s, o0) ## xs @- ω'  map snd xs = os)"
  (is "Measurable.pred S ?f")
proof -
  have *: "?f ω  fst (shd ω) = s  L os (stl ω)" for ω
    by (cases ω) auto
  show ?thesis
    by (subst *)
       (intro Measurable.pred_intros_logic measurable_compose[OF measurable_shd]; measurable)
qed

lemma T_init_observation_eq:
  "T (s, o) {ω  space S. L os ω} = T (s, o') {ω  space S. L os ω}"
  apply (subst emeasure_Collect_T[unfolded space_T], (measurable; fail))
  apply (subst (2) emeasure_Collect_T[unfolded space_T], (measurable; fail))
  apply (simp add: K_def)
  done

text ‹
  Shows that it is equivalent to define likelihood in terms of the trace space starting at a single
  pair of an internal state s› and the default observation @{term obs}.
›
lemma (in HMM) likelihood_init:
  "likelihood s os = T (s, obs) {ω  space S. L os ω}"
proof -
  have *: "(o𝒪s. emeasure (T (s, o)) {ω  space S. L os ω}) =
    of_nat (card 𝒪s) * emeasure (T (s, obs)) {ω  space S. L os ω}"
    by (subst sum_constant[symmetric]) (fastforce intro: sum.cong T_init_observation_eq[simplified])
  show ?thesis
    unfolding likelihood_def
    apply (subst emeasure_T')
    subgoal
      by measurable
    using *
    apply (simp add: I_def in_S observations_finite observations_wellformed nn_integral_pmf_of_set)
    apply (subst mult.commute)
    apply (simp add: observations_finite observations_wellformed mult_divide_eq_ennreal)
    done
qed

lemma emeasure_T_observation_Cons:
  "T (s, o0) {ω  space S. L (o1 # os) ω} =
   (+ t. ennreal (pmf (𝒪 t) o1) * T (t, o1) {ω  space S. L os ω} (𝒦 s))" (is "?l = ?r")
proof -
  have *:
    "+ y. T (s', y) {x  space S. xs. (ω'. (s', y) ## x = xs @- ω')  map snd xs = o1 # os}
       measure_pmf (𝒪 s') =
    ennreal (pmf (𝒪 s') o1) * T (s', o1) {ω  space S. xs. (ω'. ω = xs @- ω')  map snd xs = os}"
    (is "?L = ?R") for s'
  proof -
    have "?L = + x. ennreal (pmf (𝒪 s') x) *
            T (s', x) {ω  space S. xs. (ω'. (s', x) ## ω = xs @- ω')  map snd xs = o1 # os}
          count_space UNIV"
      by (rule nn_integral_measure_pmf)
    also have " =
      + o2. (if o2 = o1
              then ennreal (pmf (𝒪 s') o1) * T (s', o1) {ω  space S. L os ω}
              else 0)
       count_space UNIV"
      apply (rule nn_integral_cong_AE
          [where v = "λ o2. if o2 = o1
            then ennreal (pmf (𝒪 s') o1) * T (s', o1) {ω  space S. L os ω} else 0"]
          )
       apply (rule AE_I2)
       apply (split if_split, safe)
      subgoal
        by (auto intro!: arg_cong2[where f = times, OF HOL.refl] arg_cong2[where f = emeasure];
            metis list.simps(9) shift.simps(2) snd_conv
           )
      subgoal
        by (subst arg_cong2[where f = emeasure and d = "{}", OF HOL.refl]) auto
      done
    also have " = +o2{o1}.
       (ennreal (pmf (𝒪 s') o1) * T (s', o1) {ω  space S. L os ω})
      count_space UNIV"
      by (rule nn_integral_cong_AE) auto
    also have " = ?R"
      by simp
    finally show ?thesis .
  qed
  have "?l = + t. T t {x  space S. xs ω'. t ## x = xs @- ω'  map snd xs = o1 # os}  (K (s, o0))"
    by (subst emeasure_Collect_T[unfolded space_T], measurable)
  also have " = ?r"
    using * by (simp add: K_def)
  finally show ?thesis .
qed


subsection ‹Computation of Likelihood›

fun backward where
  "backward s [] = 1" |
  "backward s (o # os) = (+ t. ennreal (pmf (𝒪 t) o) * backward t os measure_pmf (𝒦 s))"

lemma emeasure_T_observation_backward:
  "emeasure (T (s, o)) {ω  space S. L os ω} = backward s os"
  using emeasure_T_observation_Cons by (induction os arbitrary: s o; simp)

lemma (in HMM) likelihood_backward:
  "likelihood s os = backward s os"
  unfolding likelihood_init emeasure_T_observation_backward ..

end (* HMM Defs *)

context HMM2
begin

fun (in HMM2_defs) forward where
  "forward s t_end [] = indicator {t_end} s" |
  "forward s t_end (o # os) =
    (t  𝒮. ennreal (pmf (𝒪 t) o) * ennreal (pmf (𝒦 s) t) * forward t t_end os)"

lemma forward_split:
  "forward s t (os1 @ os2) = (t'  𝒮. forward s t' os1 * forward t' t os2)"
  if "s  𝒮"
  using that
  apply (induction os1 arbitrary: s)
  subgoal for s
    apply (simp add: sum_indicator_mult[OF states_finite])
    apply (subst sum.cong[where B = "{s}"])
    by auto
  subgoal for a os1 s
    apply simp
    apply (subst sum_distrib_right)
    apply (subst sum.swap)
    apply (simp add: sum_distrib_left algebra_simps)
    done
  done

lemma (in -)
  "(t  S. f t) = f t" if "finite S" "t  S" " s  S - {t}. f s = 0"
  thm sum.empty sum.insert sum.mono_neutral_right[of S "{t}"]
  apply (subst sum.mono_neutral_right[of S "{t}"])
  using that
     apply auto
  done
(*
  oops
  by (metis add.right_neutral empty_iff finite.intros(1) insert_iff subsetI sum.empty sum.insert sum.mono_neutral_right that)

  using that
  apply auto
*)

lemma forward_backward:
  "(t  𝒮. forward s t os) = backward s os" if "s  𝒮"
  using s  𝒮
  apply (induction os arbitrary: s)
  subgoal for s
    by (subst sum.mono_neutral_right[of 𝒮 "{s}", OF states_finite])
       (auto split: if_split_asm simp: indicator_def)
  subgoal for a os s
    apply (simp add: sum.swap sum_distrib_left[symmetric])
    apply (subst nn_integral_measure_pmf_support[where A = 𝒮])
    using states_finite states_closed by (auto simp: algebra_simps)
  done

theorem likelihood_forward:
  "likelihood s os = (t  𝒮. forward s t os)" if s  𝒮
  unfolding likelihood_backward forward_backward[symmetric, OF s  𝒮] ..


subsection ‹Definition of Maximum Probabilities›

abbreviation (input) "V os as ω  ( ω'. ω = zip as os @- ω')"

definition
  "max_prob s os =
  Max {T' (I s) {ω  space S. o ω'. ω = (s, o) ## zip as os @- ω'}
       | as. length as = length os  set as  𝒮}"

fun viterbi_prob where
  "viterbi_prob s t_end [] = indicator {t_end} s" |
  "viterbi_prob s t_end (o # os) =
    (MAX t  𝒮. ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * viterbi_prob t t_end os)"

definition
  "is_decoding s os as 
    T' (I s) {ω  space S. o ω'. ω = (s, o) ## zip as os @- ω'} = max_prob s os 
    length as = length os  set as  𝒮"


subsection ‹Iteration Rule For Maximum Probabilities›

lemma emeasure_T_state_Nil:
  "T (s, o0) {ω  space S. V [] as ω} = 1"
  by simp

lemma max_prob_T_state_Nil:
  "Max {T (s, o) {ω  space S. V [] as ω} | as. length as = length []  set as  𝒮} = 1"
  by (simp add: emeasure_T_state_Nil)

lemma V_Cons: "V (o # os) (a # as) ω  fst (shd ω) = a  snd (shd ω) = o  V os as (stl ω)"
  by (cases ω) auto

lemma measurable_V[measurable]:
  "Measurable.pred S (λω. V os as ω)"
proof (induction os as rule: list_induct2')
  case (4 x xs y ys)
  then show ?case
    by (subst V_Cons)
       (intro Measurable.pred_intros_logic
          measurable_compose[OF measurable_shd] measurable_compose[OF measurable_stl];
        measurable)
qed simp+

lemma init_V_measurable[measurable]:
  "Measurable.pred S (λx. o ω'. x = (s, o) ## zip as os @- ω')" (is "Measurable.pred S ?f")
proof -
  have *: "?f ω  fst (shd ω) = s  V os as (stl ω)" for ω
    by (cases ω) auto
  show ?thesis
    by (subst *)
       (intro Measurable.pred_intros_logic measurable_compose[OF measurable_shd]; measurable)
qed

lemma max_prob_Cons':
  "Max {T (s, o1) {ω  space S. V (o # os) as ω} | as. length as = length (o # os)  set as  𝒮} =
  (
    MAX t  𝒮. ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) *
      (MAX as  {as. length as = length os  set as  𝒮}. T (t, o) {ω  space S. V os as ω})
  )" (is "?l = ?r")
  and T_V_Cons:
  "T (s, o1) {ω  space S. V (o # os) (t # as) ω}
  = ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * T (t, o) {ω  space S. V os as ω}"
  (is "?l' = ?r'")
  if "length as = length os"
proof -
  let ?S = "λ os. {as. length as = length os  set as  𝒮}"
  have S_finite: "finite (?S os)" for os :: "'t list"
    using finite_lists_length_eq[OF states_finite] by (rule finite_subset[rotated]) auto
  have S_nonempty: "?S os  {}" for os :: "'t list"
  proof -
    let ?a = "SOME a. a  𝒮" let ?as = "replicate (length os) ?a"
    from states_wellformed have "?a  𝒮"
      by (auto intro: someI_ex)
    then have "?as  ?S os"
      by auto
    then show ?thesis
      by force
  qed
  let ?f = "λt as os. T t {ω  space S. V os as (t ## ω)}"
  let ?g = "λt as os. T t {ω  space S. V os as ω}"
  have *: "?f t as (o # os) = ?g t (tl as) os * indicator {(hd as, o)} t"
    if "length as = Suc n" for t as n
    unfolding indicator_def using that by (cases as) auto
  have **: "K (s, o1) {(t, o)} = pmf (𝒪 t) o * pmf (𝒦 s) t" for t
    unfolding K_def
    apply (simp add: vimage_def)
    apply (subst arg_cong2[where
          f = nn_integral and d = "λ x. 𝒪 x {xa. xa = o  x = t} * indicator {t} x",
          OF HOL.refl])
    subgoal
      by (auto simp: indicator_def)
    by (simp add: emeasure_pmf_single ennreal_mult')
  have "?l = (MAX as  ?S (o # os). + t. ?f t as (o # os) K (s, o1))"
    by (subst Max_to_image2; subst emeasure_Collect_T[unfolded space_T]; rule measurable_V HOL.refl)
  also have " = (MAX as  ?S (o # os). + t. ?g t (tl as) os * indicator {(hd as,o)} t K (s,o1))"
    by (simp cong: Max_image_cong_simp add: *)
  also have " = (MAX(t, as) 𝒮 × ?S os. ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * ?g (t, o) as os)"
  proof ((rule Max_eq_image_if; clarsimp?), goal_cases)
    case 1
    from S_finite[of "o # os"] show ?case
      by simp
  next
    case 2
    from states_finite show ?case
      by (blast intro: S_finite)
  next
    case (3 as)
    then show ?case
      by - (rule bexI[where x = "hd as"]; cases as; auto simp: algebra_simps **)
  next
    case (4 x as)
    then show ?case
      by - (rule exI[where x = "x # as"], simp add: algebra_simps **)
  qed
  also have " = ?r"
    by (subst Max_image_left_mult[symmetric], fact+)
       (rule sym, rule Max_image_pair, rule states_finite, fact+)
  finally show "?l = ?r" .
  have "?l' = + t'. ?f t' (t # as) (o # os) K (s, o1)"
    by (rule emeasure_Collect_T[unfolded space_T]; rule measurable_V)
  also from that have " = + t'. ?g t' as os * indicator {(t,o)} t' K (s,o1)"
    by (subst *[of _ "length as"]; simp)
  also have " = ?r'"
    by (simp add: **, simp only: algebra_simps)
  finally show "?l' = ?r'" .
qed

lemmas max_prob_Cons = max_prob_Cons'[OF length_replicate]



subsection ‹Computation of Maximum Probabilities›

lemma T_init_V_eq:
  "T (s, o) {ω  space S. V os as ω} = T (s, o') {ω  space S. V os as ω}"
  apply (subst emeasure_Collect_T[unfolded space_T], (measurable; fail))
  apply (subst (2) emeasure_Collect_T[unfolded space_T], (measurable; fail))
  apply (simp add: K_def)
  done

lemma T'_I_T:
  "T' (I s) {ω  space S. o ω'. ω = (s, o) ## zip as os @- ω'} = T (s,o) {ω  space S. V os as ω}"
proof -
  have "(o𝒪s. T (s, o) {ω  space S. V os as ω}) =
    of_nat (card 𝒪s) * T (s, o) {ω  space S. V os as ω}" for as
    by (subst sum_constant[symmetric]) (fastforce intro: sum.cong T_init_V_eq[simplified])
  then show ?thesis
    unfolding max_prob_def
    apply (subst emeasure_T')
    subgoal
      by measurable
    apply (simp add: I_def in_S observations_finite observations_wellformed nn_integral_pmf_of_set)
    apply (subst mult.commute)
    apply (simp add: observations_finite observations_wellformed mult_divide_eq_ennreal)
    done
qed

lemma max_prob_init:
  "max_prob s os = Max {T (s,o) {ω  space S. V os as ω} | as. length as = length os  set as  𝒮}"
  unfolding max_prob_def by (simp add: T'_I_T[symmetric])

lemma max_prob_Nil[simp]:
  "max_prob s [] = 1"
  unfolding max_prob_init[where o = obs] by auto

lemma Max_start:
  "(MAX t𝒮. (indicator {t} s :: ennreal)) = 1" if "s  𝒮"
  using states_finite that by (auto simp: indicator_def intro: Max_eqI)

lemma Max_V_viterbi:
  "(MAX t  𝒮. viterbi_prob s t os) =
   Max {T (s, o) {ω  space S. V os as ω} | as. length as = length os  set as  𝒮}" if "s  𝒮"
  using that states_finite states_wellformed
  by (induction os arbitrary: s o; simp
        add: Max_start max_prob_Cons[simplified] Max_image_commute Max_image_left_mult Max_to_image2
        cong: Max_image_cong
      )

lemma max_prob_viterbi:
  "(MAX t  𝒮. viterbi_prob s t os) = max_prob s os" if "s  𝒮"
  using max_prob_init[of s os] Max_V_viterbi[OF s  𝒮, symmetric] by simp

end

subsection ‹Decoding the Most Probable Hidden State Sequence›

context HMM3
begin

fun viterbi where
  "viterbi s t_end [] = ([], indicator {t_end} s)" |
  "viterbi s t_end (o # os) = fst (
    argmax snd (map
      (λt. let (xs, v) = viterbi t t_end os in (t # xs, ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * v))
    state_list))"

lemma state_list_nonempty:
  "state_list  []"
  using state_list_𝒮 states_wellformed by auto

lemma viterbi_viterbi_prob:
  "snd (viterbi s t_end os) = viterbi_prob s t_end os"
proof (induction os arbitrary: s)
  case Nil
  then show ?case
    by simp
next
  case (Cons o os)
  let ?f =
    "λt. let (xs, v) = viterbi t t_end os in (t # xs, ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * v)"
  let ?xs = "map ?f state_list"
  from state_list_nonempty have "map ?f state_list  []"
    by simp
  from argmax(2,3)[OF this, of snd] have *:
    "snd (fst (argmax snd ?xs)) = snd (argmax snd ?xs)"
    "snd (argmax snd ?xs) = (MAX x  set ?xs. snd x)" .
  then show ?case
    apply (simp add: state_list_𝒮)
    apply (rule Max_eq_image_if)
       apply (intro finite_imageI states_finite; fail)
      apply (intro finite_imageI states_finite; fail)
    subgoal
      apply clarsimp
      subgoal for x
        using Cons.IH[of x] by (auto split: prod.splits)
      done
    apply clarsimp
    subgoal for x
      using Cons.IH[of x] by (force split: prod.splits)
    done
qed

context
begin

private fun val_of where
  "val_of s [] [] = 1" |
  "val_of s (t # xs) (o # os) = ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * val_of t xs os"

lemma val_of_T:
  "val_of s as os = T (s, o1) {ω  space S. V os as ω}" if "length as = length os"
  using that by (induction arbitrary: o1 rule: val_of.induct; (subst T_V_Cons)?; simp)

lemma viterbi_sequence:
  "snd (viterbi s t_end os) = val_of s (fst (viterbi s t_end os)) os"
  if "snd (viterbi s t_end os) > 0"
  using that
proof (induction os arbitrary: s)
  case Nil
  then show ?case
    by (simp add: indicator_def split: if_split_asm)
next
  case (Cons o os s)
  let ?xs = "map
    (λt. let (xs, v) = viterbi t t_end os in (t # xs, ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * v))
    state_list"
  from state_list_nonempty have "?xs  []"
    by simp
  from argmax(1)[OF this, of snd] obtain t where
    "t  set state_list"
    "fst (argmax snd ?xs) =
    (t # fst (viterbi t t_end os), ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * snd (viterbi t t_end os))"
    by (auto split: prod.splits)
  with Cons show ?case
    by (auto simp: ennreal_zero_less_mult_iff)
qed

lemma viterbi_valid_path:
  "length as = length os  set as  𝒮" if "viterbi s t_end os = (as, v)"
using that proof (induction os arbitrary: s as v)
  case Nil
  then show ?case
    by simp
next
  case (Cons o os s as v)
  let ?xs = "map
    (λt. let (xs, v) = viterbi t t_end os in (t # xs, ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * v))
    state_list"
  from state_list_nonempty have "?xs  []"
    by simp
  from argmax(1)[OF this, of snd] obtain t where "t  𝒮"
    "fst (argmax snd ?xs) =
    (t # fst (viterbi t t_end os), ennreal (pmf (𝒪 t) o * pmf (𝒦 s) t) * snd (viterbi t t_end os))"
    by (auto simp: state_list_𝒮 split: prod.splits)
  with Cons.prems show ?case
    by (cases "viterbi t t_end os"; simp add: Cons.IH)
qed

definition
  "viterbi_final s os = fst (argmax snd (map (λ t. viterbi s t os) state_list))"

lemma viterbi_finalE:
  obtains t where
    "t  𝒮" "viterbi_final s os = viterbi s t os"
    "snd (viterbi s t os) = Max ((λt. snd (viterbi s t os)) ` 𝒮)"
proof -
  from state_list_nonempty have "map (λ t. viterbi s t os) state_list  []"
    by simp
  from argmax[OF this, of snd] show ?thesis
    by (auto simp: state_list_𝒮 image_comp comp_def viterbi_final_def intro: that)
qed

theorem viterbi_final_max_prob:
  assumes "viterbi_final s os = (as, v)" "s  𝒮"
  shows "v = max_prob s os"
proof -
  obtain t where "t  𝒮" "viterbi_final s os = viterbi s t os"
    "snd (viterbi s t os) = Max ((λt. snd (viterbi s t os)) ` 𝒮)"
    by (rule viterbi_finalE)
  with assms show ?thesis
    by (simp add: viterbi_viterbi_prob max_prob_viterbi)
qed

theorem viterbi_final_is_decoding:
  assumes "viterbi_final s os = (as, v)" "v > 0" "s  𝒮"
  shows "is_decoding s os as"
proof -
  from viterbi_valid_path[of s _ os as v] assms have as: "length as = length os" "set as  𝒮"
    by - (rule viterbi_finalE[of s os]; simp)+
  obtain t where "t  𝒮" "viterbi_final s os = viterbi s t os"
    by (rule viterbi_finalE)
  with assms viterbi_sequence[of s t os] have "val_of s as os = v"
    by (cases "viterbi s t os") (auto simp: snd_def split!: prod.splits)
  with val_of_T as have "max_prob s os = T (s, obs) {ω  space S. V os as ω}"
    by (simp add: viterbi_final_max_prob[OF assms(1,3)])
  with as show ?thesis
    unfolding is_decoding_def by (simp only: T'_I_T)
qed

end (* Anonymous context *)

end (* HMM 3 *)

end (* Theory *)

Theory HMM_Implementation

section ‹Implementation›

theory HMM_Implementation
  imports
    Hidden_Markov_Model
    "Monad_Memo_DP.State_Main"
begin

subsection ‹The Forward Algorithm›

locale HMM4 = HMM3 _ _ _ 𝒪s 𝒦 for 𝒪s :: "'t set" and 𝒦 :: "'s  's pmf" +
  assumes states_distinct: "distinct state_list"

context HMM3_defs
begin

context
  fixes os :: "'t iarray"
begin

text ‹
  Alternative definition using indices into the list of states.
  The list of states is implemented as an immutable array for better performance.
›

function forward_ix_rec where
  "forward_ix_rec s t_end n = (if n  IArray.length os then indicator {t_end} s else
    (t  state_list.
      ennreal (pmf (𝒪 t) (os !! n)) * ennreal (pmf (𝒦 s) t) * forward_ix_rec t t_end (n + 1)))
  "
  by auto
termination
  by (relation "Wellfounded.measure (λ(_,_,n). IArray.length os - n)") auto

text ‹Memoization›

memoize_fun forward_ixm: forward_ix_rec
  with_memory dp_consistency_mapping
  monadifies (state) forward_ix_rec.simps[unfolded Let_def]
  term forward_ixm'
memoize_correct
  by memoize_prover

text ‹The main theorems generated by memoization.›
context
  includes state_monad_syntax
begin
thm forward_ixm'.simps forward_ixm_def
thm forward_ixm.memoized_correct
end

end (* Fixed IArray *)

definition
  "forward_ix os = forward_ix_rec (IArray os)"

definition
  "likelihood_compute s os 
    if s  set state_list then Some (t  state_list. forward s t os) else None"

end (* HMM3 Defs *)

text ‹Correctness of the alternative definition.›

lemma (in HMM3) forward_ix_drop_one:
  "forward_ix (o # os) s t (n + 1) = forward_ix os s t n"
  by (induction "length os - n" arbitrary: s n; simp add: forward_ix_def)

lemma (in HMM4) forward_ix_forward:
  "forward_ix os s t 0 = forward s t os"
  unfolding forward_ix_def
proof (induction os arbitrary: s)
  case Nil
  then show ?case
    by simp
next
  case (Cons o os)
  show ?case
    using forward_ix_drop_one[unfolded forward_ix_def] states_distinct
    by (subst forward.simps, subst forward_ix_rec.simps)
       (simp add: Cons.IH state_list_𝒮 sum_list_distinct_conv_sum_set
             del: forward_ix_rec.simps forward.simps
       )
qed

text ‹
  Instructs the code generator to use this equation instead to execute forward›.
  Uses the memoized version of forward_ix›.
›
lemma (in HMM4) forward_code [code]:
  "forward s t os = fst (run_state (forward_ixm' (IArray os) s t 0) Mapping.empty)"
  by (simp only:
      forward_ix_def forward_ixm.memoized_correct forward_ix_forward[symmetric]
      states_distinct
     )

theorem (in HMM4) likelihood_compute:
  "likelihood_compute s os = Some x  s  𝒮  x = likelihood s os"
  unfolding likelihood_compute_def
  by (auto simp: states_distinct state_list_𝒮 sum_list_distinct_conv_sum_set likelihood_forward)


subsection ‹The Viterbi Algorithm›

context HMM3_defs
begin

context
  fixes os :: "'t iarray"
begin

text ‹
  Alternative definition using indices into the list of states.
  The list of states is implemented as an immutable array for better performance.
›

function viterbi_ix_rec where
  "viterbi_ix_rec s t_end n = (if n  IArray.length os then ([], indicator {t_end} s) else
  fst (
    argmax snd (map
      (λt. let (xs, v) = viterbi_ix_rec t t_end (n + 1) in
        (t # xs, ennreal (pmf (𝒪 t) (os !! n) * pmf (𝒦 s) t) * v))
    state_list)))
  "
  by pat_completeness auto
termination
  by (relation "Wellfounded.measure (λ(_,_,n). IArray.length os - n)") auto

text ‹Memoization›

memoize_fun viterbi_ixm: viterbi_ix_rec
  with_memory dp_consistency_mapping
  monadifies (state) viterbi_ix_rec.simps[unfolded Let_def]

memoize_correct
  by memoize_prover

text ‹The main theorems generated by memoization.›
context
  includes state_monad_syntax
begin
thm viterbi_ixm'.simps viterbi_ixm_def
thm viterbi_ixm.memoized_correct
end

end (* Fixed IArray *)

definition
  "viterbi_ix os = viterbi_ix_rec (IArray os)"

end (* HMM3 Defs *)

context HMM3
begin

lemma viterbi_ix_drop_one:
  "viterbi_ix (o # os) s t (n + 1) = viterbi_ix os s t n"
  by (induction "length os - n" arbitrary: s n; simp add: viterbi_ix_def)

lemma viterbi_ix_viterbi:
  "viterbi_ix os s t 0 = viterbi s t os"
  unfolding viterbi_ix_def
proof (induction os arbitrary: s)
  case Nil
  then show ?case
    by simp
next
  case (Cons o os)
  show ?case
    using viterbi_ix_drop_one[unfolded viterbi_ix_def]
    by (subst viterbi.simps, subst viterbi_ix_rec.simps)
       (simp add: Cons.IH del: viterbi_ix_rec.simps viterbi.simps)
qed

lemma viterbi_code [code]:
  "viterbi s t os = fst (run_state (viterbi_ixm' (IArray os) s t 0) Mapping.empty)"
  by (simp only: viterbi_ix_def viterbi_ixm.memoized_correct viterbi_ix_viterbi[symmetric])

end (* Hidden Markov Model 3 *)

subsection ‹Misc›

lemma pmf_of_alist_support_aux_1:
  assumes " (_, p)  set μ. p  0"
  shows "(0 :: real)  (case map_of μ x of None  0 | Some p  p)"
  using assms by (auto split: option.split dest: map_of_SomeD)

lemma pmf_of_alist_support_aux_2:
  assumes " (_, p)  set μ. p  0"
    and "sum_list (map snd μ) = 1"
    and "distinct (map fst μ)"
  shows "+ x. ennreal (case map_of μ x of None  0 | Some p  p) count_space UNIV = 1"
  using assms
  apply (subst nn_integral_count_space)
  subgoal
    by (rule finite_subset[where B = "fst ` set μ"];
        force split: option.split_asm simp: image_iff dest: map_of_SomeD)
  apply (subst sum.mono_neutral_left[where T = "fst ` set μ"])
     apply blast
  subgoal
    by (smt ennreal_less_zero_iff map_of_eq_None_iff mem_Collect_eq option.case(1) subsetI)
  subgoal
    by auto
  subgoal premises prems
  proof -
    have "(x = 0..<length μ. snd (μ ! x))
      = sum (λ x. case map_of μ x of None  0 | Some v  v) (fst ` set μ)"
      apply (rule sym)
      apply (rule sum.reindex_cong[where l = "λ i. fst (μ ! i)"])
        apply (auto split: option.split)
      subgoal
        using prems(3) by (intro inj_onI, auto simp: distinct_conv_nth)
      subgoal
        by (auto simp: in_set_conv_nth rev_image_eqI)
      subgoal
        by (simp add: map_of_eq_None_iff)
      subgoal
        using map_of_eq_Some_iff[OF prems(3)]
        by (metis fst_conv nth_mem option.inject prod_eqI snd_conv)
      done
    with prems(2) show ?thesis
      by (smt pmf_of_alist_support_aux_1[OF assms(1)] atLeastLessThan_iff ennreal_1
          length_map nth_map sum.cong sum_ennreal sum_list_sum_nth
          )
  qed
  done

lemma pmf_of_alist_support:
  assumes " (_, p)  set μ. p  0"
    and "sum_list (map snd μ) = 1"
    and "distinct (map fst μ)"
  shows "set_pmf (pmf_of_alist μ)  fst ` set μ"
  unfolding pmf_of_alist_def
  apply (subst set_embed_pmf)
  subgoal for x
    using assms(1) by (auto split: option.split dest: map_of_SomeD)
  subgoal
    using pmf_of_alist_support_aux_2[OF assms] .
  apply (force split: option.split_asm simp: image_iff dest: map_of_SomeD)+
  done

text ‹Defining a Markov kernel from an association list.›
locale Closed_Kernel_From =
  fixes K :: "('s × ('t × real) list) list"
    and S :: "'t list"
  assumes wellformed: "S  []"
      and closed: " (s, μ)  set K.  (t, _)  set μ. t  set S"
      and is_pmf:
        " (_, μ)  set K.  (_, p)  set μ. p  0"
        " (_, μ)  set K. distinct (map fst μ)"
        " (s, μ)  set K. sum_list (map snd μ) = 1"
      and is_unique:
        "distinct (map fst K)"
begin

definition
  "K' s  case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) K) s of
  None  return_pmf (hd S) |
  Some s  s"

sublocale Closed_Kernel K' "set S"
  using wellformed closed is_pmf pmf_of_alist_support
  unfolding K'_def by - (standard; fastforce split: option.split_asm dest: map_of_SomeD)

definition [code]:
  "K1 = map_of (map (λ (s, μ). (s, map_of μ)) K)"

lemma pmf_of_alist_aux:
  assumes "(s, μ)  set K"
  shows
    "pmf (pmf_of_alist μ) t = (case map_of μ t of
      None  0
    | Some p  p)"
  using assms is_pmf unfolding pmf_of_alist_def
  by (intro pmf_embed_pmf pmf_of_alist_support_aux_2) 
     (auto 4 3 split: option.split dest: map_of_SomeD)

lemma unique: "μ = μ'" if "(s, μ)  set K" "(s, μ')  set K"
  using that is_unique
  by (smt Pair_inject distinct_conv_nth fst_conv in_set_conv_nth length_map nth_map)

lemma (in -) map_of_NoneD:
  "x  fst ` set M" if "map_of M x = None"
  using that by (auto dest: weak_map_of_SomeI)

lemma K'_code [code_post]:
  "pmf (K' s) t = (case K1 s of
      None  (if t = hd S then 1 else 0)
    | Some μ  case μ t of
        None  0
      | Some p  p
  )"
  unfolding K'_def K1_def
  apply (clarsimp split: option.split, safe)
                 apply (drule map_of_SomeD, drule map_of_NoneD, force)+
         apply (fastforce dest: unique map_of_SomeD simp: pmf_of_alist_aux)+
  done

end

subsection ‹Executing Concrete HMMs›

locale Concrete_HMM_defs =
  fixes 𝒦 :: "('s × ('s × real) list) list"
    and 𝒪 :: "('s × ('t × real) list) list"
    and 𝒪s :: "'t list"
    and 𝒦s :: "'s list"
begin

definition
  "𝒦' s  case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) 𝒦) s of
    None  return_pmf (hd 𝒦s) |
    Some s  s"

definition
  "𝒪' s  case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) 𝒪) s of
    None  return_pmf (hd 𝒪s) |
    Some s  s"

end

locale Concrete_HMM = Concrete_HMM_defs +
  assumes observations_wellformed': "𝒪s  []"
      and observations_closed': " (s, μ)  set 𝒪.  (t, _)  set μ. t  set 𝒪s"
      and observations_form_pmf':
        " (_, μ)  set 𝒪.  (_, p)  set μ. p  0"
        " (_, μ)  set 𝒪. distinct (map fst μ)"
        " (s, μ)  set 𝒪. sum_list (map snd μ) = 1"
      and observations_unique:
        "distinct (map fst 𝒪)"
  assumes states_wellformed: "𝒦s  []"
      and states_closed: " (s, μ)  set 𝒦.  (t, _)  set μ. t  set 𝒦s"
      and states_form_pmf:
        " (_, μ)  set 𝒦.  (_, p)  set μ. p  0"
        " (_, μ)  set 𝒦. distinct (map fst μ)"
        " (s, μ)  set 𝒦. sum_list (map snd μ) = 1"
      and states_unique:
        "distinct (map fst 𝒦)" "distinct 𝒦s"
begin

interpretation O: Closed_Kernel_From 𝒪 𝒪s
  rewrites "O.K' = 𝒪'"
proof -
  show ‹Closed_Kernel_From 𝒪 𝒪s
    using observations_wellformed' observations_closed' observations_form_pmf' observations_unique
    by unfold_locales auto
  show ‹Closed_Kernel_From.K' 𝒪 𝒪s = 𝒪'›
    unfolding Closed_Kernel_From.K'_def[OF ‹Closed_Kernel_From 𝒪 𝒪s] 𝒪'_def
    by auto
qed

interpretation K: Closed_Kernel_From 𝒦 𝒦s
  rewrites "K.K' = 𝒦'"
proof -
  show ‹Closed_Kernel_From 𝒦 𝒦s
    using states_wellformed states_closed states_form_pmf states_unique by unfold_locales auto
  show ‹Closed_Kernel_From.K' 𝒦 𝒦s = 𝒦'›
    unfolding Closed_Kernel_From.K'_def[OF ‹Closed_Kernel_From 𝒦 𝒦s] 𝒦'_def
    by auto
qed

lemmas O_code = O.K'_code O.K1_def
lemmas K_code = K.K'_code K.K1_def

sublocale HMM_interp: HMM4 𝒪' "set 𝒦s" 𝒦s "set 𝒪s" 𝒦'
  using O.Closed_Kernel_axioms K.Closed_Kernel_axioms states_unique(2)
  by (intro_locales; intro HMM4_axioms.intro HMM3_axioms.intro HOL.refl)

end (* Concrete HMM *)

end

Theory HMM_Example

section ‹Example›

theory HMM_Example
  imports
    HMM_Implementation
    "HOL-Library.AList_Mapping"
begin

text ‹
  We would like to implement mappings as red-black trees
  but they require the key type to be linearly ordered.
  Unfortunately, HOL-Analysis› fixes the product order to the element-wise order
  and thus we cannot restore a linear order,
  and the red-black tree implementation (from HOL-Library›) cannot be used.
›

text ‹The ice cream example from Jurafsky and Martin \cite{Jurafsky}.›

definition
  "states = [''start'', ''hot'', ''cold'', ''end'']"

definition observations :: "int list" where
  "observations = [0, 1, 2, 3]"

definition
  "kernel =
    [
      (''start'', [(''hot'',0.8 :: real), (''cold'',0.2)]),
      (''hot'',   [(''hot'',0.6 :: real), (''cold'',0.3), (''end'', 0.1)]),
      (''cold'',  [(''hot'',0.4 :: real), (''cold'',0.5), (''end'', 0.1)]),
      (''end'',   [(''end'', 1)])
    ]"

definition
  "emissions =
    [
      (''hot'',   [(1, 0.2), (2, 0.4), (3, 0.4)]),
      (''cold'',  [(1, 0.5), (2, 0.4), (3, 0.1)]),
      (''end'',   [(0, 1)])
    ]
  "

global_interpretation Concrete_HMM kernel emissions observations states
  defines
      viterbi_rec   = HMM_interp.viterbi_ixm'
  and viterbi       = HMM_interp.viterbi
  and viterbi_final = HMM_interp.viterbi_final
  and forward_rec   = HMM_interp.forward_ixm'
  and forward       = HMM_interp.forward
  and likelihood    = HMM_interp.likelihood_compute
  by (standard; eval)

lemmas [code] = HMM_interp.viterbi_ixm'.simps[unfolded O_code K_code]

lemmas [code] = HMM_interp.forward_ixm'.simps[unfolded O_code K_code]

value "likelihood ''start'' [1, 1, 1]"

text ‹
  If we enforce the last observation to correspond to @{term ''end''},
  then @{term forward} and @{term likelihood} yield the same result.
›
value "likelihood ''start'' [1, 1, 1, 0]"

value "forward ''start'' ''end'' [1, 1, 1, 0]"

value "forward ''start'' ''end'' [3, 3, 3, 0]"

value "forward ''start'' ''end'' [3, 1, 3, 0]"

value "forward ''start'' ''end'' [3, 1, 3, 1, 0]"

value "viterbi ''start'' ''end'' [1, 1, 1, 0]"

value "viterbi ''start'' ''end'' [3, 3, 3, 0]"

value "viterbi ''start'' ''end'' [3, 1, 3, 0]"

value "viterbi ''start'' ''end'' [3, 1, 3, 1, 0]"

text ‹
  If we enforce the last observation to correspond to @{term ''end''},
  then @{term viterbi} and @{term viterbi_final} yield the same result.
›
value "viterbi_final ''start'' [3, 1, 3, 1, 0]"

value "viterbi_final ''start'' [1, 1, 1, 1, 1, 1, 1, 0]"

value "viterbi_final ''start'' [1, 1, 1, 1, 1, 1, 1, 1]"

end