Session MDP-Algorithms

Theory Value_Iteration

(* Author: Maximilian Schäffeler *)

theory Value_Iteration
  imports "MDP-Rewards.MDP_reward"
begin

context MDP_att_ℒ
begin

section ‹Value Iteration›
text ‹
In the previous sections we derived that repeated application of @{const "ℒb"} to any bounded 
function from states to the reals converges to the optimal value of the MDP @{const "νb_opt"}.

We can turn this procedure into an algorithm that computes not only an approximation of 
@{const "νb_opt"} but also a policy that is arbitrarily close to optimal.

Most of the proofs rely on the assumption that the supremum in @{const "ℒb"} can always be attained.
›

text ‹
The following lemma shows that the relation we use to prove termination of the value iteration 
algorithm decreases in each step.
In essence, the distance of the estimate to the optimal value decreases by a factor of at 
least @{term l} per iteration.›


lemma vi_rel_dec: 
  assumes "l ≠ 0" "ℒb v ≠ νb_opt"
  shows "⌈log (1 / l) (dist (ℒb v) νb_opt) - c⌉ < ⌈log (1 / l) (dist v νb_opt) - c⌉"
proof -
  have "log (1 / l) (dist (ℒb v) νb_opt) - c ≤ log (1 / l) (l * dist v νb_opt) - c"
    using contraction_ℒ[of _ "νb_opt"] disc_lt_one
    by (auto simp: assms less_le intro: log_le)
  also have "… = log (1 / l) l + log (1/l) (dist v νb_opt) - c"
    using assms disc_lt_one 
    by (auto simp: less_le intro!: log_mult)
  also have "… = -(log (1 / l) (1/l)) + (log (1/l) (dist v νb_opt)) - c"
    using assms disc_lt_one
    by (subst log_inverse[symmetric]) (auto simp: less_le right_inverse_eq)
  also have "… = (log (1/l) (dist v νb_opt)) - 1 - c"
    using assms order.strict_implies_not_eq[OF disc_lt_one]
    by (auto intro!: log_eq_one neq_le_trans)
  finally have "log (1 / l) (dist (ℒb v) νb_opt) - c ≤ log (1 / l) (dist v νb_opt) - 1 - c" .
  thus ?thesis
    by linarith
qed

lemma dist_ℒb_lt_dist_opt: "dist v (ℒb v) ≤ 2 * dist v νb_opt"
proof -
  have le1: "dist v (ℒb v) ≤ dist v νb_opt + dist (ℒb v) νb_opt"
    by (simp add: dist_triangle dist_commute)
  have le2: "dist (ℒb v) νb_opt ≤ l * dist v νb_opt"
    using ℒb_opt contraction_ℒ
    by metis
  show ?thesis
    using mult_right_mono[of l 1] disc_lt_one 
    by (fastforce intro!: order.trans[OF le2] order.trans[OF le1])
qed

abbreviation "term_measure ≡ (λ(eps, v).
    if v = νb_opt ∨ l = 0
    then 0
    else nat (ceiling (log (1/l) (dist v νb_opt) - log (1/l) (eps * (1-l) / (8 * l)))))"

function value_iteration :: "real ⇒ ('s ⇒b real) ⇒ ('s ⇒b real)" where
  "value_iteration eps v =
  (if 2 * l * dist v (ℒb v) < eps * (1-l) ∨ eps ≤ 0 then ℒb v else value_iteration eps (ℒb v))"
  by auto

termination
proof (relation "Wellfounded.measure term_measure", (simp; fail), cases "l = 0")
  case False
  fix eps v
  assume h: "¬ (2 * l * dist v (ℒb v) < eps * (1 - l) ∨ eps ≤ 0)"
  show "((eps, ℒb v), eps, v) ∈ Wellfounded.measure term_measure"
  proof -
    have gt_zero[simp]: "l ≠ 0" "eps > 0" and dist_ge: "eps * (1 - l) ≤ dist v (ℒb v) * (2 * l)"
      using h
      by (auto simp: algebra_simps)
    have v_not_opt: "v ≠ νb_opt"
      using h
      by force
    have "log (1 / l) (eps * (1 - l) / (8 * l)) < log (1 / l) (dist v νb_opt)"
    proof (intro log_less)
      show "1 < 1 / l"
        by (auto intro!: mult_imp_less_div_pos intro: neq_le_trans)
      show "0 < eps * (1 - l) / (8 * l)" 
        by (auto intro!: mult_imp_less_div_pos intro: neq_le_trans)
      show "eps * (1 - l) / (8 * l) < dist v νb_opt" 
        using dist_pos_lt[OF v_not_opt] dist_ℒb_lt_dist_opt[of v] gt_zero zero_le_disc 
          mult_strict_left_mono[of "dist v (ℒb v)" "(4 * dist v νb_opt)" l]
        by (intro mult_imp_div_pos_less le_less_trans[OF dist_ge], argo+)
    qed
    thus ?thesis
      using vi_rel_dec h
      by auto
  qed
qed auto

text ‹
The distance between an estimate for the value and the optimal value can be bounded with respect to 
the distance between the estimate and the result of applying it to @{const ℒb}
›
lemma contraction_ℒ_dist: "(1 - l) * dist v νb_opt ≤ dist v (ℒb v)"
  using contraction_dist contraction_ℒ disc_lt_one zero_le_disc
  by fastforce

lemma dist_ℒb_opt_eps:
  assumes "eps > 0" "2 * l * dist v (ℒb v) < eps * (1-l)"
  shows "dist (ℒb v) νb_opt < eps / 2"
proof -
  have "dist v νb_opt ≤ dist v (ℒb v) / (1 - l)"
    using contraction_ℒ_dist
    by (simp add: mult.commute pos_le_divide_eq)
  hence "2 * l * dist v νb_opt ≤ 2 * l * (dist v (ℒb v) / (1 - l))"
    using contraction_ℒ_dist assms mult_le_cancel_left_pos[of "2 * l"]
    by (fastforce intro!: mult_left_mono[of _ _ "2 * l"])
  hence "2 * l * dist v νb_opt < eps"
    by (auto simp: assms(2) pos_divide_less_eq intro: order.strict_trans1)
  hence "dist v νb_opt * l < eps / 2"
    by argo
  hence "l * dist v νb_opt < eps / 2"
    by (auto simp: algebra_simps)
  thus "dist (ℒb v) νb_opt < eps / 2"
    using contraction_ℒ[of v νb_opt] 
    by auto
qed

text ‹
The estimates above allow to give a bound on the error of @{const value_iteration}.
›
declare value_iteration.simps[simp del]

lemma value_iteration_error: 
  assumes "eps > 0"
  shows "dist (value_iteration eps v) νb_opt < eps / 2"
  using assms dist_ℒb_opt_eps value_iteration.simps
  by (induction eps v rule: value_iteration.induct) auto

text ‹
After the value iteration terminates, one can easily obtain a stationary deterministic 
epsilon-optimal policy.

Such a policy does not exist in general, attainment of the supremum in @{const ℒb} is required.
›
definition "find_policy (v :: 's ⇒b real) s = arg_max_on (λa. La a v s) (A s)"

definition "vi_policy eps v = find_policy (value_iteration eps v)"

text ‹
We formalize the attainment of the supremum using a predicate @{const has_arg_max}.
›

abbreviation "vi u n ≡ (ℒb ^^n) u"

lemma ℒb_iter_mono:
  assumes "u ≤ v" shows "vi u n ≤ vi v n"
  using assms ℒb_mono 
  by (induction n) auto

lemma 
  assumes "vi v (Suc n) ≤ vi v n" 
  shows "vi v (Suc n + m) ≤ vi v (n + m)"
proof -
  have "vi v (Suc n + m) = vi (vi v (Suc n)) m"
    by (simp add: Groups.add_ac(2) funpow_add funpow_swap1)
  also have "... ≤ vi (vi v n) m"
    using ℒb_iter_mono[OF assms]
    by auto
  also have "... = vi v (n + m)"
    by (simp add: add.commute funpow_add)
  finally show ?thesis .
qed


lemma 
  assumes "vi v n ≤ vi v (Suc n)" 
  shows "vi v (n + m) ≤ vi v (Suc n + m)"
proof -
  have "vi v (n + m) ≤ vi (vi v n) m"
    by (simp add: Groups.add_ac(2) funpow_add funpow_swap1)
  also have "… ≤ vi v (Suc n + m)"
    using ℒb_iter_mono[OF assms]
    by (auto simp only: add.commute funpow_add o_apply)
  finally show ?thesis .
qed

(* 6.3.1 *)
(* a) *)
lemma "vi v ⇢ νb_opt"
  using ℒb_lim.

lemma "(λn. dist (vi v (Suc n)) (vi v n)) ⇢ 0"
  using thm_6_3_1_b_aux[of v]
  by (auto simp only: dist_commute[of "((ℒb ^^ Suc _) v)"])



end

context MDP_att_ℒ 
begin

text ‹
The error of the resulting policy is bounded by the distance from its value to the value computed 
by the value iteration plus the error in the value iteration itself.
We show that both are less than @{term "eps / 2"} when the algorithm terminates.
›
lemma find_policy_error_bound:
  assumes "eps > 0" "2 * l * dist v (ℒb v) < eps * (1-l)"
  shows "dist (νb (mk_stationary_det (find_policy (ℒb v)))) νb_opt < eps"
proof -
  let ?d = "mk_dec_det (find_policy (ℒb v))"
  let ?p = "mk_stationary ?d"
    (* shorter proof:     by (auto simp: arg_max_SUP[OF find_policy_QR_is_arg_max] ℒb_split.rep_eq ℒ_split_def )*)
  have L_eq_ℒb: "L (mk_dec_det (find_policy v)) v = ℒb v" for v
    unfolding find_policy_def
  proof (intro antisym)
    show "L (mk_dec_det (λs. arg_max_on (λa. La a v s) (A s))) v ≤ ℒb v"
      using Sup_att has_arg_max_arg_max abs_L_le
      unfolding ℒb.rep_eq ℒ_eq_SUP_det less_eq_bfun_def arg_max_on_def is_dec_det_def max_L_ex_def
      by (auto intro!: cSUP_upper bounded_imp_bdd_above boundedI[of _ "rM + l * norm v"])
  next
    show "ℒb v ≤ L (mk_dec_det (λs. arg_max_on (λa. La a v s) (A s))) v"
      unfolding less_eq_bfun_def ℒb.rep_eq ℒ_eq_SUP_det
      using Sup_att ex_dec_det
      by (auto intro!: cSUP_least app_arg_max_ge simp: L_eq_La_det max_L_ex_def is_dec_det_def)
  qed
  have "dist (νb ?p) (ℒb v) = dist (L ?d (νb ?p)) (ℒb v)"
    using L_ν_fix 
    by force
  also have "… ≤ dist (L ?d (νb ?p)) (ℒb (ℒb v)) + dist (ℒb (ℒb v)) (ℒb v)"
    using dist_triangle 
    by blast
  also have "… = dist (L ?d (νb ?p)) (L ?d (ℒb v)) + dist (ℒb (ℒb v)) (ℒb v)"
    by (auto simp: L_eq_ℒb)
  also have "… ≤ l *  dist (νb ?p) (ℒb v) + l * dist (ℒb v) v"
    using contraction_ℒ contraction_L
    by (fastforce intro!: add_mono)
  finally have aux: "dist (νb ?p) (ℒb v) ≤ l * dist (νb ?p) (ℒb v) + l * dist (ℒb v) v" .
  hence "dist (νb ?p) (ℒb v) - l * dist (νb ?p) (ℒb v) ≤ l * dist (ℒb v) v"
    by auto
  hence "dist (νb ?p) (ℒb v) * (1 - l) ≤ l * dist (ℒb v) v"
    by argo
  hence  "2 * dist (νb ?p) (ℒb v) * (1-l) ≤ 2 * (l * dist (ℒb v) v)"
    using zero_le_disc mult_left_mono 
    by auto
  also have "… ≤ eps * (1-l)"
    using assms
    by (auto intro!: mult_left_mono simp: dist_commute pos_divide_le_eq)
  finally have "2 * dist (νb ?p) (ℒb v) * (1 - l) ≤ eps * (1 - l)" .
  hence "2 * dist (νb ?p) (ℒb v) ≤ eps"
    using disc_lt_one mult_right_le_imp_le
    by auto
  moreover have "2 * dist (ℒb v) νb_opt < eps"
    using dist_ℒb_opt_eps assms 
    by fastforce
  moreover have "dist (νb ?p) νb_opt ≤ dist (νb ?p) (ℒb v) + dist (ℒb v) νb_opt"
    using dist_triangle 
    by blast  
  ultimately show ?thesis 
    by auto
qed

lemma vi_policy_opt:
  assumes "0 < eps"
  shows "dist (νb (mk_stationary_det (vi_policy eps v))) νb_opt < eps"
  unfolding vi_policy_def 
  using assms
proof (induction eps v rule: value_iteration.induct)
  case (1 v)
  then show ?case
    using find_policy_error_bound
    by (subst value_iteration.simps) auto
qed

lemma lemma_6_3_1_d:
  assumes "eps > 0"
  assumes "2 * l * dist (vi v (Suc n)) (vi v n) < eps * (1-l)"
  shows "dist (vi v (Suc n)) νb_opt < eps / 2"
  using dist_ℒb_opt_eps assms
  by (simp add: dist_commute)

end

context MDP_act begin

definition "find_policy' (v :: 's ⇒b real) s = arb_act (opt_acts v s)"

definition "vi_policy' eps v = find_policy' (value_iteration eps v)"

lemma find_policy'_error_bound:
  assumes "eps > 0" "2 * l * dist v (ℒb v) < eps * (1-l)"
  shows "dist (νb (mk_stationary_det (find_policy' (ℒb v)))) νb_opt < eps"
proof -
  let ?d = "mk_dec_det (find_policy' (ℒb v))"
  let ?p = "mk_stationary ?d"
  have L_eq_ℒb: "L (mk_dec_det (find_policy' v)) v = ℒb v" for v
    unfolding find_policy'_def
    by (metis ν_improving_imp_ℒb ν_improving_opt_acts)
  have "dist (νb ?p) (ℒb v) = dist (L ?d (νb ?p)) (ℒb v)"
    using L_ν_fix 
    by force
  also have "… ≤ dist (L ?d (νb ?p)) (ℒb (ℒb v)) + dist (ℒb (ℒb v)) (ℒb v)"
    using dist_triangle 
    by blast
  also have "… = dist (L ?d (νb ?p)) (L ?d (ℒb v)) + dist (ℒb (ℒb v)) (ℒb v)"
    by (auto simp: L_eq_ℒb)
  also have "… ≤ l *  dist (νb ?p) (ℒb v) + l * dist (ℒb v) v"
    using contraction_ℒ contraction_L
    by (fastforce intro!: add_mono)
  finally have aux: "dist (νb ?p) (ℒb v) ≤ l * dist (νb ?p) (ℒb v) + l * dist (ℒb v) v" .
  hence "dist (νb ?p) (ℒb v) - l * dist (νb ?p) (ℒb v) ≤ l * dist (ℒb v) v"
    by auto
  hence "dist (νb ?p) (ℒb v) * (1 - l) ≤ l * dist (ℒb v) v"
    by argo
  hence  "2 * dist (νb ?p) (ℒb v) * (1-l) ≤ 2 * (l * dist (ℒb v) v)"
    using zero_le_disc mult_left_mono 
    by auto
  also have "… ≤ eps * (1-l)"
    using assms
    by (auto intro!: mult_left_mono simp: dist_commute pos_divide_le_eq)
  finally have "2 * dist (νb ?p) (ℒb v) * (1 - l) ≤ eps * (1 - l)".
  hence "2 * dist (νb ?p) (ℒb v) ≤ eps"
    using disc_lt_one mult_right_le_imp_le
    by auto
  moreover have "2 * dist (ℒb v) νb_opt < eps"
    using dist_ℒb_opt_eps assms 
    by fastforce
  moreover have "dist (νb ?p) νb_opt ≤ dist (νb ?p) (ℒb v) + dist (ℒb v) νb_opt"
    using dist_triangle 
    by blast  
  ultimately show ?thesis 
    by auto
qed

lemma vi_policy'_opt:
  assumes "eps > 0" "l > 0"
  shows "dist (νb (mk_stationary_det (vi_policy' eps v))) νb_opt < eps"
  unfolding vi_policy'_def 
  using assms
proof (induction eps v rule: value_iteration.induct)
  case (1 v)
  then show ?case
    using find_policy'_error_bound
    by (subst value_iteration.simps) auto
qed

end
end

Theory Policy_Iteration

(* Author: Maximilian Schäffeler *)

theory Policy_Iteration
  imports "MDP-Rewards.MDP_reward"

begin

section ‹Policy Iteration›
text ‹
The Policy Iteration algorithms provides another way to find optimal policies under the expected 
total reward criterion.
It differs from Value Iteration in that it continuously improves an initial guess for an optimal 
decision rule. Its execution can be subdivided into two alternating steps: policy evaluation and 
policy improvement.

Policy evaluation means the calculation of the value of the current decision rule.

During the improvement phase, we choose the decision rule with the maximum value for L, 
while we prefer to keep the old action selection in case of ties.
›


context MDP_att_ℒ begin
definition "policy_eval d = νb (mk_stationary_det d)"
end

context MDP_act
begin

definition "policy_improvement d v s = (
  if is_arg_max (λa. La a (apply_bfun v) s) (λa. a ∈ A s) (d s) 
  then d s
  else arb_act (opt_acts v s))"

definition "policy_step d = policy_improvement d (policy_eval d)"

(* todo: move check is_dec_det outside the recursion *)
function policy_iteration :: "('s ⇒ 'a) ⇒ ('s ⇒ 'a)" where
  "policy_iteration d = (
  let d' = policy_step d in
  if d = d' ∨ ¬is_dec_det d then d else policy_iteration d')"
  by auto

text ‹
The policy iteration algorithm as stated above does require that the supremum in @{const ℒb} is
always attained.
›

text ‹
Each policy improvement returns a valid decision rule.
›
lemma is_dec_det_pi: "is_dec_det (policy_improvement d v)"
  unfolding policy_improvement_def is_dec_det_def is_arg_max_def
  by (auto simp: some_opt_acts_in_A)

lemma policy_improvement_is_dec_det: "d ∈ DD ⟹ policy_improvement d v ∈ DD"
  unfolding policy_improvement_def is_dec_det_def
  using some_opt_acts_in_A
  by auto

lemma policy_improvement_improving: 
  assumes "d ∈ DD" 
  shows "ν_improving v (mk_dec_det (policy_improvement d v))"
proof -
  have "ℒb v x = L (mk_dec_det (policy_improvement d v)) v x" for x
    using is_opt_act_some
    by (fastforce simp: thm_6_2_10_a_aux' L_eq_La_det is_opt_act_def policy_improvement_def
        arg_max_SUP[symmetric, of _ _ "(policy_improvement d v x)"] )
  thus ?thesis
    using policy_improvement_is_dec_det assms
    by (auto simp: ν_improving_alt)
qed

lemma eval_policy_step_L:
  assumes "is_dec_det d"
  shows "L (mk_dec_det (policy_step d)) (policy_eval d) = ℒb (policy_eval d)"
  unfolding policy_step_def
  using assms
  by (auto simp: ν_improving_imp_ℒb[OF policy_improvement_improving])

text ‹ The sequence of policies generated by policy iteration has monotonically increasing 
discounted reward.›
lemma policy_eval_mon:
  assumes "is_dec_det d"
  shows "policy_eval d ≤ policy_eval (policy_step d)"
proof -
  let ?d' = "mk_dec_det (policy_step d)"
  let ?dp = "mk_stationary_det d"
  let ?P = "∑t. l ^ t *R 𝒫1 ?d' ^^ t"

  have "L (mk_dec_det d) (policy_eval d) ≤ L ?d' (policy_eval d)"
    using assms
    by (auto simp: L_le_ℒb eval_policy_step_L)
  hence "policy_eval d ≤ L ?d' (policy_eval d)"
    using L_ν_fix policy_eval_def
    by auto
  hence "νb ?dp ≤ r_decb ?d' + l *R 𝒫1 ?d' (νb ?dp)"
    unfolding policy_eval_def L_def
    by auto
  hence "(id_blinfun - l *R 𝒫1 ?d') (νb ?dp) ≤ r_decb ?d'"
    by (simp add: blinfun.diff_left diff_le_eq scaleR_blinfun.rep_eq)
  hence "?P ((id_blinfun - l *R 𝒫1 ?d') (νb ?dp)) ≤ ?P (r_decb ?d')"
    using lemma_6_1_2_b
    by auto
  hence "νb ?dp ≤ ?P (r_decb ?d')"
    using inv_norm_le'(2)[OF norm_𝒫1_l_less] blincomp_scaleR_right suminf_cong
    by (metis (mono_tags, lifting))
  thus ?thesis
    unfolding policy_eval_def
    by (auto simp: ν_stationary)
qed

text ‹
If policy iteration terminates, i.e. @{term "d = policy_step d"}, then it does so with optimal value.
›
lemma policy_step_eq_imp_opt:
  assumes "is_dec_det d" "d = policy_step d" 
  shows "νb (mk_stationary (mk_dec_det d)) = νb_opt"
proof -
  have "policy_eval d = ℒb (policy_eval d)"
    unfolding policy_eval_def
    using L_ν_fix assms eval_policy_step_L[unfolded policy_eval_def]
    by fastforce
  thus ?thesis
    unfolding policy_eval_def
    using ℒ_fix_imp_opt
    by blast
qed

end

text ‹We prove termination of policy iteration only if both the state and action sets are finite.›
locale MDP_PI_finite = MDP_act A K r l arb_act
  for
    A and
    K :: "'s ::countable × 'a ::countable ⇒ 's pmf" and r l arb_act +
  assumes fin_states: "finite (UNIV :: 's set)" and fin_actions: "⋀s. finite (A s)"
begin

text ‹If the state and action sets are both finite, 
  then so is the set of deterministic decision rules @{const "DD"}›
lemma finite_DD[simp]: "finite DD"
proof -
  let ?set = "{d. ∀x :: 's. (x ∈ UNIV ⟶ d x ∈ (⋃s. A s)) ∧ (x ∉ UNIV ⟶ d x = undefined)}"
  have "finite (⋃s. A s)"
    using fin_actions fin_states by blast
  hence "finite ?set"
    using fin_states
    by (fastforce intro: finite_set_of_finite_funs)
  moreover have "DD ⊆ ?set"
    unfolding is_dec_det_def 
    by auto
  ultimately show ?thesis
    using finite_subset 
    by auto
qed

lemma finite_rel: "finite {(u, v). is_dec_det u ∧ is_dec_det v ∧ νb (mk_stationary_det u) > 
  νb (mk_stationary_det v)}"
proof-
  have aux: "finite {(u, v). is_dec_det u ∧ is_dec_det v}"
    by auto
  show ?thesis
    by (auto intro: finite_subset[OF _ aux])
qed

text ‹
This auxiliary lemma shows that policy iteration terminates if no improvement to the value of 
the policy could be made, as then the policy remains unchanged.
›
lemma eval_eq_imp_policy_eq: 
  assumes "policy_eval d = policy_eval (policy_step d)" "is_dec_det d"
  shows "d = policy_step d"
proof -
  have "policy_eval d s = policy_eval (policy_step d) s" for s
    using assms 
    by auto
  have "policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d))"
    unfolding policy_eval_def
    using L_ν_fix 
    by (auto simp: assms(1)[symmetric, unfolded policy_eval_def])
  hence "policy_eval d = ℒb (policy_eval d)"
    by (metis L_ν_fix policy_eval_def assms eval_policy_step_L)
  hence "L (mk_dec_det d) (policy_eval d) s = ℒb (policy_eval d) s" for s
    using ‹policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d))› assms(1) by auto
  hence "is_arg_max (λa. La a (νb (mk_stationary (mk_dec_det d))) s) (λa. a ∈ A s) (d s)" for s
    unfolding L_eq_La_det
    unfolding policy_eval_def ℒb.rep_eq ℒ_eq_SUP_det SUP_step_det_eq
    using assms(2) is_dec_det_def La_le
    by (auto simp del: νb.rep_eq simp: νb.rep_eq[symmetric] 
        intro!: SUP_is_arg_max boundedI[of _ "rM + l * norm _"] bounded_imp_bdd_above)
  thus ?thesis
    unfolding policy_eval_def policy_step_def policy_improvement_def
    by auto
qed

text ‹
We are now ready to prove termination in the context of finite state-action spaces.
Intuitively, the algorithm terminates as there are only finitely many decision rules,
and in each recursive call the value of the decision rule increases.
›
termination policy_iteration
proof (relation "{(u, v). u ∈ DD ∧ v ∈ DD ∧ νb (mk_stationary_det u) > νb (mk_stationary_det v)}")
  show "wf {(u, v). u ∈ DD ∧ v ∈ DD ∧ νb (mk_stationary_det v) < νb (mk_stationary_det u)}"
    using finite_rel 
    by (auto intro!: finite_acyclic_wf acyclicI_order)
next
  fix d x
  assume h: "x = policy_step d" "¬ (d = x ∨ ¬ is_dec_det d)"
  have "is_dec_det d ⟹ νb (mk_stationary_det d) ≤ νb (mk_stationary_det (policy_step d))"
    using policy_eval_mon  
    by (simp add: policy_eval_def)
  hence "is_dec_det d ⟹ d ≠ policy_step d ⟹
    νb (mk_stationary_det d) < νb (mk_stationary_det (policy_step d))"
    using eval_eq_imp_policy_eq policy_eval_def
    by (force intro!: order.not_eq_order_implies_strict)
  thus "(x, d) ∈ {(u, v). u ∈ DD ∧ v ∈ DD ∧ νb (mk_stationary_det v) < νb (mk_stationary_det u)}"
    using is_dec_det_pi policy_step_def h 
    by auto
qed

text ‹
The termination proof gives us access to the induction rule/simplification lemmas associated 
with the @{const policy_iteration} definition.
Thus we can prove that the algorithm finds an optimal policy.
›

lemma is_dec_det_pi': "d ∈ DD ⟹ is_dec_det (policy_iteration d)"
  using is_dec_det_pi
  by (induction d rule: policy_iteration.induct) (auto simp: Let_def policy_step_def)

lemma pi_pi[simp]: "d ∈ DD ⟹ policy_step (policy_iteration d) = policy_iteration d"
  using is_dec_det_pi
  by (induction d rule: policy_iteration.induct) (auto simp: policy_step_def Let_def)

lemma policy_iteration_correct: 
  "d ∈ DD ⟹ νb (mk_stationary_det (policy_iteration d)) = νb_opt" 
  by (induction d rule: policy_iteration.induct)
    (fastforce intro!: policy_step_eq_imp_opt is_dec_det_pi' simp del: policy_iteration.simps)
end

context MDP_finite_type begin
text ‹
The following proofs concern code generation, i.e. how to represent @{const 𝒫1} as a matrix.
›

sublocale MDP_att_ℒ
  by (auto simp: A_ne finite_is_arg_max MDP_att_ℒ_def MDP_att_ℒ_axioms_def max_L_ex_def 
      has_arg_max_def MDP_reward_axioms) 

definition "fun_to_matrix f = matrix (λv. (χ j. f (vec_nth v) j))"
definition "Ek_mat d = fun_to_matrix (λv. ((𝒫1 d) (Bfun v)))"
definition "nu_inv_mat d = fun_to_matrix ((λv. ((id_blinfun - l *R 𝒫1 d) (Bfun v))))"
definition "nu_mat d = fun_to_matrix (λv. ((∑i. (l *R 𝒫1 d) ^^ i) (Bfun v)))"

lemma apply_nu_inv_mat: 
  "(id_blinfun - l *R 𝒫1 d) v = Bfun (λi. ((nu_inv_mat d) *v (vec_lambda v)) $ i)"
proof -
  have eq_onpI: "P x ⟹ eq_onp P x x" for P x
    by(simp add: eq_onp_def)

  have "Real_Vector_Spaces.linear (λv. vec_lambda (((id_blinfun - l *R 𝒫1 d) (bfun.Bfun (($) v)))))"
    by (auto simp del: real_scaleR_def intro: linearI
        simp: scaleR_vec_def eq_onpI plus_vec_def vec_lambda_inverse plus_bfun.abs_eq[symmetric] 
        scaleR_bfun.abs_eq[symmetric] blinfun.scaleR_right blinfun.add_right)
  thus ?thesis
    unfolding Ek_mat_def fun_to_matrix_def nu_inv_mat_def
    by (auto simp: apply_bfun_inverse vec_lambda_inverse)
qed

lemma bounded_linear_vec_lambda: "bounded_linear (λx. vec_lambda (x :: 's ⇒b real))"
proof (intro bounded_linear_intro)
  fix x :: "'s ⇒b real"
  have "sqrt (∑ i ∈ UNIV . (apply_bfun x i)2) ≤ (∑ i ∈ UNIV . ¦(apply_bfun x i)¦)"
    using L2_set_le_sum_abs 
    unfolding L2_set_def
    by auto
  also have "(∑ i ∈ UNIV . ¦(apply_bfun x i)¦) ≤ (card (UNIV :: 's set) * (⨆xa. ¦apply_bfun x xa¦))"
    by (auto intro!: cSup_upper sum_bounded_above)
  finally show "norm (vec_lambda (apply_bfun x)) ≤ norm x * CARD('s)"
    unfolding norm_vec_def norm_bfun_def dist_bfun_def L2_set_def
    by (auto simp add: mult.commute)
qed (auto simp: plus_vec_def scaleR_vec_def)


lemma bounded_linear_vec_lambda_blinfun: 
  fixes f :: "('s ⇒b real) ⇒L ('s ⇒b real)"
  shows "bounded_linear (λv. vec_lambda (apply_bfun (blinfun_apply f (bfun.Bfun (($) v)))))" 
  using blinfun.bounded_linear_right
  by (fastforce intro: bounded_linear_compose[OF bounded_linear_vec_lambda] 
      bounded_linear_bfun_nth bounded_linear_compose[of f])

lemma invertible_nu_inv_max: "invertible (nu_inv_mat d)"
  unfolding nu_inv_mat_def fun_to_matrix_def
  by (auto simp: matrix_invertible inv_norm_le' vec_lambda_inverse apply_bfun_inverse 
      bounded_linear.linear[OF bounded_linear_vec_lambda_blinfun]
      intro!: exI[of _ "λv. (χ j. (λv. (∑i. (l *R 𝒫1 d) ^^ i) (Bfun v)) (vec_nth v) j)"])

end

definition "least_arg_max f P = (LEAST x. is_arg_max f P x)"

locale MDP_ord = MDP_finite_type A K r l
  for A and
    K :: "'s :: {finite, wellorder} × 'a :: {finite, wellorder} ⇒ 's pmf"
    and r l
begin

lemma ℒ_fin_eq_det: "ℒ v s = (⨆a ∈ A s. La a v s)"
  by (simp add: SUP_step_det_eq ℒ_eq_SUP_det)

lemma ℒb_fin_eq_det: "ℒb v s = (⨆a ∈ A s. La a v s)"
  by (simp add: SUP_step_det_eq ℒb.rep_eq ℒ_eq_SUP_det)

sublocale MDP_PI_finite A K r l "λX. Least (λx. x ∈ X)"
  by unfold_locales (auto intro: LeastI)

end
end

Theory Modified_Policy_Iteration

(* Author: Maximilian Schäffeler *)

theory Modified_Policy_Iteration
  imports 
    Policy_Iteration
    Value_Iteration
begin

section ‹Modified Policy Iteration›

locale MDP_MPI = MDP_finite_type A K r l + MDP_act A K r l arb_act
  for A and K :: "'s :: finite × 'a :: finite ⇒ 's pmf" and r l arb_act
begin

subsection ‹The Advantage Function @{term B}›

definition "B v s = (⨆d ∈ DR. (r_dec d s + (l *R 𝒫1 d - id_blinfun) v s))"

text "The function @{const B} denotes the advantage of choosing the optimal action vs.
  the current value estimate"

lemma B_eq_ℒ: "B v s = ℒ v s - v s"
proof -
  have *: "B v s = (⨆d ∈ DR. L d v s - v s)"
    unfolding B_def L_def
    by (auto simp add: blinfun.bilinear_simps add_diff_eq)
  show ?thesis
    unfolding *
  proof (rule antisym)
    show "(⨆d∈DR. L d v s - v s) ≤ ℒ v s - v s"
      unfolding ℒ_def
      using ex_dec
      by (fastforce intro!: cSUP_upper cSUP_least)
  next
    have "bdd_above ((λd. L d v s - v s) ` DR)"
      by (auto intro!: bounded_const bounded_minus_comp bounded_imp_bdd_above)
    thus "ℒ v s - v s ≤ (⨆d ∈ DR. L d v s - v s)"
      unfolding ℒ_def diff_le_eq
      by (intro cSUP_least) (auto intro: cSUP_upper2 simp: diff_le_eq[symmetric])
  qed
qed

text ‹@{const B} is a bounded function.›

lift_definition Bb :: "('s ⇒b real) ⇒ 's ⇒b real" is "B"
  using ℒb.rep_eq[symmetric] B_eq_ℒ
  by (auto intro!: bfun_normI order.trans[OF abs_triangle_ineq4] add_mono abs_le_norm_bfun)

lemma Bb_eq_ℒb: "Bb v = ℒb v - v"
  by (auto simp: ℒb.rep_eq Bb.rep_eq B_eq_ℒ)

lemma ℒb_eq_SUP_La: "ℒb v s = (⨆a ∈ A s. La a v s)"
  using L_eq_La_det ℒb_eq_SUP_det SUP_step_det_eq
  by auto

subsection ‹Optimization of the Value Function over Multiple Steps›

definition "U m v s = (⨆d ∈ DR. (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v) s)"

text ‹@{const U} expresses the value estimate obtained by optimizing the first @{term m} steps and 
  afterwards using the current estimate.›

lemma U_zero [simp]: "U 0 v = v"
  unfolding U_def ℒ_def
  by (auto simp: νb_fin.rep_eq)

lemma U_one_eq_ℒ: "U 1 v s = ℒ v s"
  unfolding U_def ℒ_def 
  by (auto simp: νb_fin_eq_𝒫X L_def blinfun.bilinear_simps)

lift_definition Ub :: "nat ⇒ ('s ⇒b real) ⇒ ('s ⇒b real)" is U
proof -
  fix n v
  have "norm (νb_fin (mk_stationary d) m) ≤ (∑i<m. l ^ i * rM)" for d m
    using abs_ν_fin_le νb_fin.rep_eq
    by (auto intro!: norm_bound)
  moreover have "norm (((l *R 𝒫1 d)^^m) v) ≤ l ^ m * norm v" for d m
    by (auto simp: 𝒫X_const[symmetric] blinfun.bilinear_simps blincomp_scaleR_right simp del: 𝒫X_sconst 
        intro!: boundedI order.trans[OF abs_le_norm_bfun] mult_left_mono)
  ultimately have *: "norm (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v) ≤ (∑i<m. l ^ i * rM) +  l ^ m * norm v" for d m
    using norm_triangle_mono by blast
  show "U n v ∈ bfun"
    using ex_dec order.trans[OF abs_le_norm_bfun *]
    by (fastforce simp: U_def intro!: bfun_normI cSup_abs_le)
qed

lemma Ub_contraction: "dist (Ub m v) (Ub m u) ≤ l ^ m * dist v u"
proof -
  have aux: "dist (Ub m v s) (Ub m u s) ≤ l ^ m * dist v u" if le: "Ub m u s ≤ Ub m v s" for s v u
  proof -
    let ?U = "λm v d. (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d) ^^ m) v) s"
    have "Ub m v s - Ub m u s ≤ (⨆d ∈ DR. ?U m v d - ?U m u d)"
      using bounded_stationary_νb_fin bounded_disc_𝒫1 le
      unfolding Ub.rep_eq U_def
      by (intro le_SUP_diff') (auto intro: bounded_plus_comp)
    also have "… = (⨆d ∈ DR. ((l *R 𝒫1 d) ^^ m) (v - u) s)"
      by (simp add: L_def scale_right_diff_distrib blinfun.bilinear_simps)
    also have "… = (⨆d ∈ DR. l^m * ((𝒫1 d ^^ m) (v - u) s))"
      by (simp add: blincomp_scaleR_right blinfun.scaleR_left)
    also have "… = l^m * (⨆d ∈ DR. ((𝒫1 d ^^ m) (v - u) s))" 
      using DR_ne bounded_P bounded_disc_𝒫1'
      by (auto intro: bounded_SUP_mul)
    also have "… ≤ l^m * norm (⨆d ∈ DR. ((𝒫1 d ^^ m) (v - u) s))"
      by (simp add: mult_left_mono)
    also have "… ≤ l^m * (⨆d ∈ DR. norm (((𝒫1 d ^^ m) (v - u) s)))"
      using DR_ne ex_dec bounded_norm_comp bounded_disc_𝒫1'
      by (fastforce intro!: mult_left_mono)
    also have "… ≤ l^m * (⨆d ∈ DR. norm ((𝒫1 d ^^ m) ((v - u))))"
      using ex_dec
      by (fastforce intro!: order.trans[OF norm_blinfun] abs_le_norm_bfun mult_left_mono cSUP_mono)
    also have "… ≤ l^m * (⨆d ∈ DR. norm ((v - u)))"
      using norm_𝒫X_apply
      by (auto simp: 𝒫X_const[symmetric] cSUP_least mult_left_mono)
    also have "… = l ^m * dist v u"
      by (auto simp: dist_norm)
    finally have "Ub m v s - Ub m u s ≤ l^m * dist v u" .
    thus ?thesis
      by (simp add: dist_real_def le)
  qed
  moreover have "Ub m v s ≤ Ub m u s ⟹ dist (Ub m v s) (Ub m u s) ≤ l^m * dist v u" for u v s
    by (simp add: aux dist_commute)
  ultimately have "dist (Ub m v s) (Ub m u s) ≤ l^m * dist v u" for u v s
    using linear 
    by blast
  thus "dist (Ub m v) (Ub m u) ≤ l^m * dist v u"
    by (simp add: dist_bound)
qed

lemma Ub_conv:
  "∃!v. Ub (Suc m) v = v" 
  "(λn. (Ub (Suc m) ^^ n) v) ⇢ (THE v. Ub (Suc m) v = v)"
proof -
  have *: "is_contraction (Ub (Suc m))"
    unfolding is_contraction_def
    using Ub_contraction[of "Suc m"] le_neq_trans[OF zero_le_disc] 
    by (cases "l = 0")
      (auto intro!: power_Suc_less_one intro: exI[of _ "l^(Suc m)"])
  show "∃!v. Ub (Suc m) v = v" "(λn. (Ub (Suc m) ^^ n) v) ⇢ (THE v. Ub (Suc m) v = v)"
    using banach'[OF *]
    by auto
qed

lemma Ub_convergent: "convergent (λn. (Ub (Suc m) ^^ n) v)"
  by (intro convergentI[OF Ub_conv(2)])

lemma Ub_mono:
  assumes "v ≤ u" 
  shows "Ub m v ≤ Ub m u"
proof  -
  have "Ub m v s ≤ Ub m u s" for s
    unfolding Ub.rep_eq U_def
  proof (intro cSUP_mono, goal_cases)
    case 2
    thus ?case
      by (simp add: bounded_imp_bdd_above bounded_disc_𝒫1 bounded_plus_comp bounded_stationary_νb_fin)
  next
    case (3 n)
    thus ?case 
      using less_eq_bfunD[OF 𝒫X_mono[OF assms]]
      by (auto simp: 𝒫X_const[symmetric] blincomp_scaleR_right blinfun.scaleR_left intro!: mult_left_mono exI)
  qed auto
  thus ?thesis
    using assms
    by auto
qed

lemma Ub_le_ℒb: "Ub m v ≤ (ℒb ^^ m) v"
proof -
  have "Ub m v s = (⨆d ∈ DR. (L d^^ m) v s)" for m v s
    by (auto simp: L_iter Ub.rep_eq ℒb.rep_eq U_def ℒ_def)
  thus ?thesis
    using L_iter_le_ℒb ex_dec
    by (fastforce intro!: cSUP_least)
qed


lemma L_iter_le_Ub: 
  assumes "d ∈ DR" 
  shows "(L d^^m) v ≤ Ub m v"
  using assms
  by (fastforce intro!: cSUP_upper bounded_imp_bdd_above
      simp: L_iter Ub.rep_eq U_def bounded_disc_𝒫1 bounded_plus_comp bounded_stationary_νb_fin)


lemma lim_Ub: "lim (λn. (Ub (Suc m) ^^ n) v) = νb_opt"
proof -
  have le_U: "νb_opt ≤ Ub m νb_opt" for m
  proof -
    obtain d where d: "ν_improving νb_opt (mk_dec_det d)" "d ∈ DD"
      using ex_improving_det by auto
    have "νb_opt = (L (mk_dec_det d) ^^ m) νb_opt"
      by (induction m) (metis L_ν_fix_iff ℒb_opt ν_improving_imp_ℒb d(1) funpow_swap1)+
    thus ?thesis
      using ‹d ∈ DD›
      by (auto intro!: order.trans[OF _ L_iter_le_Ub])
  qed
  have "Ub m νb_opt ≤ νb_opt" for m
    using ℒ_inc_le_opt
    by (auto intro!: order.trans[OF Ub_le_ℒb] simp: funpow_swap1)
  hence "Ub (Suc m) νb_opt = νb_opt"
    using le_U
    by (simp add: antisym)
  moreover have "(lim (λn. (Ub (Suc m) ^^n) v)) = Ub (Suc m) (lim (λn. (Ub (Suc m) ^^n) v))"
    using limI[OF Ub_conv(2)] theI'[OF Ub_conv(1)]
    by auto
  ultimately show ?thesis
    using Ub_conv(1)
    by metis
qed

lemma Ub_tendsto: "(λn. (Ub (Suc m) ^^ n) v) ⇢ νb_opt"
  using lim_Ub Ub_convergent convergent_LIMSEQ_iff
  by metis

lemma Ub_fix_unique: "Ub (Suc m) v = v ⟷ v = νb_opt" 
  using theI'[OF Ub_conv(1)] Ub_conv(1)
  by (auto simp: LIMSEQ_unique[OF Ub_tendsto Ub_conv(2)[of m]])

lemma dist_Ub_opt: "dist (Ub m v) νb_opt ≤ l^m * dist v νb_opt"
proof -
  have "dist (Ub m v) νb_opt = dist (Ub m v) (Ub m νb_opt)"
    by (metis Ub.abs_eq Ub_fix_unique U_zero apply_bfun_inverse not0_implies_Suc)
  also have "… ≤ l^m * dist v νb_opt"
    by (meson Ub_contraction)
  finally show ?thesis .
qed

subsection ‹Expressing a Single Step of Modified Policy Iteration›
text ‹The function @{term W} equals the value computed by the Modified Policy Iteration Algorithm
  in a single iteration.
  The right hand addend in the definition describes the advantage of using the optimal action for 
  the first m steps.
  ›
definition "W d m v = v + (∑i < m. (l *R 𝒫1 d)^^i) (Bb v)"


lemma W_eq_L_iter:
  assumes "ν_improving v d"
  shows "W d m v = (L d^^m) v"
proof -
  have "(∑i<m. (l *R 𝒫1 d)^^i) (ℒb v) = (∑i<m. (l *R 𝒫1 d)^^i) (L d v)"
    using ν_improving_imp_ℒb assms by auto
  hence "W d m v =  v + ((∑i<m. (l *R 𝒫1 d)^^i) (L d v)) - (∑i<m. (l *R 𝒫1 d)^^i) v"
    by (auto simp: W_def Bb_eq_ℒb blinfun.bilinear_simps algebra_simps )
  also have "… = v + νb_fin (mk_stationary d) m + (∑i<m. ((l *R 𝒫1 d)^^i) ((l *R 𝒫1 d) v)) - (∑i<m. (l *R 𝒫1 d)^^i) v"
    unfolding L_def
    by (auto simp: νb_fin_eq blinfun.bilinear_simps blinfun.sum_left scaleR_right.sum)
  also have "… = v + νb_fin (mk_stationary d) m + (∑i<m. ((l *R 𝒫1 d)^^Suc i) v) - (∑i<m. (l *R 𝒫1 d)^^i) v"
    by (auto simp del: blinfunpow.simps simp: blinfunpow_assoc)
  also have "… = νb_fin (mk_stationary d) m + (∑i<Suc m. ((l *R 𝒫1 d)^^ i) v)  - (∑i<m. (l *R 𝒫1 d)^^ i) v"
    by (subst sum.lessThan_Suc_shift) auto
  also have "… =  νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v"
    by (simp add: blinfun.sum_left)
  also have "… = (L d ^^ m) v"
    using L_iter by auto
  finally show ?thesis .
qed

lemma W_le_Ub:
  assumes "v ≤ u" "ν_improving v d"
  shows "W d m v ≤ Ub m u"
proof -
  have "Ub m u - W d m v ≥ νb_fin (mk_stationary d) m + ((l *R 𝒫1 d) ^^ m) u - (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v)"
    using ν_improving_D_MR assms(2) bounded_stationary_νb_fin bounded_disc_𝒫1
    by (fastforce intro!: diff_mono bounded_imp_bdd_above cSUP_upper bounded_plus_comp simp: Ub.rep_eq U_def L_iter W_eq_L_iter)
  hence *: "Ub m u - W d m v ≥ ((l *R 𝒫1 d) ^^ m) (u - v)"
    by (auto simp: blinfun.diff_right)
  show "W d m v ≤ Ub m u"
    using order.trans[OF 𝒫1_n_disc_pos[unfolded blincomp_scaleR_right[symmetric]] *] assms
    by auto
qed

lemma W_ge_ℒb:
  assumes "v ≤ u" "0 ≤ Bb u" "ν_improving u d'"
  shows "ℒb v ≤ W d' (Suc m) u"
proof -
  have "ℒb v ≤ u + Bb u"
    using assms(1) ℒb_mono Bb_eq_ℒb
    by auto
  also have "… ≤ W d' (Suc m) u"
    using L_mono ν_improving_imp_ℒb assms(3) assms 
    by (induction m) (auto simp: W_eq_L_iter Bb_eq_ℒb)
  finally show ?thesis .
qed

lemma Bb_le:
  assumes "ν_improving v d"
  shows "Bb v + (l *R 𝒫1 d - id_blinfun) (u - v) ≤ Bb u"
proof -
  have "r_decb d + (l *R 𝒫1 d - id_blinfun) u ≤ Bb u"
    using L_def L_le_ℒb assms     
    by (auto simp: Bb_eq_ℒb ℒb.rep_eq ℒ_def blinfun.bilinear_simps)
  moreover have "Bb v = r_decb d + (l *R 𝒫1 d - id_blinfun) v"
    using assms
    by (auto simp: Bb_eq_ℒb ν_improving_imp_ℒb[of _ d] L_def blinfun.bilinear_simps)
  ultimately show ?thesis
    by (simp add: blinfun.diff_right)
qed

lemma ℒb_W_ge:
  assumes "u ≤ ℒb u" "ν_improving u d"
  shows "W d m u ≤ ℒb (W d m u)"
proof -
  have "0 ≤ ((l *R 𝒫1 d) ^^ m) (Bb u)"
    by (metis Bb_eq_ℒb 𝒫1_n_disc_pos assms(1) blincomp_scaleR_right diff_ge_0_iff_ge)
  also have "… = ((l *R 𝒫1 d)^^0 + (∑i < m. (l *R 𝒫1 d)^^(Suc i))) (Bb u) - (∑i < m. (l *R 𝒫1 d)^^ i) (Bb u)"
    by (subst sum.lessThan_Suc_shift[symmetric]) (auto simp: blinfun.diff_left[symmetric])
  also have "… = Bb u + ((l *R 𝒫1 d - id_blinfun) oL (∑i < m. (l *R 𝒫1 d)^^i)) (Bb u)" 
    by (auto simp: blinfun.bilinear_simps sum_subtractf)
  also have "… = Bb u + (l *R 𝒫1 d - id_blinfun) (W d m u - u)"
    by (auto simp: W_def sum.lessThan_Suc[unfolded lessThan_Suc_atMost])
  also have "… ≤ Bb (W d m u)"
    using Bb_le assms(2) by blast
  finally have "0 ≤ Bb (W d m u)" .
  thus ?thesis using Bb_eq_ℒb
    by auto
qed

subsection ‹Computing the Bellman Operator over Multiple Steps›
definition L_pow :: "('s ⇒b real) ⇒ ('s ⇒ 'a) ⇒ nat ⇒ ('s ⇒b real)" where
  "L_pow v d m = (L (mk_dec_det d) ^^ Suc m) v"

lemma sum_telescope': "(∑i≤k. f (Suc i) - f i ) = f (Suc k) - (f 0 :: 'c :: ab_group_add)"
  using sum_telescope[of "-f" k]
  by auto

(* eq 6.5.7 *)
lemma L_pow_eq:
  assumes "ν_improving v (mk_dec_det d)"
  shows "L_pow v d m = v + (∑i ≤ m. ((l *R 𝒫1 (mk_dec_det d))^^i)) (Bb v)"
proof -
  let ?d = "(mk_dec_det d)"
  have "(∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (Bb v) = (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (L ?d v) - (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) v"
    using assms
    by (auto simp: Bb_eq_ℒb blinfun.bilinear_simps ν_improving_imp_ℒb)
  also have "… = (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (r_decb ?d) + (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) ((l *R 𝒫1 ?d) v) - (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) v"
    by (simp add: L_def blinfun.bilinear_simps)
  also have "… = (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (r_decb ?d) + (∑i ≤ m. ((l *R 𝒫1 ?d)^^Suc i)) v - (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) v"
    by (auto simp: blinfun.sum_left blinfunpow_assoc simp del: blinfunpow.simps)
  also have "… = (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (r_decb ?d) + (∑i ≤ m. ((l *R 𝒫1 ?d)^^Suc i) - (l *R 𝒫1 ?d)^^i) v"
    by (simp add: blinfun.diff_left sum_subtractf)
  also have "… = (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (r_decb ?d) + ((l *R 𝒫1 ?d)^^Suc m) v - v"
    by (subst sum_telescope') (auto simp: blinfun.bilinear_simps)
  finally have "(∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (Bb v) = (∑i ≤ m. ((l *R 𝒫1 ?d)^^i)) (r_decb ?d) + ((l *R 𝒫1 ?d)^^Suc m) v - v" .
  moreover have "L_pow v d m = νb_fin (mk_stationary_det d) (Suc m) + ((l *R 𝒫1 ?d)^^Suc m) v"
    by (simp only: L_pow_def L_iter lessThan_Suc_atMost[symmetric])
  ultimately show ?thesis
    by (auto simp: νb_fin_eq lessThan_Suc_atMost)
qed

lemma L_pow_eq_W:
  assumes "d ∈ DD" 
  shows "L_pow v (policy_improvement d v) m = W (mk_dec_det (policy_improvement d v)) (Suc m) v" 
  using assms policy_improvement_improving 
  by (auto simp: W_eq_L_iter L_pow_def)

lemma L_pow_ℒb_mono_inv:
  assumes "d ∈ DD" "v ≤ ℒb v"
  shows "L_pow v (policy_improvement d v) m ≤ ℒb (L_pow v (policy_improvement d v) m)"
  using assms L_pow_eq_W ℒb_W_ge policy_improvement_improving 
  by auto

subsection ‹The Modified Policy Iteration Algorithm›
context
  fixes d0 :: "'s ⇒ 'a"
  fixes v0 :: "'s ⇒b real"
  fixes m :: "nat ⇒ ('s ⇒b real) ⇒ nat"
  assumes d0: "d0 ∈ DD"
begin

text ‹We first define a function that executes the algorithm for n steps.›
fun mpi :: "nat ⇒ (('s ⇒ 'a) × ('s ⇒b real))" where
  "mpi 0 = (policy_improvement d0 v0, v0)" |
  "mpi (Suc n) =
  (let (d, v) = mpi n; v' = L_pow v d (m n v) in
  (policy_improvement d v', v'))"

definition "mpi_val n = snd (mpi n)"
definition "mpi_pol n = fst (mpi n)"

lemma mpi_pol_zero[simp]: "mpi_pol 0 = policy_improvement d0 v0"
  unfolding mpi_pol_def 
  by auto

lemma mpi_pol_Suc: "mpi_pol (Suc n) = policy_improvement (mpi_pol n) (mpi_val (Suc n))"
  by (auto simp: case_prod_beta' Let_def mpi_pol_def mpi_val_def)

lemma mpi_pol_is_dec_det: "mpi_pol n ∈ DD"
  unfolding mpi_pol_def
  using policy_improvement_is_dec_det d0
  by (induction n) (auto simp: Let_def split: prod.splits)

lemma ν_improving_mpi_pol: "ν_improving (mpi_val n) (mk_dec_det (mpi_pol n))"
  using d0 policy_improvement_improving mpi_pol_is_dec_det mpi_pol_Suc
  by (cases n) (auto simp: mpi_pol_def mpi_val_def)

lemma mpi_val_zero[simp]: "mpi_val 0 = v0"
  unfolding mpi_val_def by auto

lemma mpi_val_Suc: "mpi_val (Suc n) = L_pow (mpi_val n) (mpi_pol n) (m n (mpi_val n))"
  unfolding mpi_val_def mpi_pol_def
  by (auto simp: case_prod_beta' Let_def)

lemma mpi_val_eq: "mpi_val (Suc n) = 
  mpi_val n + (∑i ≤ m n (mpi_val n). (l *R 𝒫1 (mk_dec_det (mpi_pol n))) ^^ i) (Bb (mpi_val n))"  
  using L_pow_eq[OF ν_improving_mpi_pol] mpi_val_Suc
  by auto


text ‹Value Iteration is a special case of MPI where @{term "∀n v. m n v = 0"}.›
lemma mpi_includes_value_it: 
  assumes "∀n v. m n v = 0"
  shows "mpi_val (Suc n) = ℒb (mpi_val n)"
  using assms Bb_eq_ℒb mpi_val_eq
  by auto

subsection ‹Convergence Proof›
text ‹We define the sequence @{term w} as an upper bound for the values of MPI.›
fun w where
  "w 0 = v0" |
  "w (Suc n) = Ub (Suc (m n (mpi_val n))) (w n)"

lemma dist_νb_opt: "dist (w (Suc n)) νb_opt ≤ l * dist (w n) νb_opt"
  by (fastforce simp: algebra_simps intro: order.trans[OF dist_Ub_opt] mult_left_mono power_le_one
      mult_left_le_one_le order.strict_implies_order)

lemma dist_νb_opt_n: "dist (w n) νb_opt ≤ l^n * dist v0 νb_opt"
  by (induction n) (fastforce simp: algebra_simps intro: order.trans[OF dist_νb_opt] mult_left_mono)+

lemma w_conv: "w ⇢ νb_opt"
proof -
  have "(λn. l^n * dist v0 νb_opt) ⇢ 0"
    using LIMSEQ_realpow_zero
    by (cases "v0 = νb_opt") auto
  then show ?thesis
    by (fastforce intro: metric_LIMSEQ_I order.strict_trans1[OF dist_νb_opt_n] simp: LIMSEQ_def)
qed

text ‹MPI converges monotonically to the optimal value from below. 
  The iterates are sandwiched between @{const ℒb} from below and @{const Ub} from above.›
theorem mpi_conv:
  assumes "v0 ≤ ℒb v0"
  shows "mpi_val ⇢ νb_opt" and "⋀n. mpi_val n ≤ mpi_val (Suc n)"
proof -
  define y where "y n = (ℒb^^n) v0" for n
  have aux: "mpi_val n ≤ ℒb (mpi_val n) ∧ mpi_val n ≤ mpi_val (Suc n) ∧ y n ≤ mpi_val n ∧ mpi_val n ≤ w n" for n
  proof (induction n)
    case 0
    show ?case
      using assms Bb_eq_ℒb
      unfolding y_def
      by (auto simp: mpi_val_eq blinfun.sum_left 𝒫1_n_disc_pos blincomp_scaleR_right sum_nonneg)
  next
    case (Suc n)
    have val_eq_W: "mpi_val (Suc n) = W (mk_dec_det (mpi_pol n)) (Suc (m n (mpi_val n))) (mpi_val n)"
      using ν_improving_mpi_pol mpi_val_Suc W_eq_L_iter L_pow_def
      by auto
    hence *: "mpi_val (Suc n) ≤ ℒb (mpi_val (Suc n))"
      using Suc.IH ℒb_W_ge ν_improving_mpi_pol by presburger
    moreover have "mpi_val (Suc n) ≤ mpi_val (Suc (Suc n))"
      using *
      by (simp add: Bb_eq_ℒb mpi_val_eq 𝒫1_n_disc_pos blincomp_scaleR_right blinfun.sum_left sum_nonneg)
    moreover have "mpi_val (Suc n) ≤ w (Suc n)"
      using Suc.IH ν_improving_mpi_pol 
      by (auto simp: val_eq_W intro: order.trans[OF _ W_le_Ub])
    moreover have "y (Suc n) ≤ mpi_val (Suc n)"
      using Suc.IH ν_improving_mpi_pol W_ge_ℒb
      by (auto simp: y_def Bb_eq_ℒb val_eq_W)
    ultimately show ?case
      by auto
  qed
  thus "mpi_val n ≤ mpi_val (Suc n)" for n
    by auto
  have "y ⇢ νb_opt"
    using ℒb_lim y_def by presburger
  thus "mpi_val ⇢ νb_opt"
    using aux
    by (auto intro: tendsto_bfun_sandwich[OF _ w_conv])
qed

subsection ‹$\epsilon$-Optimality›
text ‹This gives an upper bound on the error of MPI.›
lemma mpi_pol_eps_opt:
  assumes "2 * l * dist (mpi_val n) (ℒb (mpi_val n)) < eps * (1 - l)" "eps > 0"
  shows "dist (νb (mk_stationary_det (mpi_pol n))) (ℒb (mpi_val n)) ≤ eps / 2"
proof -
  let ?p = "mk_stationary_det (mpi_pol n)"
  let ?d = "mk_dec_det (mpi_pol n)"
  let ?v = "mpi_val n"
  have "dist (νb ?p) (ℒb ?v) = dist (L ?d (νb ?p)) (ℒb ?v)"
    using L_ν_fix
    by force
  also have "… = dist (L ?d (νb ?p)) (L ?d ?v)"
    by (metis ν_improving_imp_ℒb ν_improving_mpi_pol)
  also have "… ≤ dist (L ?d (νb ?p)) (L ?d (ℒb ?v)) + dist (L ?d (ℒb ?v)) (L ?d ?v)"
    using dist_triangle 
    by blast
  also have "… ≤ l * dist (νb ?p) (ℒb ?v) + dist (L ?d (ℒb ?v)) (L ?d ?v)"
    using contraction_L by auto
  also have "… ≤ l * dist (νb ?p) (ℒb ?v) + l * dist (ℒb ?v) ?v"
    using contraction_L by auto
  finally have "dist (νb ?p) (ℒb ?v) ≤ l * dist (νb ?p) (ℒb ?v) + l * dist (ℒb ?v) ?v".
  hence *:"(1-l) * dist (νb ?p) (ℒb ?v) ≤ l * dist (ℒb ?v) ?v"
    by (auto simp: left_diff_distrib)
  thus ?thesis
  proof (cases "l = 0")
    case True
    thus ?thesis
      using assms *
      by auto
  next
    case False
    have **: "dist (ℒb ?v) (mpi_val n) < eps * (1 - l) / (2 * l)"
      using False le_neq_trans[OF zero_le_disc False[symmetric]] assms
      by (auto simp: dist_commute pos_less_divide_eq Groups.mult_ac(2))
    have "dist (νb ?p) (ℒb ?v) ≤ (l/ (1-l)) * dist (ℒb ?v) ?v"
      using *
      by (auto simp: mult.commute pos_le_divide_eq)
    also have "… ≤ (l/ (1-l)) * (eps * (1 - l) / (2 * l))"
      using **
      by (fastforce intro!: mult_left_mono simp: divide_nonneg_pos)
    also have "… = eps / 2"
      using False disc_lt_one
      by (auto simp: order.strict_iff_order)
    finally show "dist (νb ?p) (ℒb ?v) ≤ eps / 2".    
  qed
qed

lemma mpi_pol_opt:
  assumes "2 * l * dist (mpi_val n) (ℒb (mpi_val n)) < eps * (1 - l)" "eps > 0"
  shows "dist (νb (mk_stationary_det (mpi_pol n))) (νb_opt) < eps"
proof -
  have "dist (νb (mk_stationary_det (mpi_pol n))) (νb_opt) ≤ eps/2 + dist (ℒb (mpi_val n)) νb_opt"
    by (metis mpi_pol_eps_opt[OF assms] dist_commute dist_triangle_le add_right_mono)
  thus ?thesis
    using dist_ℒb_opt_eps assms
    by fastforce
qed

lemma mpi_val_term_ex:
  assumes "v0 ≤ ℒb v0" "eps > 0"
  shows "∃n. 2 * l * dist (mpi_val n) (ℒb (mpi_val n)) < eps * (1 - l)"
proof -
  note dist_ℒb_lt_dist_opt
  have "(λn. dist (mpi_val n) νb_opt) ⇢ 0"
    using mpi_conv(1)[OF assms(1)] tendsto_dist_iff 
    by blast
  hence "(λn. dist (mpi_val n) (ℒb (mpi_val n))) ⇢ 0"
    using dist_ℒb_lt_dist_opt
    by (auto simp: metric_LIMSEQ_I intro: tendsto_sandwich[of "λ_. 0" _ _ "λn. 2 * dist (mpi_val n) νb_opt"])
  hence "∀e >0. ∃n. dist (mpi_val n) (ℒb (mpi_val n)) < e"
    by (fastforce dest!: metric_LIMSEQ_D)
  hence "l ≠ 0 ⟹ ∃n. dist (mpi_val n) (ℒb (mpi_val n)) < eps * (1 - l) / (2 * l)"
    by (simp add: assms order.not_eq_order_implies_strict)
  thus "∃n. (2 * l) * dist (mpi_val n) (ℒb (mpi_val n)) < eps * (1 - l)"
    using assms le_neq_trans[OF zero_le_disc]
    by (cases "l = 0") (auto simp: mult.commute pos_less_divide_eq)
qed
end

subsection ‹Unbounded MPI›
context
  fixes eps δ :: real and M :: nat
begin

function (domintros) mpi_algo where "mpi_algo d v m = (
  if 2 * l * dist v (ℒb v) <  eps * (1 - l)
  then (policy_improvement d v, v)
  else mpi_algo (policy_improvement d v) (L_pow v (policy_improvement d v) (m 0 v)) (λn. m (Suc n)))"
  by auto

text ‹We define a tailrecursive version of @{const mpi} which more closely resembles @{const mpi_algo}.›
fun mpi' where
  "mpi' d v 0 m = (policy_improvement d v, v)" |
  "mpi' d v (Suc n) m = (
  let d' = policy_improvement d v; v' = L_pow v d' (m 0 v) in mpi' d' v' n (λn. m (Suc n)))"

lemma mpi_Suc':
  assumes "d ∈ DD"
  shows "mpi d v m (Suc n) = mpi (policy_improvement d v) (L_pow v (policy_improvement d v) (m 0 v)) (λa. m (Suc a)) n"
  using assms policy_improvement_is_dec_det
  by (induction n rule: nat.induct) (auto simp: Let_def)

lemma 
  assumes "d ∈ DD"
  shows "mpi d v m n = mpi' d v n m"
  using assms
proof (induction n arbitrary: d v m rule: nat.induct)
  case (Suc nat)
  thus ?case
    using policy_improvement_is_dec_det
    by (auto simp: Let_def mpi_Suc'[OF Suc(2)] Suc.IH[symmetric])
qed auto

lemma termination_mpi_algo: 
  assumes "eps > 0" "d ∈ DD" "v ≤ ℒb v"
  shows "mpi_algo_dom (d, v, m)"
proof -
  define n where "n = (LEAST n. 2 * l * dist (mpi_val d v m n) (ℒb (mpi_val d v m n)) < eps * (1 - l))" (is "n = (LEAST n. ?P d v m n)")
  have least0: "∃n. P n ⟹ (LEAST n. P n) = (0 :: nat) ⟹ P 0"  for P
    by (metis LeastI_ex)
  from n_def assms show ?thesis
  proof (induction n arbitrary: v d m)
    case 0
    have "2 * l * dist (mpi_val d v m 0) (ℒb (mpi_val d v m 0)) < eps * (1 - l)"
      using least0 mpi_val_term_ex 0
      by (metis (no_types, lifting))
    thus ?case
      using 0 mpi_algo.domintros mpi_val_zero
      by (metis (no_types, opaque_lifting))
  next
    case (Suc n v d m)
    let ?d = "policy_improvement d v"
    have "Suc n = Suc (LEAST n. 2 * l * dist (mpi_val d v m (Suc n)) (ℒb (mpi_val d v m (Suc n))) < eps * (1 - l))"
      using mpi_val_term_ex[OF Suc.prems(3) ‹v ≤ ℒb v› ‹0 < eps›, of m] Suc.prems 
      by (subst Nat.Least_Suc[symmetric]) (auto intro: LeastI_ex)
    hence "n = (LEAST n. 2 * l * dist (mpi_val d v m (Suc n)) (ℒb (mpi_val d v m (Suc n))) < eps * (1 - l))"
      by auto
    hence n_eq: "n =
    (LEAST n. 2 * l * dist (mpi_val ?d (L_pow v ?d (m 0 v)) (λa. m (Suc a)) n) (ℒb (mpi_val ?d (L_pow v ?d (m 0 v)) (λa. m (Suc a)) n))
        < eps * (1 - l))"
      using Suc.prems mpi_Suc'
      by (auto simp: is_dec_det_pi mpi_val_def)
    have "¬ 2 * l * dist v (ℒb v) < eps * (1 - l)"
      using Suc mpi_val_zero by force
    moreover have "mpi_algo_dom (?d, L_pow v ?d (m 0 v), λa. m (Suc a))"
      using Suc.IH[OF n_eq ‹0 < eps›] Suc.prems is_dec_det_pi L_pow_ℒb_mono_inv by auto
    ultimately show ?case
      using mpi_algo.domintros 
      by blast
  qed
qed

abbreviation "mpi_alg_rec d v m ≡ 
    (if 2 * l * dist v (ℒb v) < eps * (1 - l) then (policy_improvement d v, v)
     else mpi_algo (policy_improvement d v) (L_pow v (policy_improvement d v) (m 0 v))
           (λn. m (Suc n)))"

lemma mpi_algo_def':
  assumes "d ∈ DD" "v ≤ ℒb v" "eps > 0"
  shows "mpi_algo d v m = mpi_alg_rec d v m"
  using mpi_algo.psimps termination_mpi_algo assms
  by auto

lemma mpi_algo_eq_mpi:
  assumes "d ∈ DD" "v ≤ ℒb v" "eps > 0"
  shows "mpi_algo d v m = mpi d v m (LEAST n. 2 * l * dist (mpi_val d v m n) (ℒb (mpi_val d v m n)) < eps * (1 - l))"
proof -
  define n where "n = (LEAST n. 2 * l * dist (mpi_val d v m n) (ℒb (mpi_val d v m n)) < eps * (1 - l))" (is "n = (LEAST n. ?P d v m n)")
  from n_def assms show ?thesis
  proof (induction n arbitrary: d v m)
    case 0
    have "?P d v m 0"
      by (metis (no_types, lifting) assms(3) LeastI_ex 0 mpi_val_term_ex)
    thus ?case
      using assms 0
      by (auto simp: mpi_val_def mpi_algo_def')
  next
    case (Suc n)
    hence not0: "¬ (2 * l * dist v (ℒb v) < eps * (1 - l))"
      using Suc(3) mpi_val_zero
      by auto
    obtain n' where "2 * l * dist (mpi_val d v m n') (ℒb (mpi_val d v m n')) < eps * (1 - l)"
      using mpi_val_term_ex[OF Suc(3) Suc(4), of _ m] assms by blast
    hence "n = (LEAST n. ?P d v m (Suc n))"
      using Suc(2) Suc
      by (subst (asm) Least_Suc) auto
    hence "n = (LEAST n. ?P (policy_improvement d v) (L_pow v (policy_improvement d v) (m 0 v)) (λn. m (Suc n)) n)"
      using Suc(3) policy_improvement_is_dec_det mpi_Suc' 
      by (auto simp: mpi_val_def)
    hence "mpi_algo d v m = mpi d v m (Suc n)"
      unfolding mpi_algo_def'[OF Suc.prems(2-4)]
      using Suc(1) Suc.prems(2-4) is_dec_det_pi mpi_Suc' not0 L_pow_ℒb_mono_inv by force
    thus ?case
      using Suc.prems(1) by presburger
  qed
qed

lemma mpi_algo_opt: 
  assumes "v0 ≤ ℒb v0" "eps > 0" "d ∈ DD"
  shows "dist (νb (mk_stationary_det (fst (mpi_algo d v0 m)))) νb_opt < eps"
proof -
  let ?P = "λn. 2 * l * dist (mpi_val d v0 m n) (ℒb (mpi_val d v0 m n)) <  eps * (1 - l)"
  let ?n = "Least ?P"
  have "mpi_algo d v0 m = mpi d v0 m ?n" and "?P ?n"
    using mpi_algo_eq_mpi LeastI_ex[OF mpi_val_term_ex] assms by auto
  thus ?thesis
    using assms
    by (auto simp: mpi_pol_opt mpi_pol_def[symmetric])
qed

end


subsection ‹Initial Value Estimate @{term v0_mpi}›
text ‹We define an initial estimate of the value function for which Modified Policy Iteration 
  always terminates.›

abbreviation "r_min ≡ (⨅s' a. r (s', a))"
definition "v0_mpi s = r_min / (1 - l)"

lift_definition v0_mpib :: "'s ⇒b real" is "v0_mpi"
  by fastforce

lemma v0_mpib_le_ℒb: "v0_mpib ≤ ℒb v0_mpib"
proof (rule less_eq_bfunI)
  fix x
  have "r_min ≤ r (s, a)" for s a
    by (fastforce intro: cInf_lower2)
  hence "r_min ≤ (1-l) * r (s, a) + l * r_min" for s a
    using disc_lt_one zero_le_disc
    by (meson order_less_imp_le order_refl segment_bound_lemma)
  hence "r_min / (1 - l) ≤ ((1-l) * r (s, a) + l * r_min) / (1 - l)" for s a
    using order_less_imp_le[OF disc_lt_one]
    by (auto intro!: divide_right_mono)
  hence "r_min / (1 - l) ≤ r (s, a) + (l * r_min) / (1 - l)" for s a
    using disc_lt_one
    by (auto simp: add_divide_distrib)
  thus "v0_mpib x ≤ ℒb v0_mpib x"
    unfolding ℒb_eq_SUP_La v0_mpib.rep_eq v0_mpi_def
    by (auto simp: A_ne intro: cSUP_upper2[where x = "arb_act (A x)"])
qed

subsection ‹An Instance of Modified Policy Iteration with a Valid Conservative Initial Value Estimate›
definition "mpi_user eps m = (
  if eps ≤ 0 then undefined else mpi_algo eps (λx. arb_act (A x)) v0_mpib m)"

lemma mpi_user_eq: 
  assumes "eps > 0"
  shows "mpi_user eps = mpi_alg_rec eps (λx. arb_act (A x)) v0_mpib"
  using v0_mpib_le_ℒb assms
  by (auto simp: mpi_user_def mpi_algo_def' A_ne is_dec_det_def)

lemma mpi_user_opt:
  assumes "eps > 0"
  shows "dist (νb (mk_stationary_det (fst (mpi_user eps n)))) νb_opt < eps"
  unfolding mpi_user_def using assms
  by (auto intro: mpi_algo_opt simp: is_dec_det_def A_ne v0_mpib_le_ℒb)

end

end
y>

Theory Matrix_Util

theory Matrix_Util
  imports "HOL-Analysis.Analysis"
begin

section ‹Matrices›

proposition scalar_matrix_assoc':
  fixes C :: "('b::real_algebra_1)^'m^'n"
  shows "k *R (C ** D) = C ** (k *R D)"
  by (simp add: matrix_matrix_mult_def sum_distrib_left mult_ac vec_eq_iff scaleR_sum_right)

subsection ‹Nonnegative Matrices›

lemma nonneg_matrix_nonneg [dest]: "0 ≤ m ⟹ 0 ≤ m $ i $ j"
  by (simp add: Finite_Cartesian_Product.less_eq_vec_def)

lemma matrix_mult_mono: 
  assumes "0 ≤ E" "0 ≤ C" "(E :: real^'c^'c) ≤ B" "C ≤ D"
  shows "E ** C ≤ B ** D"
  using order.trans[OF assms(1) assms(3)] assms
  unfolding Finite_Cartesian_Product.less_eq_vec_def
  by (auto intro!: sum_mono mult_mono simp: matrix_matrix_mult_def)

lemma nonneg_matrix_mult: "0 ≤ (C :: ('b::{field, ordered_ring})^_^_) ⟹ 0 ≤ D ⟹ 0 ≤ C ** D"
  unfolding Finite_Cartesian_Product.less_eq_vec_def
  by (auto simp: matrix_matrix_mult_def intro!: sum_nonneg)

lemma zero_le_mat_iff [simp]: "0 ≤ mat (x :: 'c :: {zero, order}) ⟷ 0 ≤ x"
  by (auto simp: Finite_Cartesian_Product.less_eq_vec_def mat_def)

lemma nonneg_mat_ge_zero: "0 ≤ Q ⟹ 0 ≤ v ⟹ 0 ≤ Q *v (v :: real^'c)"
  unfolding Finite_Cartesian_Product.less_eq_vec_def
  by (auto intro!: sum_nonneg simp: matrix_vector_mult_def)

lemma nonneg_mat_mono: "0 ≤ Q ⟹ u ≤ v ⟹ Q *v u ≤ Q *v (v :: real^'c)"
  using nonneg_mat_ge_zero[of Q "v - u"]
  by (simp add: vec.diff)

lemma nonneg_mult_imp_nonneg_mat:
  assumes "⋀v. v ≥ 0 ⟹ X *v v ≥ 0"
  shows "X ≥ (0 :: real ^ _ ^_)"
proof -
  { assume "¬ (0 ≤ X)"
    then obtain i j where neg: "X $ i $ j < 0" 
      by (metis less_eq_vec_def not_le zero_index)
    let ?v = "χ k. if j = k then 1::real else 0"
    have "(X *v ?v) $ i < 0"
      using neg
      by (auto simp: matrix_vector_mult_def if_distrib cong: if_cong)
    hence "?v ≥ 0 ∧ ¬ ((X *v ?v) ≥ 0)"
      by (auto simp: less_eq_vec_def not_le)
    hence "∃v. v ≥ 0 ∧ ¬ X *v v ≥ 0"
      by blast
  }
  thus ?thesis
    using assms by auto
qed

lemma nonneg_mat_iff:
  "(X ≥ (0 :: real ^ _ ^_)) ⟷ (∀v. v ≥ 0 ⟶ X *v v ≥ 0)"
  using nonneg_mat_ge_zero nonneg_mult_imp_nonneg_mat by auto

lemma mat_le_iff: "(X ≤ Y) ⟷ (∀x≥0. (X::real^_^_) *v x ≤ Y *v x)"
  by (metis diff_ge_0_iff_ge matrix_vector_mult_diff_rdistrib nonneg_mat_iff)

subsection ‹Matrix Powers›

(* copied from Perron-Frobenius *)
primrec matpow :: "'a::semiring_1^'n^'n ⇒ nat ⇒ 'a^'n^'n" where
  matpow_0:   "matpow A 0 = mat 1" |
  matpow_Suc: "matpow A (Suc n) = (matpow A n) ** A"

lemma nonneg_matpow: "0 ≤ X ⟹ 0 ≤ matpow (X :: real ^ _ ^ _) i"
  by (induction i) (auto simp: nonneg_matrix_mult)

lemma matpow_mono: "0 ≤ C ⟹ C ≤ D ⟹ matpow (C :: real^_^_) n ≤ matpow D n"
  by (induction n) (auto intro!: matrix_mult_mono nonneg_matpow)

lemma matpow_scaleR: "matpow (c *R (X :: 'b :: real_algebra_1^_^_)) n = (c^n) *R (matpow X) n"
proof (induction n arbitrary: X c)
  case (Suc n)
  have "matpow (c *R X) (Suc n) = (c^n)*R (matpow X) n ** c *R X"
    using Suc by auto
  also have "… = c *R ((c^n) *R (matpow X) n ** X)"
    using scalar_matrix_assoc' 
    by (auto simp: scalar_matrix_assoc')
  finally show ?case
    by (simp add: scalar_matrix_assoc)
qed auto

lemma matrix_vector_mult_code': "(X *v x) $ i = (∑j∈UNIV. X $ i $ j * x $ j)"
  by (simp add: matrix_vector_mult_def)