# Theory While_SPMF

(* Title: While_SPMF.thy
Author: Andreas Lochbihler, ETH Zurich *)

theory While_SPMF imports
MFMC_Countable.Rel_PMF_Characterisation
"HOL-Types_To_Sets.Types_To_Sets"
"HOL-Library.Complete_Partial_Order2"
begin

text ‹
This theory defines a probabilistic while combinator for discrete (sub-)probabilities and
formalises rules for probabilistic termination similar to those by Hurd \cite{Hurd2002TPHOLs}
and McIver and Morgan \cite{McIverMorgan2005}.
›

fun map_option_set :: "('a ⇒ 'b option set) ⇒ 'a option ⇒ 'b option set"
where
"map_option_set f None = {None}"
| "map_option_set f (Some x) = f x"

lemma None_in_map_option_set:
"None ∈ map_option_set f x ⟷ None ∈ Set.bind (set_option x) f ∨ x = None"
by(cases x) simp_all

lemma None_in_map_option_set_None [intro!]: "None ∈ map_option_set f None"
by simp

lemma None_in_map_option_set_Some [intro!]: "None ∈ f x ⟹ None ∈ map_option_set f (Some x)"
by simp

lemma Some_in_map_option_set [intro!]: "Some y ∈ f x ⟹ Some y ∈ map_option_set f (Some x)"
by simp

lemma map_option_set_singleton [simp]: "map_option_set (λx. {f x}) y = {Option.bind y f}"
by(cases y) simp_all

lemma Some_eq_bind_conv: "Some y = Option.bind x f ⟷ (∃z. x = Some z ∧ f z = Some y)"
by(cases x) auto

lemma map_option_set_bind: "map_option_set f (Option.bind x g) = map_option_set (map_option_set f ∘ g) x"
by(cases x) simp_all

lemma Some_in_map_option_set_conv: "Some y ∈ map_option_set f x ⟷ (∃z. x = Some z ∧ Some y ∈ f z)"
by(cases x) auto

interpretation rel_spmf_characterisation by unfold_locales(rule rel_pmf_measureI)
hide_fact (open) rel_pmf_measureI

lemma Sup_conv_fun_lub: "Sup = fun_lub Sup"
by(auto simp add: Sup_fun_def fun_eq_iff fun_lub_def intro: arg_cong[where f=Sup])

lemma le_conv_fun_ord: "(≤) = fun_ord (≤)"
by(auto simp add: fun_eq_iff fun_ord_def le_fun_def)

lemmas parallel_fixp_induct_2_1 = parallel_fixp_induct_uc[
of _ _ _ _ "case_prod" _ "curry" "λx. x" _ "λx. x",
where P="λf g. P (curry f) g",
unfolded case_prod_curry curry_case_prod curry_K,
OF _ _ _ _ _ _ refl refl]
for P

lemma monotone_Pair:
"⟦ monotone ord orda f; monotone ord ordb g ⟧
⟹ monotone ord (rel_prod orda ordb) (λx. (f x, g x))"

lemma cont_Pair:
"⟦ cont lub ord luba orda f; cont lub ord lubb ordb g ⟧
⟹ cont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (λx. (f x, g x))"
by(rule contI)(auto simp add: prod_lub_def image_image dest!: contD)

lemma mcont_Pair:
"⟦ mcont lub ord luba orda f; mcont lub ord lubb ordb g ⟧
⟹ mcont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (λx. (f x, g x))"
by(rule mcontI)(simp_all add: monotone_Pair mcont_mono cont_Pair)

lemma mono2mono_emeasure_spmf [THEN lfp.mono2mono]:
shows monotone_emeasure_spmf:
"monotone (ord_spmf (=)) (≤) (λp. emeasure (measure_spmf p))"
by(rule monotoneI le_funI ord_spmf_eqD_emeasure)+

lemma cont_emeasure_spmf: "cont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p))"
by (rule contI) (simp add: emeasure_lub_spmf fun_eq_iff image_comp)

lemma mcont2mcont_emeasure_spmf [THEN lfp.mcont2mcont, cont_intro]:
shows mcont_emeasure_spmf: "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p))"

lemma mcont2mcont_emeasure_spmf' [THEN lfp.mcont2mcont, cont_intro]:
shows mcont_emeasure_spmf': "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p) A)"
using mcont_emeasure_spmf[unfolded Sup_conv_fun_lub le_conv_fun_ord]
by(subst (asm) mcont_fun_lub_apply) blast

lemma mcont_bind_pmf [cont_intro]:
assumes g: "⋀y. mcont luba orda lub_spmf (ord_spmf (=)) (g y)"
shows "mcont luba orda lub_spmf (ord_spmf (=)) (λx. bind_pmf p (λy. g y x))"
using mcont_bind_spmf[where f="λ_. spmf_of_pmf p" and g=g, OF _ assms] by(simp)

lemma ennreal_less_top_iff: "x < ⊤ ⟷ x ≠ (⊤ :: ennreal)"
by(cases x) simp_all

lemma type_definition_Domainp:
fixes Rep Abs A T
assumes type: "type_definition Rep Abs A"
assumes T_def: "T ≡ (λ(x::'a) (y::'b). x = Rep y)"
shows "Domainp T = (λx. x ∈ A)"
proof -
interpret type_definition Rep Abs A by(rule type)
show ?thesis unfolding Domainp_iff[abs_def] T_def fun_eq_iff by(metis Abs_inverse Rep)
qed

context includes lifting_syntax begin

lemma weight_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> (=)) weight_spmf weight_spmf"

lemma lossless_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> (=)) lossless_spmf lossless_spmf"

lemma UNIV_parametric_pred: "rel_pred R UNIV UNIV"
by(auto intro!: rel_predI)
end

lemma bind_spmf_spmf_of_set:
"⋀A. ⟦ finite A; A ≠ {} ⟧ ⟹ bind_spmf (spmf_of_set A) = bind_pmf (pmf_of_set A)"
by(simp add: spmf_of_set_def fun_eq_iff del: spmf_of_pmf_pmf_of_set)

lemma set_pmf_bind_spmf: "set_pmf (bind_spmf M f) = set_pmf M ⤜ map_option_set (set_pmf ∘ f)"
by(auto 4 3 simp add: bind_spmf_def split: option.splits intro: rev_bexI)

lemma set_pmf_spmf_of_set:
"set_pmf (spmf_of_set A) = (if finite A ∧ A ≠ {} then Some  A else {None})"
by(simp add: spmf_of_set_def spmf_of_pmf_def del: spmf_of_pmf_pmf_of_set)

definition measure_measure_spmf :: "'a spmf ⇒ 'a set ⇒ real"
where [simp]: "measure_measure_spmf p = measure (measure_spmf p)"

lemma measure_measure_spmf_parametric [transfer_rule]:
includes lifting_syntax shows
"(rel_spmf A ===> rel_pred A ===> (=)) measure_measure_spmf measure_measure_spmf"
unfolding measure_measure_spmf_def[abs_def] by(rule measure_spmf_parametric)

lemma of_nat_le_one_cancel_iff [simp]:
fixes n :: nat shows "real n ≤ 1 ⟷ n ≤ 1"
by linarith

lemma of_int_ceiling_less_add_one [simp]: "of_int ⌈r⌉ < r + 1"
by linarith

lemma lessThan_subset_Collect: "{..<x} ⊆ Collect P ⟷ (∀y<x. P y)"

lemma spmf_ub_tight:
assumes ub: "⋀x. spmf p x ≤ f x"
and sum: "(∫⇧+ x. f x ∂count_space UNIV) = weight_spmf p"
shows "spmf p x = f x"
proof -
have [rule_format]: "∀x. f x ≤ spmf p x"
proof(rule ccontr)
assume "¬ ?thesis"
then obtain x where x: "spmf p x < f x" by(auto simp add: not_le)
have *: "(∫⇧+ y. ennreal (f y) * indicator (- {x}) y ∂count_space UNIV) ≠ ⊤"
by(rule neq_top_trans[where y="weight_spmf p"], simp)(auto simp add: sum[symmetric] intro!: nn_integral_mono split: split_indicator)

have "weight_spmf p = ∫⇧+ y. spmf p y ∂count_space UNIV"
also have "… = (∫⇧+ y. ennreal (spmf p y) * indicator (- {x}) y ∂count_space UNIV) +
(∫⇧+ y. spmf p y * indicator {x} y ∂count_space UNIV)"
by(subst nn_integral_add[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
also have "… ≤ (∫⇧+ y. ennreal (f y) * indicator (- {x}) y ∂count_space UNIV) + spmf p x"
using ub by(intro add_mono nn_integral_mono)(auto split: split_indicator intro: ennreal_leI)
also have "… < (∫⇧+ y. ennreal (f y) * indicator (- {x}) y ∂count_space UNIV) + (∫⇧+ y. f y * indicator {x} y ∂count_space UNIV)"
using * x by(simp add: ennreal_less_iff)
also have "… = (∫⇧+ y. ennreal (f y) ∂count_space UNIV)"
by(subst nn_integral_add[symmetric])(auto intro: nn_integral_cong split: split_indicator)
also have "… = weight_spmf p" using sum by simp
finally show False by simp
qed
from this[of x] ub[of x] show ?thesis by simp
qed

section ‹Probabilistic while loop›

locale loop_spmf =
fixes guard :: "'a ⇒ bool"
and body :: "'a ⇒ 'a spmf"
begin

context notes [[function_internals]] begin

partial_function (spmf) while :: "'a ⇒ 'a spmf"
where "while s = (if guard s then bind_spmf (body s) while else return_spmf s)"

end

lemma while_fixp_induct [case_names adm bottom step]:
and "P (λwhile. return_pmf None)"
and "⋀while'. P while' ⟹ P (λs. if guard s then body s ⤜ while' else return_spmf s)"
shows "P while"
using assms by(rule while.fixp_induct)

lemma while_simps:
"guard s ⟹ while s = bind_spmf (body s) while"
"¬ guard s ⟹ while s = return_spmf s"
by(rewrite while.simps; simp; fail)+

end

lemma while_spmf_parametric [transfer_rule]:
includes lifting_syntax shows
"((S ===> (=)) ===> (S ===> rel_spmf S) ===> S ===> rel_spmf S) loop_spmf.while loop_spmf.while"
unfolding loop_spmf.while_def[abs_def]
apply(rule rel_funI)
apply(rule rel_funI)
apply(rule fixp_spmf_parametric[OF loop_spmf.while.mono loop_spmf.while.mono])
subgoal premises [transfer_rule] by transfer_prover
done

lemma loop_spmf_while_cong:
"⟦ guard = guard'; ⋀s. guard' s ⟹ body s = body' s ⟧
⟹ loop_spmf.while guard body = loop_spmf.while guard' body'"
unfolding loop_spmf.while_def[abs_def] by(simp cong: if_cong)

section ‹Rules for probabilistic termination›

context loop_spmf begin

subsection ‹0/1 termination laws›

lemma termination_0_1_immediate:
assumes p: "⋀s. guard s ⟹ spmf (map_spmf guard (body s)) False ≥ p"
and p_pos: "0 < p"
and lossless: "⋀s. guard s ⟹ lossless_spmf (body s)"
shows "lossless_spmf (while s)"
proof -
have "∀s. lossless_spmf (while s)"
proof(rule ccontr)
assume "¬ ?thesis"
then obtain s where s: "¬ lossless_spmf (while s)" by blast
hence True: "guard s" by(simp add: while.simps split: if_split_asm)

from p[OF this] have p_le_1: "p ≤ 1" using pmf_le_1 by(rule order_trans)
have new_bound: "p * (1 - k) + k ≤ weight_spmf (while s)"
if k: "0 ≤ k" "k ≤ 1" and k_le: "⋀s. k ≤ weight_spmf (while s)" for k s
proof(cases "guard s")
case False
have "p * (1 - k) + k ≤ 1 * (1 - k) + k" using p_le_1 k by(intro mult_right_mono add_mono; simp)
also have "… ≤ 1" by simp
finally show ?thesis using False by(simp add: while.simps)
next
case True
let ?M = "λs. measure_spmf (body s)"
have bounded: "¦∫ s''. weight_spmf (while s'') ∂?M s'¦ ≤ 1" for s'
using integral_nonneg_AE[of "λs''. weight_spmf (while s'')" "?M s'"]
by(auto simp add: weight_spmf_nonneg weight_spmf_le_1 intro!: measure_spmf.nn_integral_le_const integral_real_bounded)
have "p ≤ measure (?M s) {s'. ¬ guard s'}" using p[OF True]
hence "p * (1 - k) + k ≤ measure (?M s) {s'. ¬ guard s'} * (1 - k) + k"
also have "… = ∫ s'. indicator {s'. ¬ guard s'} s' * (1 - k) +  k ∂?M s"
using True by(simp add: ennreal_less_top_iff lossless lossless_weight_spmfD)
also have "… = ∫ s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * k ∂?M s"
by(rule Bochner_Integration.integral_cong)(simp_all split: split_indicator)
also have "… = ∫ s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * ∫ s''. k ∂?M s' ∂?M s"
by(rule Bochner_Integration.integral_cong)(auto simp add: lossless lossless_weight_spmfD split: split_indicator)
also have "… ≤ ∫ s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * ∫ s''. weight_spmf (while s'') ∂?M s' ∂?M s"
using k bounded
(simp_all add: weight_spmf_nonneg weight_spmf_le_1 mult_le_one k_le split: split_indicator)
also have "… = ∫s'. (if ¬ guard s' then 1 else ∫ s''. weight_spmf (while s'') ∂?M s') ∂?M s"
by(rule Bochner_Integration.integral_cong)(simp_all split: split_indicator)
also have "… = ∫ s'. weight_spmf (while s') ∂measure_spmf (body s)"
by(rule Bochner_Integration.integral_cong; simp add: while.simps weight_bind_spmf o_def)
also have "… = weight_spmf (while s)" using True
finally show ?thesis .
qed

define k where "k ≡ INF s. weight_spmf (while s)"
define k' where "k' ≡ p * (1 - k) + k"
from s have "weight_spmf (while s) < 1"
using weight_spmf_le_1[of "while s"] by(simp add: lossless_spmf_def)
then have "k < 1"
unfolding k_def by(rewrite cINF_less_iff)(auto intro!: bdd_belowI2 weight_spmf_nonneg)

have "0 ≤ k" unfolding k_def by(auto intro: cINF_greatest simp add: weight_spmf_nonneg)
moreover from ‹k < 1› have "k ≤ 1" by simp
moreover have "k ≤ weight_spmf (while s)" for s unfolding k_def
by(rule cINF_lower)(auto intro!: bdd_belowI2 weight_spmf_nonneg)
ultimately have "⋀s. k' ≤ weight_spmf (while s)"
unfolding k'_def by(rule new_bound)
hence "k' ≤ k" unfolding k_def by(auto intro: cINF_greatest)
also have "k < k'" using p_pos ‹k < 1› by(auto simp add: k'_def)
finally show False by simp
qed
thus ?thesis by blast
qed

primrec iter :: "nat ⇒ 'a ⇒ 'a spmf"
where
"iter 0 s = return_spmf s"
| "iter (Suc n) s = (if guard s then bind_spmf (body s) (iter n) else return_spmf s)"

lemma iter_unguarded [simp]: "¬ guard s ⟹ iter n s = return_spmf s"
by(induction n) simp_all

lemma iter_bind_iter: "bind_spmf (iter m s) (iter n) = iter (m + n) s"
by(induction m arbitrary: s) simp_all

lemma iter_Suc2: "iter (Suc n) s = bind_spmf (iter n s) (λs. if guard s then body s else return_spmf s)"
using iter_bind_iter[of n s 1, symmetric]
by(simp del: iter.simps)(rule bind_spmf_cong; simp cong: bind_spmf_cong)

lemma lossless_iter: "(⋀s. guard s ⟹ lossless_spmf (body s)) ⟹ lossless_spmf (iter n s)"
by(induction n arbitrary: s) simp_all

lemma iter_mono_emeasure1:
"emeasure (measure_spmf (iter n s)) {s. ¬ guard s} ≤ emeasure (measure_spmf (iter (Suc n) s)) {s. ¬ guard s}"
(is "?lhs ≤ ?rhs")
proof(cases "guard s")
case True
have "?lhs = emeasure (measure_spmf (bind_spmf (iter n s) return_spmf)) {s. ¬ guard s}" by simp
also have "… = ∫⇧+ s'. emeasure (measure_spmf (return_spmf s')) {s. ¬ guard s} ∂measure_spmf (iter n s)"
by(simp del: bind_return_spmf add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
also have "… ≤ ∫⇧+ s'. emeasure (measure_spmf (if guard s' then body s' else return_spmf s')) {s. ¬ guard s} ∂measure_spmf (iter n s)"
also have "… = ?rhs"
by(simp add: iter_Suc2 measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra del: iter.simps)
finally show ?thesis .
qed simp

lemma weight_while_conv_iter:
"weight_spmf (while s) = (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
(is "?lhs = ?rhs")
proof(rule antisym)
have "emeasure (measure_spmf (while s)) UNIV ≤ (SUP n. emeasure (measure_spmf (iter n s)) {s. ¬ guard s})"
(is "_ ≤ (SUP n. ?f n s)")
proof(induction arbitrary: s rule: while_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step while')
show ?case (is "?lhs' ≤ ?rhs'")
proof(cases "guard s")
case True
have inc: "incseq ?f" by(rule incseq_SucI le_funI iter_mono_emeasure1)+

from True have "?lhs' = ∫⇧+ s'. emeasure (measure_spmf (while' s')) UNIV ∂measure_spmf (body s)"
by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
also have "… ≤ ∫⇧+ s'. (SUP n. ?f n s') ∂measure_spmf (body s)"
by(rule nn_integral_mono)(rule step.IH)
also have "… = (SUP n. ∫⇧+ s'. ?f n s' ∂measure_spmf (body s))" using inc
by(subst nn_integral_monotone_convergence_SUP) simp_all
also have "… = (SUP n. ?f (Suc n) s)" using True
by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
also have "… ≤ (SUP n. ?f n s)"
by(rule SUP_mono)(auto intro: exI[where x="Suc _"])
finally show ?thesis .
next
case False
then have "?lhs' = emeasure (measure_spmf (iter 0 s)) {s. ¬ guard s}"
also have ‹… ≤ ?rhs'› by(rule SUP_upper) simp
finally show ?thesis .
qed
qed
also have "… = ennreal (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
by(subst ennreal_SUP)(fold measure_spmf.emeasure_eq_measure, auto simp add: not_less measure_spmf.subprob_emeasure_le_1 intro!: exI[where x="1"])
also have "0 ≤ (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
by(rule cSUP_upper2)(auto intro!: bdd_aboveI[where M=1] simp add: measure_spmf.subprob_measure_le_1)
ultimately show "?lhs ≤ ?rhs" by(simp add: measure_spmf.emeasure_eq_measure space_measure_spmf)

show "?rhs ≤ ?lhs"
proof(rule cSUP_least)
show "measure (measure_spmf (iter n s)) {s. ¬ guard s} ≤ weight_spmf (while s)" (is "?f n s ≤ _") for n
proof(induction n arbitrary: s)
case 0 show ?case
by(simp add: measure_spmf_return_spmf measure_return while_simps split: split_indicator)
next
case (Suc n)
show ?case
proof(cases "guard s")
case True
have "?f (Suc n) s = ∫⇧+ s'. ?f n s' ∂measure_spmf (body s)"
using True unfolding measure_spmf.emeasure_eq_measure[symmetric]
by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
also have "… ≤ ∫⇧+ s'. weight_spmf (while s') ∂measure_spmf (body s)"
by(rule nn_integral_mono ennreal_leI Suc.IH)+
also have "… = weight_spmf (while s)"
using True unfolding measure_spmf.emeasure_eq_measure[symmetric] space_measure_spmf
by(simp add: while_simps measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
finally show ?thesis by(simp)
next
case False then show ?thesis
by(simp add: measure_spmf_return_spmf measure_return while_simps split: split_indicator)
qed
qed
qed simp
qed

lemma termination_0_1:
assumes p: "⋀s. guard s ⟹ p ≤ weight_spmf (while s)"
and p_pos: "0 < p"
and lossless: "⋀s. guard s ⟹ lossless_spmf (body s)"
shows "lossless_spmf (while s)"
unfolding lossless_spmf_def
proof(rule antisym)
let ?X = "{s. ¬ guard s}"
show "weight_spmf (while s) ≤ 1" by(rule weight_spmf_le_1)

define p' where "p' ≡ p / 2"
have p'_pos: "p' > 0" and "p' < p" using p_pos by(simp_all add: p'_def)

have "∃n. p' < measure (measure_spmf (iter n s)) ?X" if "guard s" for s using p[OF that] ‹p' < p›
unfolding weight_while_conv_iter
by(subst (asm) le_cSUP_iff)(auto intro!: measure_spmf.subprob_measure_le_1)
then obtain N where p': "p' ≤ measure (measure_spmf (iter (N s) s)) ?X" if "guard s" for s
using p by atomize_elim(rule choice, force dest: order.strict_implies_order)

interpret fuse: loop_spmf guard "λs. iter (N s) s" .

have "1 = weight_spmf (fuse.while s)"
by(rule lossless_weight_spmfD[symmetric])
(rule fuse.termination_0_1_immediate; auto simp add: spmf_map vimage_def intro: p' p'_pos lossless_iter lossless)
also have "… ≤ (⨆n. measure (measure_spmf (iter n s)) ?X)"
unfolding fuse.weight_while_conv_iter
proof(rule cSUP_least)
fix n
have "emeasure (measure_spmf (fuse.iter n s)) ?X ≤ (SUP n. emeasure (measure_spmf (iter n s)) ?X)"
proof(induction n arbitrary: s)
case 0 show ?case by(auto intro!: SUP_upper2[where i=0])
next
case (Suc n)
have inc: "incseq (λn s'. emeasure (measure_spmf (iter n s')) ?X)"
by(rule incseq_SucI le_funI iter_mono_emeasure1)+

have "emeasure (measure_spmf (fuse.iter (Suc n) s)) ?X = emeasure (measure_spmf (iter (N s) s ⤜ fuse.iter n)) ?X"
by simp
also have "… = ∫⇧+ s'. emeasure (measure_spmf (fuse.iter n s')) ?X ∂measure_spmf (iter (N s) s)"
by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
also have "… ≤ ∫⇧+ s'. (SUP n. emeasure (measure_spmf (iter n s')) ?X) ∂measure_spmf (iter (N s) s)"
by(rule nn_integral_mono Suc.IH)+
also have "… = (SUP n. ∫⇧+ s'. emeasure (measure_spmf (iter n s')) ?X ∂measure_spmf (iter (N s) s))"
by(rule nn_integral_monotone_convergence_SUP[OF inc]) simp
also have "… = (SUP n. emeasure (measure_spmf (bind_spmf (iter (N s) s) (iter n))) ?X)"
by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
also have "… = (SUP n. emeasure (measure_spmf (iter (N s + n) s)) ?X)" by(simp add: iter_bind_iter)
also have "… ≤ (SUP n. emeasure (measure_spmf (iter n s)) ?X)" by(rule SUP_mono) auto
finally show ?case .
qed
also have "… = ennreal (SUP n. measure (measure_spmf (iter n s)) ?X)"
by(subst ennreal_SUP)(fold measure_spmf.emeasure_eq_measure, auto simp add: not_less measure_spmf.subprob_emeasure_le_1 intro!: exI[where x="1"])
also have "0 ≤ (SUP n. measure (measure_spmf (iter n s)) ?X)"
by(rule cSUP_upper2)(auto intro!: bdd_aboveI[where M=1] simp add: measure_spmf.subprob_measure_le_1)
ultimately show "measure (measure_spmf (fuse.iter n s)) ?X ≤ …"
qed simp
finally show  "1 ≤ weight_spmf (while s)" unfolding weight_while_conv_iter .
qed

end

lemma termination_0_1_immediate_invar:
fixes I :: "'s ⇒ bool"
assumes p: "⋀s. ⟦ guard s; I s ⟧ ⟹ spmf (map_spmf guard (body s)) False ≥ p"
and p_pos: "0 < p"
and lossless: "⋀s. ⟦ guard s; I s ⟧ ⟹ lossless_spmf (body s)"
and invar: "⋀s s'. ⟦ s' ∈ set_spmf (body s); I s; guard s ⟧ ⟹ I s'"
and I: "I s"
shows "lossless_spmf (loop_spmf.while guard body s)"
including lifting_syntax
proof -
{ assume "∃(Rep :: 's' ⇒ 's) Abs. type_definition Rep Abs {s. I s}"
then obtain Rep :: "'s' ⇒ 's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
then interpret td: type_definition Rep Abs "{s. I s}" .
define cr where "cr ≡ λx y. x = Rep y"
have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp

define guard' where "guard' ≡ (Rep ---> id) guard"
have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
define body1 where "body1 ≡ λs. if guard s then body s else return_pmf None"
define body1' where "body1' ≡ (Rep ---> map_spmf Abs) body1"
have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
define s' where "s' ≡ Abs s"
have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)

have "⋀s. guard' s ⟹ p ≤ spmf (map_spmf guard' (body1' s)) False"
by(transfer fixing: p)(simp add: body1_def p)
moreover note p_pos
moreover have "⋀s. guard' s ⟹ lossless_spmf (body1' s)" by transfer(simp add: lossless body1_def)
ultimately have "lossless_spmf (loop_spmf.while guard' body1' s')" by(rule loop_spmf.termination_0_1_immediate)
hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed

lemma termination_0_1_invar:
fixes I :: "'s ⇒ bool"
assumes p: "⋀s. ⟦ guard s; I s ⟧ ⟹ p ≤ weight_spmf (loop_spmf.while guard body s)"
and p_pos: "0 < p"
and lossless: "⋀s. ⟦ guard s; I s ⟧ ⟹ lossless_spmf (body s)"
and invar: "⋀s s'. ⟦ s' ∈ set_spmf (body s); I s; guard s ⟧ ⟹ I s'"
and I: "I s"
shows "lossless_spmf (loop_spmf.while guard body s)"
including lifting_syntax
proof-
{ assume "∃(Rep :: 's' ⇒ 's) Abs. type_definition Rep Abs {s. I s}"
then obtain Rep :: "'s' ⇒ 's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
then interpret td: type_definition Rep Abs "{s. I s}" .
define cr where "cr ≡ λx y. x = Rep y"
have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp

define guard' where "guard' ≡ (Rep ---> id) guard"
have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
define body1 where "body1 ≡ λs. if guard s then body s else return_pmf None"
define body1' where "body1' ≡ (Rep ---> map_spmf Abs) body1"
have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
define s' where "s' ≡ Abs s"
have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)

interpret loop_spmf guard' body1' .

note UNIV_parametric_pred[transfer_rule]
have "⋀s. guard' s ⟹ p ≤ weight_spmf (while s)"
unfolding measure_measure_spmf_def[symmetric] space_measure_spmf
by(transfer fixing: p)(simp add: body1_def p[simplified space_measure_spmf] cong: loop_spmf_while_cong)
moreover note p_pos
moreover have "⋀s. guard' s ⟹ lossless_spmf (body1' s)" by transfer(simp add: lossless body1_def)
ultimately have "lossless_spmf (while s')" by(rule termination_0_1)
hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed

subsection ‹Variant rule›

context loop_spmf begin

lemma termination_variant:
fixes bound :: nat
assumes bound: "⋀s. guard s ⟹ f s ≤ bound"
and step: "⋀s. guard s ⟹ p ≤ spmf (map_spmf (λs'. f s' < f s) (body s)) True"
and p_pos: "0 < p"
and lossless: "⋀s. guard s ⟹ lossless_spmf (body s)"
shows "lossless_spmf (while s)"
proof -
define p' and n where "p' ≡ min p 1" and "n ≡ bound + 1"
have p'_pos: "0 < p'" and p'_le_1: "p' ≤ 1"
and step': "guard s ⟹ p' ≤ measure (measure_spmf (body s)) {s'. f s' < f s}" for s
using p_pos step[of s] by(simp_all add: p'_def spmf_map vimage_def)
have "p' ^ n ≤ weight_spmf (while s)" if "f s < n" for s using that
proof(induction n arbitrary: s)
case 0 thus ?case by simp
next
case (Suc n)
show ?case
proof(cases "guard s")
case False
hence "weight_spmf (while s) = 1" by(simp add: while.simps)
thus ?thesis using p'_le_1 p_pos
by simp(meson less_eq_real_def mult_le_one p'_pos power_le_one zero_le_power)
next
case True
let ?M = "measure_spmf (body s)"
have "p' ^ Suc n ≤ (∫ s'. indicator {s'. f s' < f s} s' ∂?M) * p' ^ n"
using step'[OF True] p'_pos by(simp add: mult_right_mono)
also have "… = (∫ s'. indicator {s'. f s' < f s} s' * p' ^ n ∂?M)" by simp
also have "… ≤ (∫ s'. indicator {s'. f s' < f s} s' * weight_spmf (while s') ∂?M)"
using Suc.prems p'_le_1 p'_pos
by(intro integral_mono)(auto simp add: Suc.IH power_le_one weight_spmf_le_1 split: split_indicator intro!: measure_spmf.integrable_const_bound[where B=1])
also have "… ≤ … + (∫ s'. indicator {s'. f s' ≥ f s} s' * weight_spmf (while s') ∂?M)"
also have "… = ∫ s'. weight_spmf (while s') ∂?M"
(auto intro!: Bochner_Integration.integral_cong measure_spmf.integrable_const_bound[where B=1] weight_spmf_le_1 split: split_indicator)
also have "… = weight_spmf (while s)"
using True by(subst (1 2) while.simps)(simp add: weight_bind_spmf o_def)
finally show ?thesis .
qed
qed
moreover have "0 < p' ^ n" using p'_pos by simp
ultimately show ?thesis using lossless
proof(rule termination_0_1_invar)
show "f s < n" if "guard s" "guard s ⟶ f s < n" for s using that by simp
show "guard s ⟶ f s < n" using bound[of s] by(auto simp add: n_def)
show "guard s' ⟶ f s' < n" for s' using bound[of s'] by(clarsimp simp add: n_def)
qed
qed

end

lemma termination_variant_invar:
fixes bound :: nat and I :: "'s ⇒ bool"
assumes bound: "⋀s. ⟦ guard s; I s ⟧ ⟹ f s ≤ bound"
and step: "⋀s. ⟦ guard s; I s ⟧ ⟹ p ≤ spmf (map_spmf (λs'. f s' < f s) (body s)) True"
and p_pos: "0 < p"
and lossless: "⋀s. ⟦ guard s; I s ⟧ ⟹ lossless_spmf (body s)"
and invar: "⋀s s'. ⟦ s' ∈ set_spmf (body s); I s; guard s ⟧ ⟹ I s'"
and I: "I s"
shows "lossless_spmf (loop_spmf.while guard body s)"
including lifting_syntax
proof -
{ assume "∃(Rep :: 's' ⇒ 's) Abs. type_definition Rep Abs {s. I s}"
then obtain Rep :: "'s' ⇒ 's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
then interpret td: type_definition Rep Abs "{s. I s}" .
define cr where "cr ≡ λx y. x = Rep y"
have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp

define guard' where "guard' ≡ (Rep ---> id) guard"
have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
define body1 where "body1 ≡ λs. if guard s then body s else return_pmf None"
define body1' where "body1' ≡ (Rep ---> map_spmf Abs) body1"
have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
define s' where "s' ≡ Abs s"
have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)
define f' where "f' ≡ (Rep ---> id) f"
have [transfer_rule]: "(cr ===> (=)) f f'" by(simp add: rel_fun_def cr_def f'_def)

have "⋀s. guard' s ⟹ f' s ≤ bound" by(transfer fixing: bound)(rule bound)
moreover have "⋀s. guard' s ⟹ p ≤ spmf (map_spmf (λs'. f' s' < f' s) (body1' s)) True"
by(transfer fixing: p)(simp add: step body1_def)
note this p_pos
moreover have "⋀s. guard' s ⟹ lossless_spmf (body1' s)"
ultimately have "lossless_spmf (loop_spmf.while guard' body1' s')" by(rule loop_spmf.termination_variant)
hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed

end


# Theory Bernoulli

(* Title: Bernoulli.thy
Author: Andreas Lochbihler, ETH Zurich *)

section ‹Distributions built from coin flips›

subsection ‹ The Bernoulli distribution›

theory Bernoulli imports "HOL-Probability.Probability" begin

lemma zero_lt_num [simp]: "0 < (numeral n :: _ :: {canonically_ordered_monoid_add, semiring_char_0})"
by (metis not_gr_zero zero_neq_numeral)

lemma ennreal_mult_numeral: "ennreal x * numeral n = ennreal (x * numeral n)"

lemma one_plus_ennreal: "0 ≤ x ⟹ 1 + ennreal x = ennreal (1 + x)"
by simp

text ‹
We define the Bernoulli distribution as a least fixpoint instead of a loop because this
avoids the need to add a condition flag to the distribution, which we would have to project
out at the end again.  As the direct termination proof is so simple, we do not bother to prove
it equivalent to a while loop.
›

partial_function (spmf) bernoulli :: "real ⇒ bool spmf" where
"bernoulli p = do {
b ← coin_spmf;
if b then return_spmf (p ≥ 1 / 2)
else if p < 1 / 2 then bernoulli (2 * p)
else bernoulli (2 * p - 1)
}"

lemma pmf_bernoulli_None: "pmf (bernoulli p) None = 0"
proof -
have "ereal (pmf (bernoulli p) None) ≤ (INF n∈UNIV. ereal (1 / 2 ^ n))"
proof(rule INF_greatest)
show "ereal (pmf (bernoulli p) None) ≤ ereal (1 / 2 ^ n)" for n
proof(induction n arbitrary: p)
case (Suc n)
show ?case using Suc.IH[of "2 * p"] Suc.IH[of "2 * p - 1"]
by(subst bernoulli.simps)(simp add: UNIV_bool max_def field_simps spmf_of_pmf_pmf_of_set[symmetric] pmf_bind_pmf_of_set ennreal_pmf_bind nn_integral_pmf_of_set del: spmf_of_pmf_pmf_of_set)
qed
also have "… = ereal 0"
proof(rule LIMSEQ_unique)
show "(λn. ereal (1 / 2 ^ n)) ⇢ …" by(rule LIMSEQ_INF)(simp add: field_simps decseq_SucI)
show "(λn. ereal (1 / 2 ^ n)) ⇢ ereal 0" by(simp add: LIMSEQ_divide_realpow_zero)
qed
finally show ?thesis by simp
qed

lemma lossless_bernoulli [simp]: "lossless_spmf (bernoulli p)"

lemma [simp]: assumes "0 ≤ p" "p ≤ 1"
shows bernoulli_True: "spmf (bernoulli p) True = p" (is ?True)
and bernoulli_False: "spmf (bernoulli p) False = 1 - p" (is ?False)
proof -
{ have "ennreal (spmf (bernoulli p) b) ≤ ennreal (if b then p else 1 - p)" for b using assms
proof(induction arbitrary: p rule: bernoulli.fixp_induct[case_names adm bottom step])
case adm show ?case by(rule cont_intro)+
next
case (step bernoulli' p)
show ?case using step.prems step.IH[of "2 * p"] step.IH[of "2 * p - 1"]
by(auto simp add: UNIV_bool max_def divide_le_posI_ennreal ennreal_mult_numeral numeral_mult_ennreal field_simps spmf_of_pmf_pmf_of_set[symmetric] ennreal_pmf_bind nn_integral_pmf_of_set one_plus_ennreal simp del: spmf_of_pmf_pmf_of_set ennreal_plus)
qed simp }
note this[of True] this[of False]
moreover have "spmf (bernoulli p) True + spmf (bernoulli p) False = 1"
ultimately show ?True ?False using assms by(auto simp add: ennreal_le_iff2)
qed

lemma bernoulli_neg [simp]:
assumes "p ≤ 0"
shows "bernoulli p = return_spmf False"
proof -
from assms have "ord_spmf (=) (bernoulli p) (return_spmf False)"
proof(induction arbitrary: p rule: bernoulli.fixp_induct[case_names adm bottom step])
case (step bernoulli' p)
show ?case using step.prems step.IH[of "2 * p"]
by(auto simp add: ord_spmf_return_spmf2 set_bind_spmf bind_UNION field_simps)
qed simp_all
from ord_spmf_eq_leD[OF this, of True] have "spmf (bernoulli p) True = 0" by simp
moreover then have "spmf (bernoulli p) False = 1" by(simp add: spmf_False_conv_True)
ultimately show ?thesis by(auto intro: spmf_eqI split: split_indicator)
qed

lemma bernoulli_pos [simp]:
assumes "1 ≤ p"
shows "bernoulli p = return_spmf True"
proof -
from assms have "ord_spmf (=) (bernoulli p) (return_spmf True)"
proof(induction arbitrary: p rule: bernoulli.fixp_induct[case_names adm bottom step])
case (step bernoulli' p)
show ?case using step.prems step.IH[of "2 * p - 1"]
by(auto simp add: ord_spmf_return_spmf2 set_bind_spmf bind_UNION field_simps)
qed simp_all
from ord_spmf_eq_leD[OF this, of False] have "spmf (bernoulli p) False = 0" by simp
moreover then have "spmf (bernoulli p) True = 1" by(simp add: spmf_False_conv_True)
ultimately show ?thesis by(auto intro: spmf_eqI split: split_indicator)
qed

context begin interpretation pmf_as_function .
lemma bernoulli_eq_bernoulli_pmf:
"bernoulli p = spmf_of_pmf (bernoulli_pmf p)"
by(rule spmf_eqI; simp)(transfer; auto simp add: max_def min_def)
end

end


# Theory Geometric

(* Title: Geometric.thy
Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹The geometric distribution›

theory Geometric imports
Bernoulli
While_SPMF
begin

text ‹
We define the geometric distribution as a least fixpoint, which is more elegant than
as a loop. To prove probabilistic termination, we prove it equivalent to a loop and use
the proof rules for probabilistic termination.
›

context notes [[function_internals]] begin
partial_function (spmf) geometric_spmf :: "real ⇒ nat spmf" where
"geometric_spmf p = do {
b ← bernoulli p;
if b then return_spmf 0 else map_spmf ((+) 1) (geometric_spmf p)
}"
end

lemma geometric_spmf_fixp_induct [case_names adm bottom step]:
and "P (λgeometric_spmf. return_pmf None)"
and "⋀geometric_spmf'. P geometric_spmf' ⟹ P (λp. bernoulli p ⤜ (λb. if b then return_spmf 0 else map_spmf ((+) 1) (geometric_spmf' p)))"
shows "P geometric_spmf"
using assms by(rule geometric_spmf.fixp_induct)

lemma spmf_geometric_nonpos: "p ≤ 0 ⟹ geometric_spmf p = return_pmf None"
by(induction rule: geometric_spmf_fixp_induct) simp_all

lemma spmf_geometric_ge_1: "1 ≤ p ⟹ geometric_spmf p = return_spmf 0"

context
fixes p :: real
and body :: "bool × nat ⇒ (bool × nat) spmf"
defines [simp]: "body ≡ λ(b, x). map_spmf (λb'. (¬ b', x + (if b' then 0 else 1))) (bernoulli p)"
begin

interpretation loop_spmf fst body
rewrites "body ≡ λ(b, x). map_spmf (λb'. (¬ b', x + (if b' then 0 else 1))) (bernoulli p)"
by(fact body_def)

lemma geometric_spmf_conv_while:
shows "geometric_spmf p = map_spmf snd (while (True, 0))"
proof -
have "map_spmf ((+) x) (geometric_spmf p) = map_spmf snd (while (True, x))" (is "?lhs = ?rhs") for x
proof(rule spmf.leq_antisym)
show "ord_spmf (=) ?lhs ?rhs"
proof(induction arbitrary: x rule: geometric_spmf_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step geometric')
show ?case using step.IH[of "Suc x"]
apply(rewrite while.simps)
apply(clarsimp simp add: map_spmf_bind_spmf bind_map_spmf intro!: ord_spmf_bind_reflI)
apply(rewrite while.simps)
done
qed
have "ord_spmf (=) ?rhs ?lhs"
and "ord_spmf (=) (map_spmf snd (while (False, x))) (return_spmf x)"
proof(induction arbitrary: x and x rule: while_fixp_induct)
case adm show ?case by simp
case bottom case 1 show ?case by simp
case bottom case 2 show ?case by simp
next
case (step while')
case 1 show ?case using step.IH(1)[of "Suc x"] step.IH(2)[of x]
by(rewrite geometric_spmf.simps)(clarsimp simp add: map_spmf_bind_spmf bind_map_spmf spmf.map_comp o_def intro!: ord_spmf_bind_reflI)
case 2 show ?case by simp
qed
then show "ord_spmf (=) ?rhs ?lhs" by -
qed
from this[of 0] show ?thesis by(simp cong: map_spmf_cong)
qed

lemma lossless_geometric [simp]: "lossless_spmf (geometric_spmf p) ⟷ p > 0"
proof(cases "0 < p ∧ p < 1")
case True
let ?body = "λ(b, x :: nat). map_spmf (λb'. (¬ b', x + (if b' then 0 else 1))) (bernoulli p)"
have "lossless_spmf (while (True, 0))"
proof(rule termination_0_1_immediate)
have "{x. x} = {True}" by auto
then show "p ≤ spmf (map_spmf fst (?body s)) False" for s :: "bool × nat" using True
by(cases s)(simp add: spmf.map_comp o_def spmf_map vimage_def spmf_conv_measure_spmf[symmetric])
show "0 < p" using True by simp
qed(clarsimp)
with True show ?thesis by(simp add: geometric_spmf_conv_while)

end

lemma spmf_geometric:
assumes p: "0 < p" "p < 1"
shows "spmf (geometric_spmf p) n = (1 - p) ^ n * p" (is "?lhs n = ?rhs n")
proof(rule spmf_ub_tight)
fix n
have "ennreal (?lhs n) ≤ ennreal (?rhs n)" using p
proof(induction arbitrary: n rule: geometric_spmf_fixp_induct)
case adm show ?case by(rule cont_intro)+
case bottom show ?case by simp
case (step geometric_spmf')
then show ?case
by(cases n)(simp_all add: ennreal_spmf_bind nn_integral_measure_spmf UNIV_bool nn_integral_count_space_finite ennreal_mult spmf_map vimage_def mult.assoc spmf_conv_measure_spmf[symmetric] mult_mono split: split_indicator)
qed
then show "?lhs n ≤ ?rhs n" using p by(simp)
next
have "(∑i. ennreal (p * (1 - p) ^ i)) = ennreal (p * (1 / (1 - (1 - p))))" using p
by (intro suminf_ennreal_eq sums_mult geometric_sums) auto
then show "(∑⇧+ x. ennreal ((1 - p) ^ x * p)) = weight_spmf (geometric_spmf p)"
using lossless_geometric[of p] p unfolding lossless_spmf_def
qed

end


# Theory Fast_Dice_Roll

(* Title: Fast_Dice_Roll.thy
Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹Arbitrary uniform distributions›

theory Fast_Dice_Roll imports
Bernoulli
While_SPMF
begin

text ‹This formalisation follows the ideas by J\'er\'emie Lumbroso \cite{Lumbroso2013arxiv}.›

lemma sample_bits_fusion:
fixes v :: nat
assumes "0 < v"
shows
"bind_pmf (pmf_of_set {..<v}) (λc. bind_pmf (pmf_of_set UNIV) (λb. f (2 * c + (if b then 1 else 0)))) =
bind_pmf (pmf_of_set {..<2 * v}) f"
(is "?lhs = ?rhs")
proof -
have "?lhs = bind_pmf (map_pmf (λ(c, b). (2 * c + (if b then 1 else 0))) (pair_pmf (pmf_of_set {..<v}) (pmf_of_set UNIV))) f"
(is "_ = bind_pmf (map_pmf ?f _) _")
by(simp add: pair_pmf_def bind_map_pmf bind_assoc_pmf bind_return_pmf)
also have "map_pmf ?f (pair_pmf (pmf_of_set {..<v}) (pmf_of_set UNIV)) = pmf_of_set {..<2 * v}"
(is "?l = ?r" is "map_pmf ?f ?p = _")
proof(rule pmf_eqI)
fix i :: nat
have [simp]: "inj ?f" by(auto simp add: inj_on_def) arith+
define i' where "i' ≡ i div 2"
define b where "b ≡ odd i"
have i: "i = ?f (i', b)" by(simp add: i'_def b_def)
show "pmf ?l i = pmf ?r i"
by(subst i; subst pmf_map_inj')(simp_all add: pmf_pair i'_def assms lessThan_empty_iff split: split_indicator)
qed
finally show ?thesis .
qed

lemma sample_bits_fusion2:
fixes v :: nat
assumes "0 < v"
shows
"bind_pmf (pmf_of_set UNIV) (λb. bind_pmf (pmf_of_set {..<v}) (λc. f (c + v * (if b then 1 else 0)))) =
bind_pmf (pmf_of_set {..<2 * v}) f"
(is "?lhs = ?rhs")
proof -
have "?lhs = bind_pmf (map_pmf (λ(c, b). (c + v * (if b then 1 else 0))) (pair_pmf (pmf_of_set {..<v}) (pmf_of_set UNIV))) f"
(is "_ = bind_pmf (map_pmf ?f _) _")
unfolding pair_pmf_def by(subst bind_commute_pmf)(simp add: bind_map_pmf bind_assoc_pmf bind_return_pmf)
also have "map_pmf ?f (pair_pmf (pmf_of_set {..<v}) (pmf_of_set UNIV)) = pmf_of_set {..<2 * v}"
(is "?l = ?r" is "map_pmf ?f ?p = _")
proof(rule pmf_eqI)
fix i :: nat
have [simp]: "inj_on ?f ({..<v} × UNIV)" by(auto simp add: inj_on_def)
define i' where "i' ≡ if i ≥ v then i - v else i"
define b where "b ≡ i ≥ v"
have i: "i = ?f (i', b)" by(simp add: i'_def b_def)
show "pmf ?l i = pmf ?r i"
proof(cases "i < 2 * v")
case True
thus ?thesis
by(subst i; subst pmf_map_inj)(auto simp add: pmf_pair i'_def assms lessThan_empty_iff split: split_indicator)
next
case False
hence "i ∉ set_pmf ?l" "i ∉ set_pmf ?r"
using assms by(auto simp add: lessThan_empty_iff split: if_split_asm)
thus ?thesis by(simp add: set_pmf_iff del: set_map_pmf)
qed
qed
finally show ?thesis .
qed

context fixes n :: nat notes [[function_internals]] begin

text ‹
The check for @{term "v >= n"} should be done already at the start of the loop.
Otherwise we do not see why this algorithm should be optimal (when we start with @{term "v = n"}
and @{term "c = n - 1"}, then it can go round a few loops before it returns something).

We define the algorithm as a least fixpoint. To prove termination, we later show that it is
equivalent to a while loop which samples bitstrings of a given length, which could in turn
be implemented as a loop.  The fixpoint formulation is more elegant because we do not need to
nest any loops.
›

partial_function (spmf) fast_dice_roll :: "nat ⇒ nat ⇒ nat spmf"
where
"fast_dice_roll v c =
(if v ≥ n then if c < n then return_spmf c else fast_dice_roll (v - n) (c - n)
else do {
b ← coin_spmf;
fast_dice_roll (2 * v) (2 * c + (if b then 1 else 0)) } )"

lemma fast_dice_roll_fixp_induct [case_names adm bottom step]:
assumes "spmf.admissible (λfast_dice_roll. P (curry fast_dice_roll))"
and "P (λv c. return_pmf None)"
and "⋀fdr. P fdr ⟹ P (λv c. if v ≥ n then if c < n then return_spmf c else fdr (v - n) (c - n)
else bind_spmf coin_spmf (λb. fdr (2 * v) (2 * c + (if b then 1 else 0))))"
shows "P fast_dice_roll"
using assms by(rule fast_dice_roll.fixp_induct)

definition fast_uniform :: "nat spmf"
where "fast_uniform = fast_dice_roll 1 0"

lemma spmf_fast_dice_roll_ub:
assumes "0 < v"
shows "spmf (bind_pmf (pmf_of_set {..<v}) (fast_dice_roll v)) x ≤ (if x < n then 1 / n else 0)"
(is "?lhs ≤ ?rhs")
proof -
have "ennreal ?lhs ≤ ennreal ?rhs" using assms
proof(induction arbitrary: v x rule: fast_dice_roll_fixp_induct)
case bottom thus ?case by simp
case (step fdr)
show ?case (is "?lhs ≤ ?rhs")
proof(cases "n ≤ v")
case le: True
then have "?lhs = spmf (bind_pmf (pmf_of_set {..<v}) (λc. if c < n then return_spmf c else fdr (v - n) (c - n))) x"
by simp
also have "… = (∫⇧+ c'. indicator (if x < n then {x} else {}) c' ∂measure_pmf (pmf_of_set {..<v})) +
(∫⇧+ c'. indicator {n ..< v} c' * spmf (fdr (v - n) (c' - n)) x ∂measure_pmf (pmf_of_set {..<v}))"
(is "?then = ?found + ?continue") using step.prems
also have "?found = (if x < n then 1 else 0) / v" using step.prems le
by(auto simp add: measure_pmf.emeasure_eq_measure measure_pmf_of_set lessThan_empty_iff Iio_Int_singleton)
also have "?continue = (∫⇧+ c'. indicator {n ..< v} c' * 1 / v * spmf (fdr (v - n) (c' - n)) x ∂count_space UNIV)"
using step.prems by(auto simp add: nn_integral_measure_pmf lessThan_empty_iff ennreal_mult[symmetric] intro!: nn_integral_cong split: split_indicator)
also have "… = (if v = n then 0 else ennreal ((v - n) / v) * spmf (bind_pmf (pmf_of_set {n..<v}) (λc'. fdr (v - n) (c' - n))) x)"
using le step.prems
by(subst ennreal_pmf_bind)(auto simp add: ennreal_mult[symmetric] nn_integral_measure_pmf nn_integral_0_iff_AE AE_count_space nn_integral_cmult[symmetric] split: split_indicator)
also {
assume *: "n < v"
then have "pmf_of_set {n..<v} = map_pmf ((+) n) (pmf_of_set {..<v - n})"
by(subst map_pmf_of_set_inj)(auto 4 3 simp add: inj_on_def lessThan_empty_iff intro!: arg_cong[where f=pmf_of_set] intro: rev_image_eqI[where x="_ - n"] diff_less_mono)
also have "bind_pmf … (λc'. fdr (v - n) (c' - n)) = bind_pmf (pmf_of_set {..<v - n}) (fdr (v - n))"
also have "ennreal (spmf … x) ≤ (if x < n then 1 / n else 0)"
also note calculation }
then have "… ≤ ennreal ((v - n) / v) * (if x < n then 1 / n else 0)" using le
by(cases "v = n")(auto split del: if_split intro: divide_right_mono mult_left_mono)
also have "… = (v - n) / v * (if x < n then 1 / n else 0)" by(simp add: ennreal_mult[symmetric])
finally show ?thesis using le by(auto simp add: add_mono field_simps of_nat_diff ennreal_plus[symmetric] simp del: ennreal_plus)
next
case False
then have "?lhs = spmf (bind_pmf (pmf_of_set {..<v}) (λc. bind_pmf (pmf_of_set UNIV) (λb. fdr (2 * v) (2 * c + (if b then 1 else 0))))) x"
also have "… = spmf (bind_pmf (pmf_of_set {..<2 * v}) (fdr (2 * v))) x" using step.prems
also have "… ≤ ?rhs" using step.prems by(intro step.IH) simp
finally show ?thesis .
qed
qed
thus ?thesis by simp
qed

lemma spmf_fast_uniform_ub:
"spmf fast_uniform x ≤ (if x < n then 1 / n else 0)"
proof -
have "{..<Suc 0} = {0}" by auto
then show ?thesis using spmf_fast_dice_roll_ub[of 1 x]
by(simp add: fast_uniform_def pmf_of_set_singleton bind_return_pmf split: if_split_asm)
qed

lemma fast_dice_roll_0 [simp]: "fast_dice_roll 0 c = return_pmf None"
by(induction arbitrary: c rule: fast_dice_roll_fixp_induct)(simp_all add: bind_eq_return_pmf_None)

text ‹To prove termination, we fold all the iterations that only double into one big step›

definition fdr_step :: "nat ⇒ nat ⇒ (nat × nat) spmf"
where
"fdr_step v c =
(if v = 0 then return_pmf None
else let x = 2 ^ (nat ⌈log 2 (max 1 n) - log 2 v⌉) in
map_spmf (λbs. (x * v, x * c + bs)) (spmf_of_set {..<x}))"

lemma fdr_step_unfold:
"fdr_step v c =
(if v = 0 then return_pmf None
else if n ≤ v then return_spmf (v, c)
else do {
b ← coin_spmf;
fdr_step (2 * v) (2 * c + (if b then 1 else 0)) })"
(is "?lhs = ?rhs" is "_ = (if _ then _ else ?else)")
proof(cases "v = 0")
case v: False
define x where "x ≡ λv :: nat. 2 ^ (nat ⌈log 2 (max 1 n) - log 2 v⌉) :: nat"
have x_pos: "x v > 0" by(simp add: x_def)

show ?thesis
proof(cases "n ≤ v")
case le: True
hence "x v = 1" using v by(simp add: x_def log_le)
moreover have "{..<1} = {0 :: nat}" by auto
ultimately show ?thesis using le v by(simp add: fdr_step_def spmf_of_set_singleton)
next
case less: False
hence even: "even (x v)" using v by(simp add: x_def)
with x_pos have x_ge_1: "x v > 1" by(cases "x v = 1") auto
have *: "x (2 * v) = x v div 2" using v less unfolding x_def
apply(rewrite in "_ = 2 ^ ⌑ div _" le_add_diff_inverse2[symmetric, where b=1])
apply(simp del: Suc_pred)
done

have "?lhs = map_spmf (λbs. (x v * v, x v * c + bs)) (spmf_of_set {..<x v})"
using v by(simp add: fdr_step_def x_def Let_def)
also from even have "… = bind_pmf (pmf_of_set {..<2 * (x v div 2)}) (λbs. return_spmf (x v * v, x v * c + bs))"
by(simp add: map_spmf_conv_bind_spmf bind_spmf_spmf_of_set x_pos lessThan_empty_iff)
also have "… = bind_spmf coin_spmf (λb. bind_spmf (spmf_of_set {..<x v div 2})
(λc'. return_spmf (x v * v, x v * c + c' + (x v div 2) * (if b then 1 else 0))))"
using x_ge_1
also have "… = bind_spmf coin_spmf (λb. map_spmf (λbs. (x (2 * v) * (2 * v), x (2 * v) * (2 * c + (if b then 1 else 0)) + bs)) (spmf_of_set {..<x (2 * v)}))"
using * even by(simp add: map_spmf_conv_bind_spmf algebra_simps)
also have "… = ?rhs" using v less by(simp add: fdr_step_def Let_def x_def)
finally show ?thesis .
qed

lemma fdr_step_induct [case_names fdr_step]:
"(⋀v c. (⋀b. ⟦v ≠ 0; v < n⟧ ⟹ P (2 * v) (2 * c + (if b then 1 else 0))) ⟹ P v c)
⟹ P v c"
apply induction_schema
apply pat_completeness
apply(relation "Wellfounded.measure (λ(v, c). n - v)")
apply simp_all
done

partial_function (spmf) fdr_alt :: "nat ⇒ nat ⇒ nat spmf"
where
"fdr_alt v c = do {
(v', c') ← fdr_step v c;
if c' < n then return_spmf c' else fdr_alt (v' - n) (c' - n) }"

lemma fast_dice_roll_alt: "fdr_alt = fast_dice_roll"
proof(intro ext)
show "fdr_alt v c = fast_dice_roll v c" for v c
proof(rule spmf.leq_antisym)
show "ord_spmf (=) (fdr_alt v c) (fast_dice_roll v c)"
proof(induction arbitrary: v c rule: fdr_alt.fixp_induct[case_names adm bottom step])
case adm show ?case by simp
case bottom show ?case by simp
case (step fdra)
show ?case
proof(induction v c rule: fdr_step_induct)
case inner: (fdr_step v c)
show ?case
apply(rewrite fdr_step_unfold)
apply(rewrite fast_dice_roll.simps)
apply(auto intro!: ord_spmf_bind_reflI simp add: Let_def inner.IH step.IH)
done
qed
qed
have "ord_spmf (=) (fast_dice_roll v c) (fdr_alt v c)"
and "fast_dice_roll 0 c = return_pmf None"
proof(induction arbitrary: v c rule: fast_dice_roll_fixp_induct)
case adm thus ?case by simp
case bottom case 1 thus ?case by simp
case bottom case 2 thus ?case by simp
case (step fdr) case 1 show ?case
apply(rewrite fdr_alt.simps)
apply(rewrite fdr_step_unfold)
apply(auto intro!: ord_spmf_bind_reflI simp add: fdr_alt.simps[symmetric] step.IH rel_pmf_return_pmf2 set_pmf_bind_spmf o_def set_pmf_spmf_of_set split: if_split_asm)
done
case step case 2 from step.IH show ?case by(simp add: Let_def bind_eq_return_pmf_None)
qed
then show "ord_spmf (=) (fast_dice_roll v c) (fdr_alt v c)" by -
qed
qed

lemma lossless_fdr_step [simp]: "lossless_spmf (fdr_step v c) ⟷ v > 0"

lemma fast_dice_roll_alt_conv_while:
"fdr_alt v c =
map_spmf snd (bind_spmf (fdr_step v c) (loop_spmf.while (λ(v, c). n ≤ c) (λ(v, c). fdr_step (v - n) (c - n))))"
proof(induction arbitrary: v c rule: parallel_fixp_induct_2_1[OF partial_function_definitions_spmf partial_function_definitions_spmf fdr_alt.mono loop_spmf.while.mono fdr_alt_def loop_spmf.while_def, case_names adm bottom step])
case bottom show ?case by simp
case (step fdr while)
show ?case using step.IH
by(auto simp add: map_spmf_bind_spmf o_def intro!: bind_spmf_cong[OF refl])
qed

lemma lossless_fast_dice_roll:
assumes "c < v" "v ≤ n"
shows "lossless_spmf (fast_dice_roll v c)"
proof(cases "v < n")
case True
let ?I = "λ(v, c). c < v ∧ n ≤ v ∧ v < 2 * n"
let ?f = "λ(v, c). if n ≤ c then n + c - v + 1 else 0"
have invar: "?I (v', c')" if step: "(v', c') ∈ set_spmf (fdr_step (v - n) (c - n))"
and I: "c < v" "n ≤ v" "v < 2 * n" and c: "n ≤ c" for v' c' v c
proof(clarsimp; safe)
define x where "x = nat ⌈log 2 (max 1 n) - log 2 (v - n)⌉"
have **: "-1 < log 2 (real n / real (v - n))" by(rule less_le_trans[where y=0])(use I c in ‹auto›)

from I c step obtain bs where v': "v' = 2 ^ x * (v - n)"
and c': "c' = 2 ^ x * (c - n) + bs"
and bs: "bs < 2 ^ x"
unfolding fdr_step_def x_def[symmetric] by(auto simp add: Let_def)
have "2 ^ x * (c - n) + bs < 2 ^ x * (c - n + 1)" unfolding distrib_left using bs
also have "… ≤ 2 ^ x * (v - n)" using I c by(intro mult_left_mono) auto
finally show "c' < v'" using c' v' by simp

have "v' = 2 powr x * (v - n)" by(simp add: powr_realpow v')
also have "… < 2 powr (log 2 (max 1 n) - log 2 (v - n) + 1) * (v - n)"
using ** I c by(intro mult_strict_right_mono)(auto simp add: x_def log_divide)
also have "… ≤ 2 * n" unfolding powr_add using I c
finally show "v' < 2 * n" using c' by(simp del: of_nat_add)

have "log 2 (n / (v - n)) ≤ x" using I c ** by(auto simp add: x_def log_divide max_def)
hence "2 powr log 2 (n / (v - n)) ≤ 2 powr x" by(rule powr_mono) simp
also have "2 powr log 2 (n / (v - n)) = n / (v - n)" using I c by(simp)
finally have "n ≤ real (2 ^ x * (v - n))" using I c by(simp add: field_simps powr_realpow)
then show "n ≤ v'" by(simp add: v' del: of_nat_mult)
qed

have loop: "lossless_spmf (loop_spmf.while (λ(v, c). n ≤ c) (λ(v, c). fdr_step (v - n) (c - n)) (v, c))"
if "c < 2 * n" and "n ≤ v" and "c < v" and "v < 2 * n"
for v c
proof(rule termination_variant_invar; clarify?)
fix v c
assume I: "?I (v, c)" and c: "n ≤ c"
show "?f (v, c) ≤ n" using I c by auto

define x where "x = nat ⌈log 2 (max 1 n) - log 2 (v - n)⌉"
define p :: real where "p ≡ 1 / (2 * n)"

from I c have n: "0 < n" and v: "n < v" by auto
from I c v n have x_pos: "x > 0" by(auto simp add: x_def max_def)

have **: "-1 < log 2 (real n / real (v - n))" by(rule less_le_trans[where y=0])(use I c in ‹auto›)
then have "x ≤ log 2 (real n) + 1" using v n
hence "2 powr x ≤ 2 powr …" by(rule powr_mono) simp
hence "p ≤ 1 / 2 ^ x" unfolding powr_add using n
by(subst (asm) powr_realpow, simp)(subst (asm) powr_log_cancel; simp_all add: p_def field_simps)
also
let ?X = "{c'. n ≤ 2 ^ x * (c - n) + c' ⟶ n + (2 ^ x * (c - n) + c') - 2 ^ x * (v - n) < n + c - v}"
have "n + c * 2 ^ x - v * 2 ^ x < c + n - v" using I c
proof(cases "n + c * 2 ^ x ≥ v * 2 ^ x")
case True
have "(int c - v) * 2 ^ x < (int c - v) * 1"
using x_pos I c by(intro mult_strict_left_mono_neg) simp_all
then have "int n + c * 2 ^ x - v * 2 ^ x < c + int n - v" by(simp add: algebra_simps)
also have "… = int (c + n - v)" using I c by auto
also have "int n + c * 2 ^ x - v * 2 ^ x = int (n + c * 2 ^ x - v * 2 ^ x)"
using True that by(simp add: of_nat_diff)
finally show ?thesis by simp
qed auto
then have "{..<2 ^ x} ∩ ?X ≠ {}" using that n v
by(auto simp add: disjoint_eq_subset_Compl Collect_neg_eq[symmetric] lessThan_subset_Collect algebra_simps intro: exI[where x=0])
then have "0 < card ({..<2 ^ x} ∩ ?X)" by(simp add: card_gt_0_iff)
hence "1 / 2 ^ x ≤ … / 2 ^ x" by(simp add: field_simps)
finally show "p ≤ spmf (map_spmf (λs'. ?f s' < ?f (v, c)) (fdr_step (v - n) (c - n))) True"
using I c unfolding fdr_step_def x_def[symmetric]
by(clarsimp simp add: Let_def spmf.map_comp o_def spmf_map measure_spmf_of_set vimage_def p_def)

show "lossless_spmf (fdr_step (v - n) (c - n))" using I c by simp
show "?I (v', c')" if step: "(v', c') ∈ set_spmf (fdr_step (v - n) (c - n))" for v' c'
using that by(rule invar)(use I c in auto)
next
show "(0 :: real) < 1 / (2 * n)" using that by(simp)
show "?I (v, c)" using that by simp
qed
show ?thesis using assms True
by(auto simp add: fast_dice_roll_alt[symmetric] fast_dice_roll_alt_conv_while intro!: loop dest: invar[of _ _ "n + v" "n + c", simplified])
next
case False
with assms have "v = n" by simp
thus ?thesis using assms by(subst fast_dice_roll.simps) simp
qed

lemma fast_dice_roll_n0:
assumes "n = 0"
shows "fast_dice_roll v c = return_pmf None"
by(induction arbitrary: v c rule: fast_dice_roll_fixp_induct)(simp_all add: assms)

lemma lossless_fast_uniform [simp]: "lossless_spmf fast_uniform ⟷ n > 0"
proof(cases "n = 0")
case True
then show ?thesis using fast_dice_roll_n0 unfolding fast_uniform_def by(simp)
next
case False
then show ?thesis by(simp add: fast_uniform_def lossless_fast_dice_roll)
qed

lemma spmf_fast_uniform: "spmf fast_uniform x = (if x < n then 1 / n else 0)"
proof(cases "n > 0")
case n: True
show ?thesis using spmf_fast_uniform_ub
proof(rule spmf_ub_tight)
have "(∑⇧+ x. ennreal (if x < n then 1 / n else 0)) = (∑⇧+ x∈{..<n}. 1 / n)"
by(auto simp add: nn_integral_count_space_indicator simp del: nn_integral_const intro: nn_integral_cong)
also have "… = 1" using n by(simp add: field_simps ennreal_of_nat_eq_real_of_nat ennreal_mult[symmetric])
also have "… = weight_spmf fast_uniform" using lossless_fast_uniform n unfolding lossless_spmf_def by simp
finally show "(∑⇧+ x. ennreal (if x < n then 1 / n else 0)) = …" .
qed
next
case False
with fast_dice_roll_n0[of 1 0] show ?thesis unfolding fast_uniform_def by(simp)
qed

end

lemma fast_uniform_conv_uniform: "fast_uniform n = spmf_of_set {..<n}"

end


# Theory Resampling

(* Title: Resampling.thy
Author: Andreas Lochbihler, ETH Zurich *)

theory Resampling imports
While_SPMF
begin

lemma ord_spmf_lossless:
assumes "ord_spmf (=) p q" "lossless_spmf p"
shows "p = q"
unfolding pmf.rel_eq[symmetric] using assms(1)
by(rule pmf.rel_mono_strong)(use assms(2) in ‹auto elim!: ord_option.cases simp add: lossless_iff_set_pmf_None›)

context notes [[function_internals]] begin

partial_function (spmf) resample :: "'a set ⇒ 'a set ⇒ 'a spmf" where
"resample A B = bind_spmf (spmf_of_set A) (λx. if x ∈ B then return_spmf x else resample A B)"

end

lemmas resample_fixp_induct[case_names adm bottom step] = resample.fixp_induct

context
fixes A :: "'a set"
and B :: "'a set"
begin

interpretation loop_spmf "λx. x ∉ B" "λ_. spmf_of_set A" .

lemma resample_conv_while: "resample A B = bind_spmf (spmf_of_set A) while"
proof(induction rule: parallel_fixp_induct_2_1[OF partial_function_definitions_spmf partial_function_definitions_spmf resample.mono while.mono resample_def while_def, case_names adm bottom step])
case adm show ?case by simp
case bottom show ?case by simp
case (step resample' while') then show ?case by(simp add: z3_rule(33) cong del: if_cong)
qed

context
assumes A: "finite A"
and B: "B ⊆ A" "B ≠ {}"
begin

private lemma A_nonempty: "A ≠ {}"
using B by blast

private lemma B_finite: "finite B"
using A B by(blast intro: finite_subset)

lemma lossless_resample: "lossless_spmf (resample A B)"
proof -
from B have [simp]: "A ∩ B ≠ {}" by auto
have "lossless_spmf (while x)" for x
by(rule termination_0_1_immediate[where p="card (A ∩ B) / card A"])
(simp_all add: spmf_map vimage_def measure_spmf_of_set field_simps A_nonempty A not_le card_gt_0_iff B)
then show ?thesis by(clarsimp simp add: resample_conv_while A A_nonempty)
qed

lemma resample_le_sample:
"ord_spmf (=) (resample A B) (spmf_of_set B)"
proof(induction rule: resample_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step resample')
note [simp] = B_finite A
show ?case
proof(rule ord_pmf_increaseI)
fix x
let ?f = "λx. if x ∈ B then return_spmf x else resample' A B"
have "spmf (bind_spmf (spmf_of_set A) ?f) x =
(∑n∈B ∪ (A - B). if n ∈ B then (if n = x then 1 else 0) / card A else spmf (resample' A B) x / card A)"
using B
by(auto simp add: spmf_bind integral_spmf_of_set sum_divide_distrib if_distrib[where f="λp. spmf p _ / _"] cong: if_cong intro!: sum.cong split: split_indicator_asm)
also have "… = (∑n∈B. (if n = x then 1 else 0) / card A) + (∑n∈A - B. spmf (resample' A B) x / card A)"
by(subst sum.union_disjoint)(auto)
also have "… = (if x ∈ B then 1 / card A else 0) + card (A - B) / card A * spmf (resample' A B) x"
by(simp cong: sum.cong add: if_distrib[where f="λx. x / _"] cong: if_cong)
also have "… ≤ (if x ∈ B then 1 / card A else 0) + card (A - B) / card A * spmf (spmf_of_set B) x"
by(intro add_left_mono mult_left_mono step.IH[THEN ord_spmf_eq_leD]) simp
also have "… = spmf (spmf_of_set B) x" using B
by(simp add: spmf_of_set field_simps A_nonempty card_Diff_subset card_mono of_nat_diff)
finally show "spmf (bind_spmf (spmf_of_set A) ?f) x ≤ …" .
qed simp
qed

lemma resample_eq_sample: "resample A B = spmf_of_set B"
using resample_le_sample lossless_resample by(rule ord_spmf_lossless)

end

end

end`