# Theory Cyclic_Group_Ext

theory Cyclic_Group_Ext imports
CryptHOL.CryptHOL
"HOL-Number_Theory.Cong"
begin

context cyclic_group begin

lemma generator_pow_order: "❙g [^] order G = 𝟭"
proof(cases "order G > 0")
case True
hence fin: "finite (carrier G)" by(simp add: order_gt_0_iff_finite)
then have [symmetric]: "(λx. x ⊗ ❙g)  carrier G = carrier G"
then have "carrier G = (λ n. ❙g [^] Suc n)  {..<order G}"
using fin by(simp add: carrier_conv_generator image_image)
then obtain n where n: "𝟭 = ❙g [^] Suc n" "n < order G" by auto
have "n = order G - 1" using n inj_onD[OF inj_on_generator, of 0 "Suc n"] by fastforce
with True n show ?thesis by auto
qed simp

lemma pow_generator_mod: "❙g [^] (k mod order G) = ❙g [^] k"
proof(cases "order G > 0")
case True
obtain n where n: "k = n * order G + k mod order G" by (metis div_mult_mod_eq)
have "❙g [^] k = (❙g [^] order G) [^] n ⊗ ❙g [^] (k mod order G)"
by(subst n)(simp add: nat_pow_mult nat_pow_pow mult_ac)
then show ?thesis by(simp add: generator_pow_order)
qed simp

lemma int_nat_pow:
assumes "a ≥ 0"
shows "(❙g [^] (int (a ::nat))) [^] (b::int)  = ❙g [^] (a*b)"
using assms
proof(cases "a > 0")
case True
show ?thesis
using int_pow_pow by blast
next case False
have "(❙g [^] (int (a ::nat))) [^] (b::int) = 𝟭" using False by simp
also have "❙g [^] (a*b) = 𝟭" using False by simp
ultimately show ?thesis by simp
qed

lemma pow_generator_mod_int: "❙g [^] ((k :: int) mod order G) = ❙g [^] k"
proof(cases "order G > 0")
case True
obtain n :: int where n: "k = order G * n + k mod order G"
by (metis div_mult_mod_eq mult.commute)
then have "❙g [^] k = ❙g [^] (order G * n) ⊗ ❙g [^] (k mod order G)"
using int_pow_mult nat_pow_mult by (metis generator_closed)
then have "❙g [^] k = (❙g [^] order G) [^] n ⊗ ❙g [^] (k mod order G)"
using int_nat_pow by (simp add: int_pow_int)
then show ?thesis by(simp add: generator_pow_order)
qed simp

lemma pow_gen_mod_mult:
shows"(❙g [^] (a::nat) ⊗ ❙g [^] (b::nat)) [^] ((c::int)* int (d::nat)) = (❙g [^] a ⊗ ❙g [^] b) [^] ((c*int d) mod (order G))"
proof-
have "(❙g [^] (a::nat) ⊗ ❙g [^] (b::nat)) ∈ carrier G" by simp
then obtain n :: nat where n: "❙g [^] n = (❙g [^] (a::nat) ⊗ ❙g [^] (b::nat))"
also obtain r where r: "r = c*int d" by simp
have "(❙g [^] (a::nat) ⊗ ❙g [^] (b::nat)) [^] ((c::int)*int (d::nat)) = (❙g [^] n) [^] r" using n r by simp
moreover have"... = (❙g [^] n) [^] (r mod (order G))" using pow_generator_mod_int pow_generator_mod
by (metis int_nat_pow int_pow_int mod_mult_right_eq zero_le)
moreover have "... =  (❙g [^] a ⊗ ❙g [^] b) [^] ((c*int d) mod (order G))" using r n by simp
ultimately show ?thesis by simp
qed

lemma pow_generator_eq_iff_cong:
"finite (carrier G) ⟹ ❙g [^] x = ❙g [^] y ⟷ [x = y] (mod order G)"
by(subst (1 2) pow_generator_mod[symmetric])(auto simp add: cong_def order_gt_0_iff_finite intro: inj_onD[OF inj_on_generator])

lemma cyclic_group_commute:
assumes "a ∈ carrier G" "b ∈ carrier G"
shows "a ⊗ b = b ⊗ a"
(is "?lhs = ?rhs")
proof-
obtain n :: nat where n: "a = ❙g [^] n" using generatorE assms by auto
also  obtain k :: nat where k: "b = ❙g [^] k" using generatorE assms by auto
ultimately have "?lhs =  ❙g [^] n ⊗ ❙g [^] k" by simp
then have "... = ❙g [^] (n + k)" by(simp add: nat_pow_mult)
then have "... = ❙g [^] (k + n)" by(simp add: add.commute)
then show ?thesis by(simp add: nat_pow_mult n k)
qed

lemma cyclic_group_assoc:
assumes "a ∈ carrier G" "b ∈ carrier G" "c ∈ carrier G"
shows "(a ⊗ b) ⊗ c = a ⊗ (b ⊗ c)"
(is "?lhs = ?rhs")
proof-
obtain n :: nat where n: "a = ❙g [^] n" using generatorE assms by auto
obtain k :: nat where k: "b = ❙g [^] k" using generatorE assms by auto
obtain j :: nat where j: "c = ❙g [^] j" using generatorE assms by auto
have "?lhs = (❙g [^] n ⊗ ❙g [^] k) ⊗ ❙g [^] j" using n k j by simp
then have "... = ❙g [^] (n + (k + j))" by(simp add: nat_pow_mult add.assoc)
then show ?thesis by(simp add: nat_pow_mult n k j)
qed

lemma l_cancel_inv:
assumes "h ∈ carrier G"
shows "(❙g [^] (a :: nat) ⊗ inv (❙g [^] a)) ⊗ h = h"
(is "?lhs = ?rhs")
proof-
have "?lhs = (❙g [^] int a ⊗ inv (❙g [^] int a)) ⊗ h" by simp
then have "... = (❙g [^] int a ⊗ (❙g [^] (- a))) ⊗ h" using int_pow_neg[symmetric] by simp
then have "... = ❙g [^] (int a - a)  ⊗ h" by(simp add: int_pow_mult)
then have "... = ❙g [^] ((0:: int)) ⊗ h" by simp
then show ?thesis by (simp add: assms)
qed

lemma inverse_split:
assumes "a ∈ carrier G" and "b ∈ carrier G"
shows "inv (a ⊗ b) = inv a ⊗ inv b"
by (simp add:  assms comm_group.inv_mult cyclic_group_commute group_comm_groupI)

lemma inverse_pow_pow:
assumes "a ∈ carrier G"
shows "inv (a [^] (r::nat)) = (inv a) [^] r"
proof -
have "a [^] r ∈ carrier G"
using assms by blast
then show ?thesis
qed

lemma l_neq_1_exp_neq_0:
assumes "l ∈ carrier G"
and "l ≠ 𝟭"
and "l = ❙g [^] (t::nat)"
shows "t ≠ 0"
proof(rule ccontr)
assume "¬ (t ≠ 0)"
hence "t = 0" by simp
hence "❙g [^] t = 𝟭" by simp
then show "False" using assms by simp
qed

lemma order_gt_1_gen_not_1:
assumes "order G > 1"
shows "❙g ≠ 𝟭"
proof(rule ccontr)
assume "¬ ❙g ≠ 𝟭"
hence "❙g = 𝟭" by simp
hence g_pow_eq_1: "❙g [^] n = 𝟭" for n :: nat by simp
hence "range (λn :: nat. ❙g [^] n) = {𝟭}" by auto
hence "carrier G ⊆ {𝟭}" using generator by auto
hence "order G < 1"
by (metis One_nat_def assms g_pow_eq_1 inj_onD inj_on_generator lessThan_iff not_gr_zero zero_less_Suc)
with assms show "False" by simp
qed

lemma power_swap: "((❙g [^] (α0::nat)) [^] (r::nat)) = ((❙g [^] r) [^] α0)"
(is "?lhs = ?rhs")
proof-
have "?lhs = ❙g [^] (α0 * r)"
using nat_pow_pow mult.commute by auto
hence "... = ❙g [^] (r * α0)"
by(metis mult.commute)
thus ?thesis using nat_pow_pow by auto
qed

end

end

# Theory Number_Theory_Aux

theory Number_Theory_Aux imports
"HOL-Number_Theory.Cong"
"HOL-Number_Theory.Residues"
begin

lemma bezw_inverse:
assumes "gcd (e :: nat) (N ::nat) = 1"
shows "[nat e * nat ((fst (bezw e N)) mod N) = 1] (mod nat N)"
proof-
have "(fst (bezw e N) * e + snd (bezw e N) * N) mod N = 1 mod N"
by (metis assms bezw_aux zmod_int)
hence "(fst (bezw e N) mod N * e mod N) = 1 mod N"
hence cong_eq: "[(fst (bezw e N) mod N * e) = 1] (mod N)"
by (metis of_nat_1 zmod_int cong_def)
hence "[nat (fst (bezw e N) mod N) * e = 1] (mod N)"
proof -
{ assume "int (nat (fst (bezw e N) mod int N)) ≠ fst (bezw e N) mod int N"
have "N = 0 ⟶ 0 ≤ fst (bezw e N) mod int N"
by fastforce
then have "int (nat (fst (bezw e N) mod int N)) = fst (bezw e N) mod int N"
by fastforce }
then have "[int (nat (fst (bezw e N) mod int N) * e) = int 1] (mod int N)"
by (metis cong_eq of_nat_1 of_nat_mult)
then show ?thesis
using cong_int_iff by blast
qed
then show ?thesis by(simp add: mult.commute)
qed

lemma inverse:
assumes "gcd x (q::nat) = 1"
and "q > 0"
shows "[x * (fst (bezw x q)) = 1] (mod q)"
proof-
have int_eq: "fst (bezw  x q) * x + snd (bezw x q) * int q = 1"
by (metis assms(1) bezw_aux of_nat_1)
hence int_eq': "(fst (bezw  x q) * x + snd (bezw x q) * int q) mod q = 1 mod q"
by (metis of_nat_1 zmod_int)
hence "(fst (bezw x q) * x) mod q = 1 mod q"
by simp
hence "[(fst (bezw x q)) * x  = 1] (mod q)"
using cong_def int_eq int_eq' by metis
then show ?thesis by(simp add: mult.commute)
qed

lemma prod_not_prime:
assumes "prime (x::nat)"
and "prime y"
and "x > 2"
and "y > 2"
shows "¬ prime ((x-1)*(y-1))"
by (metis assms One_nat_def Suc_diff_1 nat_neq_iff numeral_2_eq_2 prime_gt_0_nat prime_product)

lemma ex_inverse:
assumes coprime: "coprime (e :: nat) ((P-1)*(Q-1))"
and "prime P"
and "prime Q"
and "P ≠ Q"
shows "∃ d. [e*d = 1] (mod (P-1)) ∧ d ≠ 0"
proof-
have "coprime e (P-1)"
using assms(1) by simp
then obtain d where d: "[e*d = 1] (mod (P-1))"
using cong_solve_coprime_nat by auto
then show ?thesis by (metis cong_0_1_nat cong_1 mult_0_right zero_neq_one)
qed

lemma ex_k1_k2:
assumes coprime: "coprime (e :: nat) ((P-1)*(Q-1))"
and "[e*d = 1] (mod (P-1))"
shows "∃ k1 k2. e*d + k1*(P-1) = 1 + k2*(P-1)"
by (metis assms(2) cong_iff_lin_nat)

lemma ex_k_mod:
assumes coprime: "coprime (e :: nat) ((P-1)*(Q-1))"
and "P ≠ Q"
and "prime P"
and "prime Q"
and "d ≠ 0"
and " [e*d = 1] (mod (P-1))"
shows "∃ k. e*d = 1 + k*(P-1)"
proof-
have "e > 0"
using assms(1) assms(2) prime_gt_0_nat by fastforce
then have "e*d ≥ 1" using assms by simp
then obtain k where k: "e*d = 1 + k*(P-1)"
using assms(6) cong_to_1'_nat by auto
then show ?thesis
by simp
qed

lemma fermat_little:
assumes "prime (P :: nat)"
shows "[x^P = x] (mod P)"
proof(cases "P dvd x")
case True
hence "x mod P = 0" by simp
moreover have "x ^ P mod P = 0"
by (simp add: True assms prime_dvd_power_nat_iff prime_gt_0_nat)
ultimately show ?thesis
next
case False
hence "[x ^ (P - 1) = 1] (mod P)"
using fermat_theorem assms by blast
then show ?thesis
by (metis assms cong_def diff_diff_cancel diff_is_0_eq' diff_zero mod_mult_right_eq power_eq_if power_one_right prime_ge_1_nat zero_le_one)
qed

end

# Theory Uniform_Sampling

section ‹Uniform Sampling›

text‹Here we prove different one time pad lemmas based on uniform sampling we require throughout our proofs.›

theory Uniform_Sampling
imports
CryptHOL.Cyclic_Group_SPMF
"HOL-Number_Theory.Cong"
CryptHOL.List_Bits
begin

text ‹If q is a prime we can sample from the units.›

definition sample_uniform_units :: "nat ⇒ nat spmf"
where "sample_uniform_units q = spmf_of_set ({..< q} - {0})"

lemma set_spmf_sampl_uni_units [simp]: "set_spmf (sample_uniform_units q) = {..< q} - {0}"

lemma lossless_sample_uniform_units:
assumes "q > 1"
shows "lossless_spmf (sample_uniform_units q)"
using assms by auto

text ‹General lemma for mapping using uniform sampling from units.›

assumes inj_on: "inj_on f ({..<q} - {0})"
and sur: "f  ({..<q} - {0}) = ({..<q} - {0})"
shows "map_spmf f (sample_uniform_units q) = (sample_uniform_units q)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set (({..<q} - {0}))"
also have "map_spmf(λs. f s) (spmf_of_set ({..<q} - {0})) = spmf_of_set ((λs. f s)  ({..<q} - {0}))"
also have "f  ({..<q} - {0}) = ({..<q} - {0})"
ultimately show ?thesis using rhs by simp
qed

text ‹General lemma for mapping using uniform sampling.›

assumes inj_on: "inj_on f {..<q}"
and sur: "f  {..<q} = {..<q}"
shows "map_spmf f (sample_uniform q) = (sample_uniform q)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set ({..< q})"
also have "map_spmf(λs. f s) (spmf_of_set {..<q}) = spmf_of_set ((λs. f s)  {..<q})"
also have "f  {..<q} = {..<q}"
ultimately show ?thesis using rhs by simp
qed

assumes x:  "x < q"
and x': "x' < q"
and map: "((y :: nat) + x) mod q = (y + x') mod q"
shows "x = x'"
proof-
have aa: "((y :: nat) + x) mod q = (y + x') mod q ⟹ x mod q = x' mod q"
proof-
have 4: "((y:: nat) + x) mod q = (y + x') mod q ⟹ [((y:: nat) + x) = (y + x')] (mod q)"
have 5: "[((y:: nat) + x) = (y + x')] (mod q) ⟹ [x = x'] (mod q)"
have 6: "[x = x'] (mod q) ⟹ x mod q = x' mod q"
then show ?thesis by(simp add: map 4 5 6)
qed
also have bb: "x mod q = x' mod q ⟹ x = x'"
ultimately show ?thesis by(simp add: map)
qed

lemma inj_uni_samp_add: "inj_on (λ(b :: nat). (y + b) mod q ) {..<q}"

lemma surj_uni_samp:
assumes inj: "inj_on  (λ(b :: nat). (y + b) mod q ) {..<q}"
shows "(λ(b :: nat). (y + b) mod q)  {..< q} =  {..< q}"
apply(rule endo_inj_surj) using inj by auto

shows "map_spmf (λb. (y + b) mod q) (sample_uniform q) = (sample_uniform q)"

text ‹The multiplicaton map case.›

lemma inj_mult:
assumes coprime: "coprime x (q::nat)"
and y: "y < q"
and y': "y' < q"
and map: "x * y mod q = x * y' mod q"
shows "y = y'"
proof-
have "x*y mod q = x*y' mod q ⟹ y mod q = y' mod q"
proof-
have "x*y mod q = x*y' mod q ⟹ [x*y = x*y'] (mod q)"
also have "[x*y = x*y'] (mod q) = [y = y'] (mod q)"
also have "[y = y'] (mod q) ⟹ y mod q = y' mod q"
ultimately show ?thesis by(simp add: map)
qed
also have "y mod q = y' mod q ⟹ y = y'"
ultimately show ?thesis by(simp add: map)
qed

lemma inj_on_mult:
assumes coprime: "coprime x (q::nat)"
shows "inj_on (λ b. x*b mod q) {..<q}"
using coprime by(simp only: inj_mult)

lemma surj_on_mult:
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (λ b. x*b mod q) {..<q}"
shows "(λ b. x*b mod q)  {..< q} = {..< q}"
apply(rule endo_inj_surj) using coprime inj by auto

assumes coprime: "coprime x q"
shows "map_spmf (λ b. x*b mod q) (sample_uniform q) = (sample_uniform q)"
using inj_on_mult surj_on_mult one_time_pad coprime by simp

text ‹The multiplication map for sampling from units.›

lemma inj_on_mult_units:
assumes 1: "coprime x (q::nat)" shows "inj_on (λ b. x*b mod q) ({..<q} - {0})"
using 1 by(simp only: inj_mult)

lemma surj_on_mult_units:
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (λ b. x*b mod q) ({..<q} - {0})"
shows "(λ b. x*b mod q)  ({..<q} - {0}) = ({..<q} - {0})"
proof(rule endo_inj_surj)
show "finite ({..<q} - {0})" using coprime inj by(simp)
show "(λb. x * b mod q)  ({..<q} - {0}) ⊆ {..<q} - {0}"
proof -
obtain n :: "nat set ⇒ (nat ⇒ nat) ⇒ nat set ⇒ nat" where
"∀x0 x1 x2. (∃v3. v3 ∈ x2 ∧ x1 v3 ∉ x0) = (n x0 x1 x2 ∈ x2 ∧ x1 (n x0 x1 x2) ∉ x0)"
by moura
then have subset: "∀N f Na. n Na f N ∈ N ∧ f (n Na f N) ∉ Na ∨ f  N ⊆ Na"
by (meson image_subsetI)
have mem_insert: "x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∉ {..<q} ∨ x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ insert 0 {..<q}"
by force
have map_eq: "(x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ insert 0 {..<q} - {0}) = (x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ {..<q} - {0})"
by simp
{ assume "x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q = x * 0 mod q"
then have "(0 ≤ q) = (0 = q) ∨ (n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∉ {..<q} ∨ n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∈ {0}) ∨ n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∉ {..<q} - {0} ∨ x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ {..<q} - {0}"
by (metis antisym_conv1 insertCI lessThan_iff local.coprime inj_mult) }
moreover
{ assume "0 ≠ x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q"
moreover
{ assume "x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ insert 0 {..<q} ∧ x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∉ {0}"
then have "(λn. x * n mod q)  ({..<q} - {0}) ⊆ {..<q} - {0}"
using map_eq subset by (meson Diff_iff) }
ultimately have "(λn. x * n mod q)  ({..<q} - {0}) ⊆ {..<q} - {0} ∨ (0 ≤ q) = (0 = q)"
using mem_insert by (metis antisym_conv1 lessThan_iff mod_less_divisor singletonD) }
ultimately have "(λn. x * n mod q)  ({..<q} - {0}) ⊆ {..<q} - {0} ∨ n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∉ {..<q} - {0} ∨ x * n ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ {..<q} - {0}"
by force
then show "(λn. x * n mod q)  ({..<q} - {0}) ⊆ {..<q} - {0}"
using subset by meson
qed
show "inj_on (λb. x * b mod q) ({..<q} - {0})" using assms by(simp)
qed

assumes coprime: "coprime x q"
shows "map_spmf (λ b. x*b mod q) (sample_uniform_units q) = sample_uniform_units q"
using inj_on_mult_units surj_on_mult_units one_time_pad_units coprime by simp

assumes coprime: "coprime x (q::nat)"
and xa: "xa < q"
and ya: "ya < q"
and map: "(y + x * xa) mod q = (y + x * ya) mod q"
shows "xa = ya"
proof-
have "(y + x * xa) mod q = (y + x * ya) mod q ⟹ xa mod q = ya mod q"
proof-
have "(y + x * xa) mod q = (y + x * ya) mod q ⟹ [y + x*xa = y + x *ya] (mod q)"
using cong_def by blast
also have "[y + x*xa = y + x *ya] (mod q) ⟹ [xa = ya] (mod q)"
ultimately show ?thesis by(simp add: cong_def map)
qed
also have "xa mod q = ya mod q ⟹ xa = ya"
ultimately show ?thesis by(simp add: map)
qed

assumes coprime: "coprime x (q::nat)"
shows "inj_on (λ b. (y + x*b) mod q) {..<q}"

lemma surj_on_add_mult: assumes coprime: "coprime x (q::nat)" and inj: "inj_on (λ b. (y + x*b) mod q) {..<q}"
shows "(λ b. (y + x*b) mod q)  {..< q} = {..< q}"
apply(rule endo_inj_surj) using coprime inj by auto

shows "map_spmf (λ b. (y + x*b) mod q) (sample_uniform q) = (sample_uniform q)"

text ‹Subtraction Map.›

lemma inj_minus:
assumes x: "(x :: nat) < q"
and ya: "ya < q"
and map: "(y + q - x) mod q = (y + q - ya) mod q"
shows  "x = ya"
proof-
have "(y + q - x) mod q = (y + q - ya) mod q ⟹ x mod q = ya mod q"
proof-
have "(y + q - x) mod q = (y + q - ya) mod q ⟹ [y + q - x = y + q - ya] (mod q)"
using cong_def by blast
moreover have "[y + q - x = y + q - ya] (mod q) ⟹ [q - x = q - ya] (mod q)"
using x ya cong_add_lcancel_nat by fastforce
moreover have "[y + q - x = y + q - ya] (mod q) ⟹ [q + x = q + ya] (mod q)"
ultimately show ?thesis
qed
moreover have "x mod q = ya mod q ⟹ x = ya"
ultimately show ?thesis by(simp add: map)
qed

lemma inj_on_minus: "inj_on  (λ(b :: nat). (y + (q - b)) mod q ) {..<q}"

lemma surj_on_minus:
assumes inj: "inj_on  (λ(b :: nat). (y + (q - b)) mod q ) {..<q}"
shows "(λ(b :: nat). (y + (q - b)) mod q)  {..< q} = {..< q}"
apply(rule endo_inj_surj)
using inj by auto

shows "map_spmf(λ b. (y + (q - b)) mod q) (sample_uniform q) = (sample_uniform q)"
using inj_on_minus surj_on_minus one_time_pad by simp

lemma not_coin_flip: "map_spmf (λ a. ¬ a) coin_spmf = coin_spmf"
proof-
have "inj_on Not {True, False}"
by simp
also have  "Not  {True, False} = {True, False}"
by auto
qed

lemma xor_uni_samp: "map_spmf(λ b. y ⊕ b) (coin_spmf) = map_spmf(λ b. b) (coin_spmf)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set {True, False}"
also have "map_spmf(λ b. y ⊕ b) (spmf_of_set {True, False}) = spmf_of_set((λ b. y ⊕ b)  {True, False})"
also have "(λ b. xor y b)  {True, False} = {True, False}"
using xor_def by auto
finally show ?thesis using rhs by(simp)
qed

end

# Theory Semi_Honest_Def

section ‹Semi-Honest Security›

text ‹We follow the security definitions for the semi honest setting as described in \cite{DBLP:books/sp/17/Lindell17}.
In the semi honest model the parties are assumed not to deviate from the protocol transcript.
Semi honest security guarantees that no information is leaked during the running of the protocol.›

subsection ‹Security definitions›

theory Semi_Honest_Def imports
CryptHOL.CryptHOL
begin

subsubsection ‹Security for deterministic functionalities›

locale sim_det_def =
fixes R1 :: "'msg1 ⇒ 'msg2 ⇒ 'view1 spmf"
and S1  :: "'msg1 ⇒ 'out1 ⇒ 'view1 spmf"
and R2 :: "'msg1 ⇒ 'msg2 ⇒ 'view2 spmf"
and S2  :: "'msg2 ⇒ 'out2 ⇒ 'view2 spmf"
and funct :: "'msg1 ⇒ 'msg2 ⇒ ('out1 × 'out2) spmf"
and protocol :: "'msg1 ⇒ 'msg2 ⇒ ('out1 × 'out2) spmf"
assumes lossless_R1: "lossless_spmf (R1 m1 m2)"
and lossless_S1: "lossless_spmf (S1 m1 out1)"
and lossless_R2: "lossless_spmf (R2 m1 m2)"
and lossless_S2: "lossless_spmf (S2 m2 out2)"
and lossless_funct: "lossless_spmf (funct m1 m2)"
begin

type_synonym 'view' adversary_det = "'view' ⇒ bool spmf"

definition "correctness m1 m2 ≡ (protocol m1 m2 = funct m1 m2)"

where "adv_P1 m1 m2 D ≡ ¦(spmf (R1 m1 m2 ⤜ D) True)
- spmf (funct m1 m2 ⤜ (λ (o1, o2). S1 m1 o1 ⤜ D)) True¦"

definition "perfect_sec_P1 m1 m2 ≡ (R1 m1 m2 = funct m1 m2 ⤜ (λ (s1, s2). S1 m1 s1))"

where "adv_P2 m1 m2 D = ¦spmf (R2 m1 m2 ⤜ (λ view. D view)) True
- spmf (funct m1 m2 ⤜ (λ (o1, o2). S2 m2 o2 ⤜ (λ view. D view))) True¦"

definition "perfect_sec_P2 m1 m2 ≡ (R2 m1 m2 = funct m1 m2 ⤜ (λ (s1, s2). S2 m2 s2))"

text ‹We also define the security games (for Party 1 and 2) used in EasyCrypt to define semi honest security for Party 1.
We then show the two definitions are equivalent.›

definition P1_game_alt :: "'msg1 ⇒ 'msg2 ⇒ 'view1 adversary_det ⇒ bool spmf"
where "P1_game_alt m1 m2 D = do {
b ← coin_spmf;
(out1, out2) ← funct m1 m2;
rview :: 'view1 ← R1 m1 m2;
sview :: 'view1 ← S1 m1 out1;
b' ← D (if b then rview else sview);
return_spmf (b = b')}"

where "adv_P1_game m1 m2 D = ¦2*(spmf (P1_game_alt m1 m2 D ) True) - 1¦"

text ‹We show the two definitions are equivalent›

lemma equiv_defs_P1:
assumes lossless_D: "∀ view. lossless_spmf ((D:: 'view1 adversary_det) view)"
proof-
have return_True_not_False: "spmf (return_spmf (b)) True = spmf (return_spmf (¬ b)) False"
for b by(cases b; auto)
have lossless_ideal: "lossless_spmf ((funct m1 m2 ⤜ (λ(out1, out2). S1 m1 out1 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b'))))))"
by(simp add: lossless_S1 lossless_funct lossless_weight_spmfD split_def lossless_D)
have return: "spmf (funct m1 m2 ⤜ (λ(o1, o2). S1 m1 o1 ⤜ D)) True
= spmf (funct m1 m2 ⤜ (λ(o1, o2). S1 m1 o1 ⤜ (λ view. D view ⤜ (λ b. return_spmf b)))) True"
by simp
have "2*(spmf (P1_game_alt m1 m2 D ) True) - 1 = (spmf (R1 m1 m2 ⤜ (λrview. D rview ⤜ (λ(b':: bool). return_spmf (True = b'))))) True
- (1 - (spmf (funct m1 m2 ⤜ (λ(out1, out2). S1 m1 out1 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b')))))) True)"
UNIV_bool bind_spmf_const lossless_R1 lossless_S1 lossless_funct lossless_weight_spmfD)
hence "adv_P1_game m1 m2 D = ¦(spmf (R1 m1 m2 ⤜ (λrview. D rview ⤜ (λ(b':: bool). return_spmf (True = b'))))) True
- (1 - (spmf (funct m1 m2 ⤜ (λ(out1, out2). S1 m1 out1 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b')))))) True)¦"
also have "¦(spmf (R1 m1 m2 ⤜ (λrview. D rview ⤜ (λ(b':: bool). return_spmf (True = b'))))) True
- (1 - (spmf (funct m1 m2 ⤜ (λ(out1, out2). S1 m1 out1 ⤜ (λsview. D sview
⤜ (λb'. return_spmf (False = b')))))) True)¦ = adv_P1 m1 m2 D"
apply(simp only: adv_P1_def spmf_False_conv_True[symmetric] lossless_ideal; simp)
by(simp only: return)(simp only: split_def spmf_bind return_True_not_False)
ultimately show ?thesis by simp
qed

definition P2_game_alt :: "'msg1 ⇒ 'msg2 ⇒ 'view2 adversary_det ⇒ bool spmf"
where "P2_game_alt m1 m2 D = do {
b ← coin_spmf;
(out1, out2) ← funct m1 m2;
rview :: 'view2 ← R2 m1 m2;
sview :: 'view2 ← S2 m2 out2;
b' ← D (if b then rview else sview);
return_spmf (b = b')}"

where "adv_P2_game m1 m2 D = ¦2*(spmf (P2_game_alt m1 m2 D ) True) - 1¦"

lemma equiv_defs_P2:
assumes lossless_D: "∀ view. lossless_spmf ((D:: 'view2 adversary_det) view)"
proof-
have return_True_not_False: "spmf (return_spmf (b)) True = spmf (return_spmf (¬ b)) False"
for b by(cases b; auto)
have lossless_ideal: "lossless_spmf ((funct m1 m2 ⤜ (λ(out1, out2). S2 m2 out2 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b'))))))"
by(simp add: lossless_S2 lossless_funct lossless_weight_spmfD split_def lossless_D)
have return: "spmf (funct m1 m2 ⤜ (λ(o1, o2). S2 m2 o2 ⤜ D)) True = spmf (funct m1 m2 ⤜ (λ(o1, o2). S2 m2 o2 ⤜ (λ view. D view ⤜ (λ b. return_spmf b)))) True"
by simp
have
"2*(spmf (P2_game_alt m1 m2 D ) True) - 1 = (spmf (R2 m1 m2 ⤜ (λrview. D rview ⤜ (λ(b':: bool). return_spmf (True = b'))))) True
- (1 - (spmf (funct m1 m2 ⤜ (λ(out1, out2). S2 m2 out2 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b')))))) True)"
UNIV_bool bind_spmf_const lossless_R2 lossless_S2 lossless_funct lossless_weight_spmfD)
hence "adv_P2_game m1 m2 D = ¦(spmf (R2 m1 m2 ⤜ (λrview. D rview ⤜ (λ(b':: bool). return_spmf (True = b'))))) True
- (1 - (spmf (funct m1 m2 ⤜ (λ(out1, out2). S2 m2 out2 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b')))))) True)¦"
also have "¦(spmf (R2 m1 m2 ⤜ (λrview. D rview ⤜ (λ(b':: bool). return_spmf (True = b'))))) True
- (1 - (spmf (funct m1 m2 ⤜ (λ(out1, out2). S2 m2 out2 ⤜ (λsview. D sview ⤜ (λb'. return_spmf (False = b')))))) True)¦ = adv_P2 m1 m2 D"
apply(simp only: adv_P2_def spmf_False_conv_True[symmetric] lossless_ideal; simp)
by(simp only: return)(simp only: split_def spmf_bind return_True_not_False)
ultimately show ?thesis by simp
qed

end

subsubsection ‹ Security definitions for non deterministic functionalities ›

locale sim_non_det_def =
fixes R1 :: "'msg1 ⇒ 'msg2 ⇒ ('view1 × ('out1 × 'out2)) spmf"
and S1  :: "'msg1 ⇒ 'out1 ⇒ 'view1 spmf"
and Out1 :: "'msg1 ⇒ 'msg2 ⇒ 'out1 ⇒ ('out1 × 'out2) spmf" ― ‹takes the input of the other party so can form the outputs of parties›
and R2 :: "'msg1 ⇒ 'msg2 ⇒ ('view2 × ('out1 × 'out2)) spmf"
and S2  :: "'msg2 ⇒ 'out2 ⇒ 'view2 spmf"
and Out2 :: "'msg2 ⇒ 'msg1 ⇒ 'out2 ⇒ ('out1 × 'out2) spmf"
and funct :: "'msg1 ⇒ 'msg2 ⇒ ('out1 × 'out2) spmf"
begin

type_synonym ('view', 'out1', 'out2') adversary_non_det = "('view' × ('out1' × 'out2')) ⇒ bool spmf"

definition Ideal1 :: "'msg1 ⇒ 'msg2 ⇒ 'out1  ⇒ ('view1 × ('out1 × 'out2)) spmf"
where "Ideal1 m1 m2 out1 = do {
view1 :: 'view1 ← S1 m1 out1;
out1 ← Out1 m1 m2 out1;
return_spmf (view1, out1)}"

definition Ideal2 :: "'msg2 ⇒ 'msg1 ⇒ 'out2 ⇒ ('view2 × ('out1 × 'out2)) spmf"
where "Ideal2 m2 m1 out2 = do {
view2 :: 'view2 ← S2 m2 out2;
out2 ← Out2 m2 m1 out2;
return_spmf (view2, out2)}"

definition adv_P1 :: "'msg1 ⇒ 'msg2 ⇒ ('view1, 'out1, 'out2) adversary_non_det ⇒ real"
where "adv_P1 m1 m2 D ≡ ¦(spmf (R1 m1 m2 ⤜ (λ view. D view)) True) - spmf (funct m1 m2 ⤜ (λ (o1, o2). Ideal1 m1 m2 o1 ⤜ (λ view. D view))) True¦"

definition "perfect_sec_P1 m1 m2 ≡ (R1 m1 m2 = funct m1 m2 ⤜ (λ (s1, s2). Ideal1 m1 m2 s1))"

definition adv_P2 :: "'msg1 ⇒ 'msg2 ⇒ ('view2, 'out1, 'out2) adversary_non_det ⇒ real"
where "adv_P2 m1 m2 D = ¦spmf (R2 m1 m2 ⤜ (λ view. D view)) True - spmf (funct m1 m2 ⤜ (λ (o1, o2). Ideal2 m2 m1 o2 ⤜ (λ view. D view))) True¦"

definition "perfect_sec_P2 m1 m2 ≡ (R2 m1 m2 = funct m1 m2 ⤜ (λ (s1, s2). Ideal2 m2 m1 s2))"

end

subsubsection ‹ Secret sharing schemes ›

locale secret_sharing_scheme =
fixes share :: "'input_out ⇒ ('share × 'share) spmf"
and reconstruct :: "('share × 'share) ⇒ 'input_out spmf"
and F :: "('input_out ⇒ 'input_out ⇒ 'input_out spmf) set"
begin

definition "sharing_correct input ≡ (share input ⤜ (λ (s1,s2). reconstruct (s1,s2)) = return_spmf input)"

definition "correct_share_eval input1 input2 ≡ (∀ gate_eval ∈ F.
∃ gate_protocol :: ('share × 'share) ⇒ ('share × 'share) ⇒ ('share × 'share) spmf.
share input1 ⤜ (λ (s1,s2). share input2
⤜ (λ (s3,s4). gate_protocol (s1,s3) (s2,s4)
⤜ (λ (S1,S2). reconstruct (S1,S2)))) = gate_eval input1 input2)"

end

end



# Theory OT_Functionalities

subsection ‹Oblivious Transfer functionalities›

text‹Here we define the functionalities for 1-out-of-2 and 1-out-of-4 OT.›

theory OT_Functionalities imports
CryptHOL.CryptHOL
begin

definition funct_OT_12 :: "('a ×  'a) ⇒ bool ⇒ (unit × 'a) spmf"
where "funct_OT_12 input⇩1 σ = return_spmf (() , if σ then (snd input⇩1) else (fst input⇩1))"

lemma lossless_funct_OT_12: "lossless_spmf (funct_OT_12 msgs σ)"

definition funct_OT_14 :: "('a × 'a × 'a × 'a) ⇒ (bool × bool) ⇒ (unit × 'a) spmf"
where "funct_OT_14 M C = do {
let (c0,c1) = C;
let (m00, m01, m10, m11) = M;
return_spmf ((),if c0 then (if c1 then m11 else m10) else (if c1 then m01 else m00))}"

lemma lossless_funct_14_OT: "lossless_spmf (funct_OT_14 M C)"

end

# Theory ETP

subsection ‹ ETP definitions ›

text ‹ We define Extended Trapdoor Permutations (ETPs) following \cite{DBLP:books/sp/17/Lindell17} and \cite{DBLP:books/cu/Goldreich2004}.
In particular we consider the property of Hard Core Predicates (HCPs). ›

theory ETP imports
CryptHOL.CryptHOL
begin

type_synonym ('index,'range) dist2 = "(bool × 'index × bool × bool) ⇒ bool spmf"

type_synonym ('index,'range) advP2 = "'index ⇒ bool ⇒ bool ⇒ ('index,'range) dist2 ⇒ 'range ⇒ bool spmf"

locale etp =
fixes I :: "('index × 'trap) spmf" ― ‹samples index and trapdoor›
and domain :: "'index ⇒ 'range set"
and range :: "'index ⇒ 'range set"
and F :: "'index ⇒ ('range ⇒ 'range)" ― ‹permutation›
and F⇩i⇩n⇩v :: "'index ⇒ 'trap ⇒ 'range ⇒ 'range" ― ‹must be efficiently computable›
and B :: "'index ⇒ 'range ⇒ bool" ― ‹hard core predicate›
assumes dom_eq_ran: "y ∈ set_spmf I ⟶ domain (fst y) = range (fst y)"
and finite_range: "y ∈ set_spmf I ⟶ finite (range (fst y))"
and non_empty_range: "y ∈ set_spmf I ⟶ range (fst y) ≠ {}"
and bij_betw: "y ∈ set_spmf I ⟶ bij_betw (F (fst y)) (domain (fst y)) (range (fst y))"
and lossless_I: "lossless_spmf I"
and F_f_inv: "y ∈ set_spmf I ⟶ x ∈ range (fst y) ⟶ F⇩i⇩n⇩v (fst y) (snd y) (F (fst y) x) = x"
begin

definition S :: "'index ⇒ 'range spmf"
where "S α = spmf_of_set (range α)"

lemma lossless_S: "y ∈ set_spmf I ⟶  lossless_spmf (S (fst y))"
by(simp add: lossless_spmf_def S_def finite_range non_empty_range)

lemma set_spmf_S [simp]: "y ∈ set_spmf I ⟶ set_spmf (S (fst y)) = range (fst y)"

lemma f_inj_on: "y ∈ set_spmf I ⟶ inj_on (F (fst y)) (range (fst y))"
by(metis bij_betw_def bij_betw dom_eq_ran bij_betw_def bij_betw dom_eq_ran)

lemma range_f: "y ∈ set_spmf I ⟶  x ∈ range (fst y) ⟶ F (fst y) x ∈ range (fst y)"
by (metis bij_betw bij_betw dom_eq_ran bij_betwE)

lemma f_inv_f [simp]: "y ∈ set_spmf I ⟶ x ∈ range (fst y) ⟶ F⇩i⇩n⇩v (fst y) (snd y) (F (fst y) x) = x"
by (metis bij_betw bij_betw_inv_into_left dom_eq_ran F_f_inv)

lemma f_inv_f' [simp]: "y ∈ set_spmf I ⟶ x ∈ range (fst y) ⟶ Hilbert_Choice.inv_into (range (fst y)) (F (fst y)) (F (fst y) x) = x"
by (metis bij_betw bij_betw_inv_into_left bij_betw dom_eq_ran)

lemma B_F_inv_rewrite: "(B α (F⇩i⇩n⇩v α τ y⇩σ') = (B α (F⇩i⇩n⇩v α τ y⇩σ') = m1)) = m1"
by auto

lemma uni_set_samp:
assumes "y ∈ set_spmf I"
shows "map_spmf (λ x. F (fst y) x) (S (fst y)) = (S (fst y))"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set (range (fst y))"
unfolding S_def by(simp)
also have "map_spmf (λ x. F (fst y) x) (spmf_of_set (range (fst y))) = spmf_of_set ((λ x. F (fst y) x)  (range (fst y)))"
using f_inj_on assms
by (metis map_spmf_of_set_inj_on)
also have "(λ x. F (fst y) x)  (range (fst y)) = range (fst y)"
apply(rule endo_inj_surj)
using bij_betw
by (auto simp add: bij_betw_def dom_eq_ran f_inj_on bij_betw finite_range assms)
finally show ?thesis by(simp add: rhs)
qed

text‹We define the security property of the hard core predicate (HCP) using a game.›

definition HCP_game :: "('index,'range) advP2 ⇒  bool ⇒ bool ⇒ ('index,'range) dist2 ⇒ bool spmf"
where "HCP_game A = (λ σ b⇩σ D. do {
(α, τ) ← I;
x ← S α;
b' ← A α σ b⇩σ D x;
let b = B α (F⇩i⇩n⇩v α τ x);
return_spmf (b = b')})"

definition "HCP_adv A σ b⇩σ D = ¦((spmf (HCP_game A σ b⇩σ D) True) - 1/2)¦"

end

end



# Theory ETP_OT

subsection ‹ Oblivious transfer constructed from ETPs ›

text‹Here we construct the OT protocol based on ETPs given in \cite{DBLP:books/sp/17/Lindell17} (Chapter 4) and prove
semi honest security for both parties. We show information theoretic security for Party 1 and reduce the security of
Party 2 to the HCP assumption.›

theory ETP_OT imports
"HOL-Number_Theory.Cong"
ETP
OT_Functionalities
Semi_Honest_Def
begin

type_synonym 'range viewP1 = "((bool × bool) × 'range × 'range) spmf"
type_synonym 'range dist1 = "((bool × bool) × 'range × 'range) ⇒ bool spmf"
type_synonym 'index viewP2 = "(bool × 'index × (bool × bool)) spmf"
type_synonym 'index dist2 = "(bool × 'index × bool × bool) ⇒ bool spmf"
type_synonym ('index, 'range) advP2 = "'index ⇒ bool ⇒ bool ⇒ 'index dist2 ⇒ 'range ⇒ bool spmf"

lemma if_False_True: "(if x then False else ¬ False) ⟷ (if x then False else True)"
by simp

lemma if_then_True [simp]: "(if b then True else x) ⟷ (¬ b ⟶ x)"
by simp

lemma if_else_True [simp]: "(if b then x else True) ⟷ (b ⟶ x)"
by simp

lemma inj_on_Not [simp]: "inj_on Not A"

locale ETP_base = etp: etp I domain range F F⇩i⇩n⇩v B
for I :: "('index × 'trap) spmf" ― ‹samples index and trapdoor›
and domain :: "'index ⇒ 'range set"
and range :: "'index ⇒ 'range set"
and B :: "'index ⇒ 'range ⇒ bool" ― ‹hard core predicate›
and F :: "'index ⇒ 'range ⇒ 'range"
and F⇩i⇩n⇩v :: "'index ⇒ 'trap ⇒ 'range ⇒ 'range"
begin

text‹The probabilistic program that defines the protocol.›

definition protocol :: "(bool × bool) ⇒ bool ⇒ (unit × bool) spmf"
where "protocol input⇩1 σ = do {
let (b⇩σ, b⇩σ') = input⇩1;
(α :: 'index, τ :: 'trap) ← I;
x⇩σ :: 'range ← etp.S α;
y⇩σ' :: 'range ← etp.S α;
let (y⇩σ :: 'range) = F α x⇩σ;
let (x⇩σ :: 'range) = F⇩i⇩n⇩v α τ y⇩σ;
let (x⇩σ' :: 'range) = F⇩i⇩n⇩v α τ y⇩σ';
let (β⇩σ :: bool) = xor (B α x⇩σ) b⇩σ;
let (β⇩σ' :: bool) = xor (B α x⇩σ') b⇩σ';
return_spmf ((), if σ then xor (B α x⇩σ') β⇩σ' else xor (B α x⇩σ) β⇩σ)}"

lemma correctness: "protocol (m0,m1) c = funct_OT_12 (m0,m1) c"
proof-
have "(B α (F⇩i⇩n⇩v α τ y⇩σ') = (B α (F⇩i⇩n⇩v α τ y⇩σ') = m1)) = m1"
for α τ y⇩σ'  by auto
then show ?thesis
by(auto simp add: protocol_def funct_OT_12_def Let_def etp.B_F_inv_rewrite bind_spmf_const etp.lossless_S local.etp.lossless_I lossless_weight_spmfD split_def cong: bind_spmf_cong)
qed

text ‹ Party 1 views ›

definition R1 :: "(bool × bool) ⇒ bool ⇒ 'range viewP1"
where "R1 input⇩1 σ = do {
let (b⇩0, b⇩1) = input⇩1;
(α, τ) ← I;
x⇩σ ← etp.S α;
y⇩σ' ← etp.S α;
let y⇩σ = F α x⇩σ;
return_spmf ((b⇩0, b⇩1), if σ then y⇩σ' else y⇩σ, if σ then y⇩σ else y⇩σ')}"

lemma lossless_R1: "lossless_spmf (R1 msgs σ)"
by(simp add: R1_def local.etp.lossless_I split_def etp.lossless_S Let_def)

definition S1 :: "(bool × bool) ⇒ unit ⇒ 'range viewP1"
where "S1 == (λ input⇩1 (). do {
let (b⇩0, b⇩1) = input⇩1;
(α, τ) ← I;
y⇩0 :: 'range ← etp.S α;
y⇩1 ← etp.S α;
return_spmf ((b⇩0, b⇩1), y⇩0, y⇩1)})"

lemma lossless_S1: "lossless_spmf (S1 msgs ())"
by(simp add: S1_def local.etp.lossless_I split_def etp.lossless_S)

text ‹ Party 2 views ›

definition R2 :: "(bool × bool) ⇒ bool ⇒ 'index viewP2"
where "R2 msgs σ = do {
let (b0,b1) = msgs;
(α, τ) ← I;
x⇩σ ← etp.S α;
y⇩σ' ← etp.S α;
let y⇩σ = F α x⇩σ;
let x⇩σ = F⇩i⇩n⇩v α τ y⇩σ;
let x⇩σ' = F⇩i⇩n⇩v α τ y⇩σ';
let β⇩σ = (B α x⇩σ) ⊕ (if σ then b1 else b0) ;
let β⇩σ' = (B α x⇩σ') ⊕ (if σ then b0 else b1);
return_spmf (σ, α,(β⇩σ, β⇩σ'))}"

lemma lossless_R2: "lossless_spmf (R2 msgs σ)"
by(simp add: R2_def split_def local.etp.lossless_I etp.lossless_S)

definition S2 :: "bool ⇒ bool ⇒ 'index viewP2"
where "S2 σ b⇩σ = do {
(α, τ) ← I;
x⇩σ ← etp.S α;
y⇩σ' ← etp.S α;
let x⇩σ' = F⇩i⇩n⇩v α τ y⇩σ';
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
let β⇩σ' = B α x⇩σ';
return_spmf (σ, α, (β⇩σ, β⇩σ'))}"

lemma lossless_S2: "lossless_spmf (S2 σ b⇩σ)"
by(simp add: S2_def local.etp.lossless_I etp.lossless_S split_def)

text ‹ Security for Party 1 ›

text‹We have information theoretic security for Party 1.›

lemma P1_security: "R1 input⇩1 σ = funct_OT_12 x y ⤜ (λ (s1, s2). S1 input⇩1 s1)"
proof-
have "R1 input⇩1 σ =  do {
let (b0,b1) = input⇩1;
(α, τ) ← I;
y⇩σ' :: 'range ← etp.S α;
y⇩σ ← map_spmf (λ x⇩σ. F α x⇩σ) (etp.S α);
return_spmf ((b0,b1), if σ then y⇩σ' else y⇩σ, if σ then y⇩σ else y⇩σ')}"
by(simp add: bind_map_spmf o_def Let_def R1_def)
also have "... = do {
let (b0,b1) = input⇩1;
(α, τ) ← I;
y⇩σ' :: 'range ← etp.S α;
y⇩σ ← etp.S α;
return_spmf ((b0,b1), if σ then y⇩σ' else y⇩σ, if σ then y⇩σ else y⇩σ')}"
by(simp add: etp.uni_set_samp Let_def split_def cong: bind_spmf_cong)
also have "... = funct_OT_12 x y ⤜ (λ (s1, s2). S1 input⇩1 s1)"
by(cases σ; simp add: S1_def R1_def Let_def funct_OT_12_def)
ultimately show ?thesis by auto
qed

text ‹ The adversary used in proof of security for party 2 ›

definition 𝒜 :: "('index, 'range) advP2"
where "𝒜 α σ b⇩σ D2 x = do {
β⇩σ' ← coin_spmf;
x⇩σ ← etp.S α;
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
d ← D2(σ, α, β⇩σ, β⇩σ');
return_spmf(if d then β⇩σ' else ¬ β⇩σ')}"

lemma lossless_𝒜:
assumes "∀ view. lossless_spmf (D2 view)"
shows "y ∈ set_spmf I ⟶  lossless_spmf (𝒜 (fst y) σ b⇩σ D2 x)"

lemma assm_bound_funct_OT_12:
assumes "etp.HCP_adv 𝒜 σ (if σ then b1 else b0) D ≤ HCP_ad"
shows "¦spmf (funct_OT_12 (b0,b1) σ ⤜ (λ (out1,out2).
etp.HCP_game 𝒜 σ out2 D)) True - 1/2¦ ≤ HCP_ad"
proof-
have "?lhs = ¦spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D) True - 1/2¦"
thus ?thesis using assms etp.HCP_adv_def by simp
qed

lemma assm_bound_funct_OT_12_collapse:
shows "¦spmf (funct_OT_12 m1 σ ⤜ (λ (out1,out2). etp.HCP_game 𝒜 σ out2 D)) True - 1/2¦ ≤ HCP_ad"
using assm_bound_funct_OT_12 surj_pair assms by metis

text ‹ To prove security for party 2 we split the proof on the cases on party 2's input ›

lemma R2_S2_False:
assumes "((if σ then b0 else b1) = False)"
shows "spmf (R2 (b0,b1) σ ⤜ (D2 :: (bool × 'index × bool × bool) ⇒ bool spmf)) True
= spmf (funct_OT_12 (b0,b1) σ ⤜ (λ (out1,out2). S2 σ out2 ⤜ D2)) True"
proof-
have "σ ⟹ ¬ b0" using assms by simp
moreover have "¬ σ ⟹ ¬ b1" using assms by simp
ultimately show ?thesis
by(auto simp add: R2_def S2_def split_def local.etp.F_f_inv assms funct_OT_12_def cong: bind_spmf_cong_simp)
qed

lemma R2_S2_True:
assumes "((if σ then b0 else b1) = True)"
and lossless_D: "∀ a. lossless_spmf (D2 a)"
shows "¦(spmf (bind_spmf (R2 (b0,b1) σ) D2) True) - spmf (funct_OT_12 (b0,b1) σ ⤜ (λ (out1, out2). S2 σ out2 ⤜ (λ view. D2 view))) True¦
= ¦2*((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)¦"
proof-
have  "(spmf (funct_OT_12 (b0,b1) σ ⤜ (λ (out1, out2). S2 σ out2 ⤜ D2)) True
- spmf (bind_spmf (R2 (b0,b1) σ) D2) True)
= 2 * ((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)"
proof-
have  "((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)  =
1/2*(spmf (bind_spmf (S2 σ (if σ then b1 else b0)) D2) True
- spmf (bind_spmf (R2 (b0,b1) σ) D2) True)"
proof-
have σ_true_b0_true: "σ ⟹ b0 = True" using assms(1) by simp
have σ_false_b1_true: "¬ σ ⟹ b1" using assms(1) by simp
have return_True_False: "spmf (return_spmf (¬ d)) True = spmf (return_spmf d) False"
for d by(cases d; simp)
define HCP_game_true where "HCP_game_true == λ σ b⇩σ. do {
(α, τ) ← I;
x⇩σ ← etp.S α;
x ← (etp.S α);
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
let β⇩σ' = B α (F⇩i⇩n⇩v α τ x);
d ← D2(σ, α, β⇩σ, β⇩σ');
let b' = (if d then β⇩σ' else ¬ β⇩σ');
let b = B α (F⇩i⇩n⇩v α τ x);
return_spmf (b = b')}"
define HCP_game_false where "HCP_game_false == λ σ b⇩σ. do {
(α, τ) ← I;
x⇩σ ← etp.S α;
x ← (etp.S α);
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
let β⇩σ' = ¬ B α (F⇩i⇩n⇩v α τ x);
d ← D2(σ, α, β⇩σ, β⇩σ');
let b' = (if d then β⇩σ' else ¬ β⇩σ');
let b = B α (F⇩i⇩n⇩v α τ x);
return_spmf (b = b')}"
define HCP_game_𝒜 where "HCP_game_𝒜 == λ σ b⇩σ. do {
β⇩σ' ← coin_spmf;
(α, τ) ← I;
x ← etp.S α;
x' ← etp.S α;
d ← D2 (σ, α, (B α x) ⊕ b⇩σ, β⇩σ');
let b' = (if d then  β⇩σ' else ¬ β⇩σ');
return_spmf (B α (F⇩i⇩n⇩v α τ x') = b')}"
define S2D where "S2D == λ σ b⇩σ . do {
(α, τ) ← I;
x⇩σ ← etp.S α;
y⇩σ' ← etp.S α;
let x⇩σ' = F⇩i⇩n⇩v α τ y⇩σ';
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
let β⇩σ' = B α x⇩σ';
d :: bool ← D2(σ, α, β⇩σ, β⇩σ');
return_spmf d}"
define R2D where "R2D == λ msgs σ.  do {
let (b0,b1) = msgs;
(α, τ) ← I;
x⇩σ ← etp.S α;
y⇩σ' ← etp.S α;
let y⇩σ = F α x⇩σ;
let x⇩σ = F⇩i⇩n⇩v α τ y⇩σ;
let x⇩σ' = F⇩i⇩n⇩v α τ y⇩σ';
let β⇩σ = (B α x⇩σ) ⊕ (if σ then b1 else b0) ;
let β⇩σ' = (B α x⇩σ') ⊕ (if σ then b0 else b1);
b :: bool ← D2(σ, α,(β⇩σ, β⇩σ'));
return_spmf b}"
define D_true where "D_true  == λσ b⇩σ. do {
(α, τ) ← I;
x⇩σ ← etp.S α;
x ← (etp.S α);
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
let β⇩σ' = B α (F⇩i⇩n⇩v α τ x);
d :: bool ← D2(σ, α, β⇩σ, β⇩σ');
return_spmf d}"
define D_false where "D_false == λ σ b⇩σ. do {
(α, τ) ← I;
x⇩σ ← etp.S α;
x ← etp.S α;
let β⇩σ = (B α x⇩σ) ⊕ b⇩σ;
let β⇩σ' = ¬ B α (F⇩i⇩n⇩v α τ x);
d :: bool ← D2(σ, α, β⇩σ, β⇩σ');
return_spmf d}"
have lossless_D_false: "lossless_spmf (D_false σ (if σ then b1 else b0))"
apply(auto simp add: D_false_def lossless_D local.etp.lossless_I)
using local.etp.lossless_S by auto
have "spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True =  spmf (HCP_game_𝒜 σ (if σ then b1 else b0)) True"
apply(simp add: etp.HCP_game_def HCP_game_𝒜_def 𝒜_def split_def etp.F_f_inv)
by(rewrite bind_commute_spmf[where q = "coin_spmf"]; rewrite bind_commute_spmf[where q = "coin_spmf"]; rewrite bind_commute_spmf[where q = "coin_spmf"]; auto)+
also have "... = spmf (bind_spmf (map_spmf Not coin_spmf) (λb. if b then HCP_game_true σ (if σ then b1 else b0) else HCP_game_false σ (if σ then b1 else b0))) True"
unfolding HCP_game_𝒜_def HCP_game_true_def HCP_game_false_def 𝒜_def Let_def
apply(subst if_distrib[where f = "bind_spmf _" for f, symmetric]; simp cong: bind_spmf_cong add: if_distribR )+
apply(rewrite in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑"  in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(fold map_spmf_conv_bind_spmf)
apply(rule conjI; rule impI; simp)
apply(simp only: spmf_bind)
apply(rule Bochner_Integration.integral_cong[OF refl])+
apply clarify
subgoal for r r⇩σ α τ
apply(simp only: UNIV_bool spmf_of_set integral_spmf_of_set)
apply(simp cong: if_cong split del: if_split)
apply(cases "B r (F⇩i⇩n⇩v r r⇩σ τ)")
by auto
apply(rewrite in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑"  in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(simp only: spmf_bind)
apply(rule Bochner_Integration.integral_cong[OF refl])+
apply clarify
subgoal for r r⇩σ α τ
apply(simp only: UNIV_bool spmf_of_set integral_spmf_of_set)
apply(simp cong: if_cong split del: if_split)
apply(cases " B r (F⇩i⇩n⇩v r r⇩σ τ)")
by auto
done
also have "... = 1/2*(spmf (HCP_game_true σ (if σ then b1 else b0)) True) + 1/2*(spmf (HCP_game_false σ (if σ then b1 else b0)) True)"
by(simp add: spmf_bind UNIV_bool spmf_of_set integral_spmf_of_set)
also have "... = 1/2*(spmf (D_true σ (if σ then b1 else b0)) True) + 1/2*(spmf (D_false σ (if σ then b1 else b0)) False)"
proof-
have "spmf (I ⤜ (λ(α, τ). etp.S α ⤜ (λx⇩σ. etp.S α ⤜ (λx. D2 (σ, α, B α x⇩σ = (¬ (if σ then b1 else b0)), ¬ B α (F⇩i⇩n⇩v α τ x)) ⤜ (λd. return_spmf (¬ d)))))) True
= spmf (I ⤜ (λ(α, τ). etp.S α ⤜ (λx⇩σ. etp.S α ⤜ (λx. D2 (σ, α, B α x⇩σ = (¬ (if σ then b1 else b0)), ¬ B α (F⇩i⇩n⇩v α τ x)))))) False"
(is "?lhs = ?rhs")
proof-
have "?lhs = spmf (I ⤜ (λ(α, τ). etp.S α ⤜ (λx⇩σ. etp.S α ⤜ (λx. D2 (σ, α, B α x⇩σ = (¬ (if σ then b1 else b0)), ¬ B α (F⇩i⇩n⇩v α τ x)) ⤜ (λd. return_spmf (d)))))) False"
by(simp only: split_def return_True_False spmf_bind)
then show ?thesis by simp
qed
then show ?thesis  by(simp add: HCP_game_true_def HCP_game_false_def Let_def D_true_def D_false_def if_distrib[where f="(=) _"] cong: if_cong)
qed
also have "... =  1/2*((spmf (D_true σ (if σ then b1 else b0) ) True) + (1 - spmf (D_false σ (if σ then b1 else b0) ) True))"
also have "... = 1/2 + 1/2* (spmf (D_true σ (if σ then b1 else b0)) True) - 1/2*(spmf (D_false σ (if σ then b1 else b0)) True)"
by(simp)
also have "... =  1/2 + 1/2* (spmf (S2D σ (if σ then b1 else b0) ) True) - 1/2*(spmf (R2D (b0,b1) σ ) True)"
apply(auto  simp add: local.etp.F_f_inv S2D_def R2D_def D_true_def D_false_def  assms split_def cong: bind_spmf_cong_simp)
ultimately show ?thesis by(simp add: S2D_def R2D_def R2_def S2_def split_def)
qed
then show ?thesis by(auto simp add: funct_OT_12_def)
qed
thus ?thesis by simp
qed

assumes lossless_D: "∀ a. lossless_spmf (D2 a)"
shows "¦(spmf (bind_spmf (R2 (b0,b1) σ) D2) True) - spmf (funct_OT_12 (b0,b1) σ ⤜ (λ (out1, out2). S2 σ out2 ⤜ (λ view. D2 view))) True¦
≤ ¦2*((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)¦"
by(cases "(if σ then b0 else b1)"; auto simp add: R2_S2_False R2_S2_True assms)

sublocale OT_12: sim_det_def R1 S1 R2 S2 funct_OT_12 protocol
unfolding sim_det_def_def
by(simp add: lossless_R1 lossless_S1 lossless_R2 lossless_S2 funct_OT_12_def)

lemma correct: "OT_12.correctness m1 m2"
unfolding OT_12.correctness_def
by (metis prod.collapse correctness)

lemma P1_security_inf_the: "OT_12.perfect_sec_P1 m1 m2"
unfolding OT_12.perfect_sec_P1_def using P1_security by simp

lemma P2_security:
assumes "∀ a. lossless_spmf (D a)"
proof-
have "spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D) True = spmf (funct_OT_12 (b0,b1) σ ⤜ (λ (out1, out2). etp.HCP_game 𝒜 σ out2 D)) True"
for σ b0 b1
hence "OT_12.adv_P2 m1 m2 D ≤ ¦2*((spmf (funct_OT_12 m1 m2 ⤜ (λ (out1, out2). etp.HCP_game 𝒜 m2 out2 D)) True) - 1/2)¦"
moreover have "¦2*((spmf (funct_OT_12 m1 m2 ⤜ (λ (out1, out2). etp.HCP_game 𝒜 m2 out2 D)) True) - 1/2)¦ ≤ ¦2*HCP_ad¦"
proof -
have "(∃r. ¦(1::real) / r¦ ≠ 1 / ¦r¦) ∨ 2 / ¦1 / (spmf (funct_OT_12 m1 m2
⤜ (λ(x, y). ((λu b. etp.HCP_game 𝒜 m2 b D)::unit ⇒ bool ⇒ bool spmf) x y)) True - 1 / 2)¦
≤ HCP_ad / (1 / 2)"
using assm_bound_funct_OT_12_collapse assms by auto
then show ?thesis
by fastforce
qed
ultimately show ?thesis by argo
qed

end

text ‹ We also consider the asymptotic case for security proofs ›

locale ETP_sec_para =
fixes I :: "nat ⇒ ('index × 'trap) spmf"
and domain ::  "'index ⇒ 'range set"
and range ::  "'index ⇒ 'range set"
and f :: "'index ⇒ ('range ⇒ 'range)"
and F :: "'index ⇒ 'range ⇒ 'range"
and F⇩i⇩n⇩v :: "'index ⇒ 'trap ⇒ 'range ⇒ 'range"
and B :: "'index ⇒ 'range ⇒ bool"
assumes ETP_base: "⋀ n. ETP_base (I n) domain range F F⇩i⇩n⇩v"
begin

sublocale ETP_base "(I n)" domain range
using ETP_base  by simp

lemma correct_asym: "OT_12.correctness n m1 m2"

lemma P1_sec_asym: "OT_12.perfect_sec_P1 n m1 m2"
using P1_security_inf_the by simp

lemma P2_sec_asym:
assumes "∀ a. lossless_spmf (D a)"
shows "negligible (λ n. OT_12.adv_P2 n m1 m2 D)"
proof-
moreover have "¦OT_12.adv_P2 n m1 m2 D¦ = OT_12.adv_P2 n m1 m2 D" for n unfolding OT_12.adv_P2_def by simp
moreover have  "OT_12.adv_P2 n m1 m2 D ≤ 2 * etp_advantage n" for n using assms P2_security by blast
ultimately show ?thesis
using assms negligible_le HCP_adv_neg P2_security by presburger
qed

end

end

# Theory ETP_RSA_OT

subsubsection ‹ RSA instantiation ›

text‹It is known that the RSA collection forms an ETP. Here we instantitate our proof of security for OT
that uses a general ETP for RSA. We use the proof of the general construction of OT. The main proof effort
here is in showing the RSA collection meets the requirements of an ETP, mainly this involves showing the
RSA mapping is a bijection.›

theory ETP_RSA_OT imports
ETP_OT
Number_Theory_Aux
Uniform_Sampling
begin

type_synonym index = "(nat × nat)"
type_synonym trap = nat
type_synonym range = nat
type_synonym domain = nat
type_synonym viewP1 = "((bool × bool) × nat × nat) spmf"
type_synonym viewP2 = "(bool × index × (bool × bool)) spmf"
type_synonym dist2 = "(bool × index × bool × bool) ⇒ bool spmf"
type_synonym advP2 = "index ⇒ bool ⇒ bool ⇒ dist2 ⇒ bool spmf"

locale rsa_base =
fixes prime_set :: "nat set" ― ‹the set of primes used›
and B :: "index ⇒ nat ⇒ bool"
assumes prime_set_ass: "prime_set ⊆ {x. prime x ∧ x > 2}"
and finite_prime_set: "finite prime_set"
and prime_set_gt_2: "card prime_set > 2"
begin

lemma prime_set_non_empty: "prime_set ≠ {}"
using prime_set_gt_2 by auto

definition coprime_set :: "nat ⇒ nat set"
where "coprime_set N ≡ {x. coprime x N ∧ x > 1 ∧ x < N}"

lemma coprime_set_non_empty:
assumes "N > 2"
shows "coprime_set N ≠ {}"
by(simp add: coprime_set_def; metis assms(1) Suc_lessE coprime_Suc_right_nat lessI numeral_2_eq_2)

definition sample_coprime :: "nat ⇒ nat spmf"
where "sample_coprime N = spmf_of_set (coprime_set (N))"

lemma sample_coprime_e_gt_1:
assumes "e ∈ set_spmf (sample_coprime N)"
shows "e > 1"
using assms by(simp add: sample_coprime_def coprime_set_def)

lemma lossless_sample_coprime:
assumes "¬ prime N"
and "N > 2"
shows  "lossless_spmf (sample_coprime N)"
proof-
have "coprime_set N ≠ {}"
also have "finite (coprime_set N)"
ultimately show ?thesis by(simp add: sample_coprime_def)
qed

lemma set_spmf_sample_coprime:
shows "set_spmf (sample_coprime N) = {x. coprime x N ∧ x > 1 ∧ x < N}"

definition sample_primes :: "nat spmf"
where "sample_primes = spmf_of_set prime_set"

lemma lossless_sample_primes:
shows "lossless_spmf sample_primes"

lemma set_spmf_sample_primes:
shows "set_spmf sample_primes ⊆ {x. prime x ∧ x > 2}"
by(auto simp add: sample_primes_def prime_set_ass finite_prime_set)

lemma mem_samp_primes_gt_2:
shows "x ∈ set_spmf sample_primes ⟹ x > 2"
using prime_set_ass by blast

lemma mem_samp_primes_prime:
shows "x ∈ set_spmf sample_primes ⟹ prime x"
apply (simp add: finite_prime_set sample_primes_def prime_set_ass)
using prime_set_ass by blast

definition sample_primes_excl :: "nat set ⇒ nat spmf"
where "sample_primes_excl P = spmf_of_set (prime_set - P)"

lemma lossless_sample_primes_excl:
shows "lossless_spmf (sample_primes_excl {P})"
using prime_set_gt_2 subset_singletonD by fastforce

definition sample_set_excl :: "nat set ⇒ nat set ⇒ nat spmf"
where "sample_set_excl Q P = spmf_of_set (Q - P)"

lemma set_spmf_sample_set_excl [simp]:
assumes "finite (Q - P)"
shows "set_spmf (sample_set_excl Q P) = (Q - P)"
unfolding  sample_set_excl_def
by (metis set_spmf_of_set assms)+

lemma lossless_sample_set_excl:
assumes "finite Q"
and "card Q > 2"
shows "lossless_spmf (sample_set_excl Q {P})"
unfolding sample_set_excl_def
using assms subset_singletonD by fastforce

lemma mem_samp_primes_excl_gt_2:
shows "x ∈ set_spmf (sample_set_excl prime_set {y}) ⟹ x > 2"
apply(simp add: finite_prime_set sample_set_excl_def  prime_set_ass )
using prime_set_ass by blast

lemma mem_samp_primes_excl_prime :
shows "x ∈ set_spmf (sample_set_excl prime_set {y}) ⟹ prime x"
using prime_set_ass by blast

lemma sample_coprime_lem:
assumes "x ∈ set_spmf sample_primes"
and " y ∈ set_spmf (sample_set_excl prime_set {x}) "
shows "lossless_spmf (sample_coprime ((x - Suc 0) * (y - Suc 0)))"
proof-
have gt_2: "x > 2" "y > 2"
using mem_samp_primes_gt_2 assms mem_samp_primes_excl_gt_2 by auto
have "¬ prime ((x-1)*(y-1))"
proof-
have "prime x" "prime y"
using mem_samp_primes_prime mem_samp_primes_excl_prime assms by auto
then show ?thesis using prod_not_prime gt_2 by simp
qed
also have "((x-1)*(y-1)) > 2"
by (metis (no_types, lifting) gt_2 One_nat_def Suc_diff_1 assms(1) assms(2) calculation
divisors_zero less_2_cases nat_1_eq_mult_iff nat_neq_iff not_numeral_less_one numeral_2_eq_2
prime_gt_0_nat rsa_base.mem_samp_primes_excl_prime rsa_base.mem_samp_primes_prime rsa_base_axioms two_is_prime_nat)
ultimately show ?thesis using lossless_sample_coprime by simp
qed

definition I :: "(index × trap) spmf"
where "I = do {
P ← sample_primes;
Q ← sample_set_excl prime_set {P};
let N = P*Q;
let N' = (P-1)*(Q-1);
e ← sample_coprime N';
let d = nat ((fst (bezw e N')) mod N');
return_spmf ((N, e), d)}"

lemma lossless_I: "lossless_spmf I"
by(auto simp add: I_def lossless_sample_primes lossless_sample_set_excl finite_prime_set prime_set_gt_2 Let_def sample_coprime_lem)

lemma set_spmf_I_N:
assumes "((N,e),d) ∈ set_spmf I"
obtains P Q where "N = P * Q"
and "P ≠ Q"
and "prime P"
and "prime Q"
and "coprime e ((P - 1)*(Q - 1))"
and "d = nat (fst (bezw e ((P-1)*(Q-1))) mod int ((P-1)*(Q-1)))"
using assms apply(auto simp add: I_def Let_def)
using finite_prime_set mem_samp_primes_prime sample_set_excl_def rsa_base_axioms sample_primes_def

lemma set_spmf_I_e_d:
assumes "((N,e),d) ∈ set_spmf I"
shows "e > 1" and "d > 1"
using assms sample_coprime_e_gt_1
by (smt Euclidean_Division.pos_mod_sign Num.of_nat_simps(5) Suc_diff_1 bezw_inverse cong_def coprime_imp_gcd_eq_1 gr0I less_1_mult less_numeral_extra(2) mem_Collect_eq mod_by_0 mod_less more_arith_simps(6) nat_0 nat_0_less_mult_iff nat_int nat_neq_iff numerals(2) of_nat_0_le_iff of_nat_1 rsa_base.mem_samp_primes_gt_2 rsa_base_axioms set_spmf_sample_coprime zero_less_diff)

definition domain :: "index ⇒ nat set"
where "domain index ≡ {..< fst index}"

definition range :: "index ⇒ nat set"
where "range index ≡ {..< fst index}"

lemma finite_range: "finite (range index)"

lemma dom_eq_ran: "domain index = range index"

definition F :: "index ⇒ (nat ⇒ nat)"
where "F index x = x ^ (snd index) mod (fst index)"

definition F⇩i⇩n⇩v :: "index ⇒ trap ⇒ nat ⇒ nat"
where "F⇩i⇩n⇩v α τ y = y ^ τ mod (fst α)"

text ‹ We must prove the RSA function is a bijection ›

lemma rsa_bijection:
assumes coprime: "coprime e ((P-1)*(Q-1))"
and prime_P: "prime (P::nat)"
and prime_Q: "prime Q"
and P_neq_Q: "P ≠ Q"
and x_lt_pq: "x < P * Q"
and y_lt_pd: "y < P * Q"
and rsa_map_eq: "x ^ e mod (P * Q) = y ^ e mod (P * Q)"
shows "x = y"
proof-
have flt_xP: "[x ^ P = x] (mod P)"
using fermat_little prime_P by blast
have flt_yP: "[y ^ P = y] (mod P)"
using fermat_little prime_P by blast
have flt_xQ: "[x ^ Q = x] (mod Q)"
using fermat_little prime_Q by blast
have flt_yQ: "[y ^ Q = y] (mod Q)"
using fermat_little prime_Q by blast
show ?thesis
proof(cases "y ≥ x")
case True
hence ye_gt_xe: "y^e ≥ x^e"
have x_y_exp_e: "[x^e = y^e] (mod P)"
using cong_modulus_mult_nat  cong_altdef_nat True ye_gt_xe cong_sym cong_def assms by blast
obtain d where d:  "[e*d = 1] (mod (P-1)) ∧ d ≠ 0"
using ex_inverse assms by blast
then obtain k where k: "e*d = 1 + k*(P-1)"
using ex_k_mod assms by blast
hence xk_yk: "[x^(1 + k*(P-1)) = y^(1 + k*(P-1))] (mod P)"
by(metis k power_mult x_y_exp_e cong_pow)
have xk_x: "[x^(1 + k*(P-1)) = x] (mod P)"
proof(induct k)
case 0
then show ?case by simp
next
case (Suc k)
assume  asm: "[x ^ (1 + k * (P - 1)) = x] (mod P)"
then show ?case
proof-
have exp_rewrite: "(k * (P - 1) + P) = (1 + (k + 1) * (P - 1))"
have "[x * x ^ (k * (P - 1)) = x] (mod P)" using asm by simp
hence "[x ^ (k * (P - 1)) * x ^ P = x] (mod P)" using flt_xP
by (metis cong_scalar_right cong_trans mult.commute)
hence "[x ^ (k * (P - 1) + P) = x] (mod P)"
hence "[x ^ (1 + (k + 1) * (P - 1)) = x] (mod P)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have yk_y: "[y^(1 + k*(P-1)) = y] (mod P)"
proof(induct k)
case 0
then show ?case by simp
next
case (Suc k)
assume  asm: "[y ^ (1 + k * (P - 1)) = y] (mod P)"
then show ?case
proof-
have exp_rewrite: "(k * (P - 1) + P) = (1 + (k + 1) * (P - 1))"
have "[y * y ^ (k * (P - 1)) = y] (mod P)" using asm by simp
hence "[y ^ (k * (P - 1)) * y ^ P = y] (mod P)" using flt_yP
by (metis cong_scalar_right cong_trans mult.commute)
hence "[y ^ (k * (P - 1) + P) = y] (mod P)"
hence "[y ^ (1 + (k + 1) * (P - 1)) = y] (mod P)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have "[x^e = y^e] (mod Q)"
by (metis rsa_map_eq cong_modulus_mult_nat cong_def mult.commute)
obtain d' where d':  "[e*d' = 1] (mod (Q-1)) ∧ d' ≠ 0"
by (metis mult.commute ex_inverse prime_P prime_Q P_neq_Q coprime)
then obtain k' where k': "e*d' = 1 + k'*(Q-1)"
by(metis ex_k_mod mult.commute prime_P prime_Q P_neq_Q coprime)
hence xk_yk': "[x^(1 + k'*(Q-1)) = y^(1 + k'*(Q-1))] (mod Q)"
by(metis k' power_mult ‹[x ^ e = y ^ e] (mod Q)› cong_pow)
have xk_x': "[x^(1 + k'*(Q-1)) = x] (mod Q)"
proof(induct k')
case 0
then show ?case by simp
next
case (Suc k')
assume  asm: "[x ^ (1 + k' * (Q - 1)) = x] (mod Q)"
then show ?case
proof-
have exp_rewrite: "(k' * (Q - 1) + Q) = (1 + (k' + 1) * (Q - 1))"
have "[x * x ^ (k' * (Q - 1)) = x] (mod Q)" using asm by simp
hence "[x ^ (k' * (Q - 1)) * x ^ Q = x] (mod Q)" using flt_xQ
by (metis cong_scalar_right cong_trans mult.commute)
hence "[x ^ (k' * (Q - 1) + Q) = x] (mod Q)"
hence "[x ^ (1 + (k' + 1) * (Q - 1)) = x] (mod Q)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have yk_y': "[y^(1 + k'*(Q-1)) = y] (mod Q)"
proof(induct k')
case 0
then show ?case by simp
next
case (Suc k')
assume  asm: "[y ^ (1 + k' * (Q - 1)) = y] (mod Q)"
then show ?case
proof-
have exp_rewrite: "(k' * (Q - 1) + Q) = (1 + (k' + 1) * (Q - 1))"
have "[y * y ^ (k' * (Q - 1)) = y] (mod Q)" using asm by simp
hence "[y ^ (k' * (Q - 1)) * y ^ Q = y] (mod Q)" using flt_yQ
by (metis cong_scalar_right cong_trans mult.commute)
hence "[y ^ (k' * (Q - 1) + Q) = y] (mod Q)"
hence "[y ^ (1 + (k' + 1) * (Q - 1)) = y] (mod Q)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have P_dvd_xy: "P dvd (y - x)"
proof-
have "[x = y] (mod P)"
using xk_x yk_y xk_yk
thus ?thesis
using cong_altdef_nat cong_sym True by blast
qed
have Q_dvd_xy: "Q dvd (y - x)"
proof-
have "[x = y] (mod Q)"
using xk_x' yk_y' xk_yk'
thus ?thesis
using cong_altdef_nat cong_sym True by blast
qed
show ?thesis
proof-
have "P*Q dvd (y - x)" using P_dvd_xy  Q_dvd_xy
by (simp add: assms divides_mult primes_coprime)
then have "[x = y] (mod P*Q)"
by (simp add: cong_altdef_nat cong_sym True)
hence "x mod P*Q = y mod P*Q"
using  cong_def xk_x xk_yk yk_y by metis
then show ?thesis
using ‹[x = y] (mod P * Q)› cong_less_modulus_unique_nat x_lt_pq y_lt_pd by blast
qed
next
case False
hence ye_gt_xe: "x^e ≥ y^e"
have pow_eq: "[x^e = y^e] (mod P*Q)"
then have PQ_dvd_ye_xe: "(P*Q) dvd (x^e - y^e)"
using cong_altdef_nat False ye_gt_xe cong_sym by blast
then have "[x^e = y^e] (mod P)"
using cong_modulus_mult_nat pow_eq by blast
obtain d where d:  "[e*d = 1] (mod (P-1)) ∧ d ≠ 0" using ex_inverse assms
by blast
then obtain k where k: "e*d = 1 + k*(P-1)" using ex_k_mod assms by blast
have xk_yk: "[x^(1 + k*(P-1)) = y^(1 + k*(P-1))] (mod P)"
proof-
have "[(x^e)^d = (y^e)^d] (mod P)"
using ‹[x ^ e = y ^ e] (mod P)› cong_pow by blast
then have "[x^(e*d) = y^(e*d)] (mod P)"
thus ?thesis using k by simp
qed
have xk_x: "[x^(1 + k*(P-1)) = x] (mod P)"
proof(induct k)
case 0
then show ?case by simp
next
case (Suc k)
assume  asm: "[x ^ (1 + k * (P - 1)) = x] (mod P)"
then show ?case
proof-
have exp_rewrite: "(k * (P - 1) + P) = (1 + (k + 1) * (P - 1))"
have "[x * x ^ (k * (P - 1)) = x] (mod P)" using asm by simp
hence "[x ^ (k * (P - 1)) * x ^ P = x] (mod P)" using flt_xP
by (metis cong_scalar_right cong_trans mult.commute)
hence "[x ^ (k * (P - 1) + P) = x] (mod P)"
hence "[x ^ (1 + (k + 1) * (P - 1)) = x] (mod P)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have yk_y: "[y^(1 + k*(P-1)) = y] (mod P)"
proof(induct k)
case 0
then show ?case by simp
next
case (Suc k)
assume  asm: "[y ^ (1 + k * (P - 1)) = y] (mod P)"
then show ?case
proof-
have exp_rewrite: "(k * (P - 1) + P) = (1 + (k + 1) * (P - 1))"
have "[y * y ^ (k * (P - 1)) = y] (mod P)" using asm by simp
hence "[y ^ (k * (P - 1)) * y ^ P = y] (mod P)" using flt_yP
by (metis cong_scalar_right cong_trans mult.commute)
hence "[y ^ (k * (P - 1) + P) = y] (mod P)"
hence "[y ^ (1 + (k + 1) * (P - 1)) = y] (mod P)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have P_dvd_xy: "P dvd (x - y)"
proof-
have "[x = y] (mod P)" using xk_x yk_y xk_yk
thus ?thesis
using cong_altdef_nat cong_sym False by simp
qed
have "[x^e = y^e] (mod Q)"
using cong_modulus_mult_nat pow_eq PQ_dvd_ye_xe cong_dvd_modulus_nat dvd_triv_right by blast
obtain d' where d':  "[e*d' = 1] (mod (Q-1)) ∧ d' ≠ 0"
by (metis mult.commute ex_inverse prime_P prime_Q coprime P_neq_Q)
then obtain k' where k': "e*d' = 1 + k'*(Q-1)"
by(metis ex_k_mod mult.commute prime_P prime_Q coprime P_neq_Q)
have xk_yk': "[x^(1 + k'*(Q-1)) = y^(1 + k'*(Q-1))] (mod Q)"
proof-
have "[(x^e)^d' = (y^e)^d'] (mod Q)"
using ‹[x ^ e = y ^ e] (mod Q)› cong_pow by blast
then have "[x^(e*d') = y^(e*d')] (mod Q)"
thus ?thesis using k'
by simp
qed
have xk_x': "[x^(1 + k'*(Q-1)) = x] (mod Q)"
proof(induct k')
case 0
then show ?case by simp
next
case (Suc k')
assume  asm: "[x ^ (1 + k' * (Q - 1)) = x] (mod Q)"
then show ?case
proof-
have exp_rewrite: "(k' * (Q - 1) + Q) = (1 + (k' + 1) * (Q - 1))"
have "[x * x ^ (k' * (Q - 1)) = x] (mod Q)" using asm by simp
hence "[x ^ (k' * (Q - 1)) * x ^ Q = x] (mod Q)" using flt_xQ
by (metis cong_scalar_right cong_trans mult.commute)
hence "[x ^ (k' * (Q - 1) + Q) = x] (mod Q)"
hence "[x ^ (1 + (k' + 1) * (Q - 1)) = x] (mod Q)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have yk_y': "[y^(1 + k'*(Q-1)) = y] (mod Q)"
proof(induct k')
case 0
then show ?case by simp
next
case (Suc k')
assume  asm: "[y ^ (1 + k' * (Q - 1)) = y] (mod Q)"
then show ?case
proof-
have exp_rewrite: "(k' * (Q - 1) + Q) = (1 + (k' + 1) * (Q - 1))"
have "[y * y ^ (k' * (Q - 1)) = y] (mod Q)" using asm by simp
hence "[y ^ (k' * (Q - 1)) * y ^ Q = y] (mod Q)" using flt_yQ
by (metis cong_scalar_right cong_trans mult.commute)
hence "[y ^ (k' * (Q - 1) + Q) = y] (mod Q)"
hence "[y ^ (1 + (k' + 1) * (Q - 1)) = y] (mod Q)"
using exp_rewrite by argo
thus ?thesis by simp
qed
qed
have Q_dvd_xy: "Q dvd (x - y)"
proof-
have "[x = y] (mod Q)"
using xk_x' yk_y' xk_yk' by (simp add: cong_def)
thus ?thesis
using cong_altdef_nat cong_sym False by simp
qed
show ?thesis
proof-
have "P*Q dvd (x - y)"
using P_dvd_xy Q_dvd_xy by (simp add: assms divides_mult primes_coprime)
hence 1: "[x = y] (mod P*Q)"
using False cong_altdef_nat linear by blast
hence "x mod P*Q = y mod P*Q"
using cong_less_modulus_unique_nat x_lt_pq y_lt_pd by blast
thus ?thesis
using 1 cong_less_modulus_unique_nat x_lt_pq y_lt_pd by blast
qed
qed
qed

lemma rsa_bij_betw:
assumes "coprime e ((P - 1)*(Q - 1))"
and "prime P"
and "prime Q"
and "P ≠ Q"
shows "bij_betw (F ((P * Q), e)) (range ((P * Q), e)) (range ((P * Q), e))"
proof-
have PQ_not_0: "prime P ⟶ prime Q ⟶ P * Q ≠ 0"
using assms by auto
have "inj_on (λx. x ^ snd (P * Q, e) mod fst (P * Q, e)) {..<fst (P * Q, e)}"
using rsa_bijection assms by blast
moreover have "(λx. x ^ snd (P * Q, e) mod fst (P * Q, e))  {..<fst (P * Q, e)} = {..<fst (P * Q, e)}"
apply(simp add: assms(2) assms(3) prime_gt_0_nat PQ_not_0)
apply(rule endo_inj_surj; auto simp add: assms(2) assms(3) image_subsetI prime_gt_0_nat PQ_not_0 inj_on_def)
using rsa_bijection assms by blast
ultimately show ?thesis
unfolding bij_betw_def F_def range_def by blast
qed

lemma bij_betw1:
assumes "((N,e),d) ∈ set_spmf I"
shows "bij_betw (F ((N), e)) (range ((N), e)) (range ((N), e))"
proof-
obtain P Q where "N = P * Q" and "bij_betw (F ((P*Q), e)) (range ((P*Q), e)) (range ((P*Q), e))"
proof-
obtain P Q where "prime P" and "prime Q" and "N = P * Q" and "P ≠ Q" and "coprime e ((P - 1)*(Q - 1))"
using set_spmf_I_N assms by blast
then show ?thesis
using rsa_bij_betw that by blast
qed
thus ?thesis by blast
qed

lemma
assumes "P dvd x"
shows "[x = 0] (mod P)"
using assms cong_def by force

lemma rsa_inv:
assumes d: "d = nat (fst (bezw e ((P-1)*(Q-1))) mod int ((P-1)*(Q-1)))"
and coprime: "coprime e ((P-1)*(Q-1))"
and prime_P: "prime (P::nat)"
and prime_Q: "prime Q"
and P_neq_Q: "P ≠ Q"
and e_gt_1: "e > 1"
and d_gt_1: "d > 1"
shows "((((x) ^ e) mod (P*Q)) ^ d) mod (P*Q) = x mod (P*Q)"
proof(cases "x = 0 ∨ x = 1")
case True
then show ?thesis
by (metis assms(6) assms(7) le_simps(1) nat_power_eq_Suc_0_iff neq0_conv not_one_le_zero numeral_nat(7) power_eq_0_iff power_mod)
next
case False
hence x_gt_1: "x > 1" by simp
define z where "z = (x ^ e) ^ d - x"
hence z_gt_0: "z > 0"
proof-
have "(x ^ e) ^ d - x = x ^ (e * d) - x"
also have "... > 0"
by (metis x_gt_1 e_gt_1 d_gt_1 le_neq_implies_less less_one linorder_not_less n_less_m_mult_n not_less_zero numeral_nat(7) power_increasing_iff power_one_right zero_less_diff)
ultimately  show ?thesis using z_def by argo
qed
hence "[z = 0] (mod P)"
proof(cases "[x = 0] (mod P)")
case True
then show ?thesis
proof -
have "0 ≠ d * e"
by (metis (no_types) assms assms mult_is_0 not_one_less_zero)
then show ?thesis
by (metis (no_types) Groups.add_ac(2) True add_diff_inverse_nat cong_def cong_dvd_iff cong_mult_self_right dvd_0_right dvd_def dvd_trans mod_add_self2 more_arith_simps(5) nat_diff_split power_eq_if power_mult semiring_normalization_rules(7) z_def)
qed
next
case False
have "[e * d = 1] (mod ((P - 1) * (Q - 1)))"
by (metis d bezw_inverse coprime coprime_imp_gcd_eq_1 nat_int)
hence "[e * d = 1] (mod (P - 1))"
using assms cong_modulus_mult_nat by blast
then obtain k where k: "e*d = 1 + k*(P-1)"
using ex_k_mod assms by force
hence "x ^ (e * d) = x * ((x ^ (P - 1)) ^ k)"
by (metis power_add power_one_right mult.commute power_mult)
hence "[x ^ (e * d) = x * ((x ^ (P - 1)) ^ k)] (mod P)"
using cong_def by simp
moreover have "[x ^ (P - 1) = 1] (mod P)"
using prime_P fermat_theorem False
moreover have "[x ^ (e * d) = x * ((1) ^ k)] (mod P)"
by (metis ‹x ^ (e * d) = x * (x ^ (P - 1)) ^ k› calculation(2) cong_pow cong_scalar_left)
hence "[x ^ (e * d) = x] (mod P)" by simp
thus ?thesis using z_def z_gt_0
qed
moreover have "[z = 0] (mod Q)"
proof(cases "[x = 0] (mod Q)")
case True
then show ?thesis
proof -
have "0 ≠ d * e"
by (metis (no_types) assms mult_is_0 not_one_less_zero)
then show ?thesis
by (metis (no_types) Groups.add_ac(2) True add_diff_inverse_nat cong_def cong_dvd_iff cong_mult_self_right dvd_0_right dvd_def dvd_trans mod_add_self2 more_arith_simps(5) nat_diff_split power_eq_if power_mult semiring_normalization_rules(7) z_def)
qed
next
case False
have "[e * d = 1] (mod ((P - 1) * (Q - 1)))"
by (metis d bezw_inverse coprime coprime_imp_gcd_eq_1 nat_int)
hence "[e * d = 1] (mod (Q - 1))"
using assms cong_modulus_mult_nat mult.commute by metis
then obtain k where k: "e*d = 1 + k*(Q-1)"
using ex_k_mod assms by force
hence "x ^ (e * d) = x * ((x ^ (Q - 1)) ^ k)"
by (metis power_add power_one_right mult.commute power_mult)
hence "[x ^ (e * d) = x * ((x ^ (Q - 1)) ^ k)] (mod P)"
using cong_def by simp
moreover have "[x ^ (Q - 1) = 1] (mod Q)"
using prime_Q fermat_theorem False
moreover have "[x ^ (e * d) = x * ((1) ^ k)] (mod Q)"
by (metis ‹x ^ (e * d) = x * (x ^ (Q - 1)) ^ k› calculation(2) cong_pow cong_scalar_left)
hence "[x ^ (e * d) = x] (mod Q)" by simp
thus ?thesis using z_def z_gt_0
qed
ultimately have "Q dvd (x ^ e) ^ d - x"
"P dvd (x ^ e) ^ d - x"
using z_def assms cong_0_iff by blast +
hence "P * Q dvd ((x ^ e) ^ d - x)"
using assms divides_mult primes_coprime_nat by blast
hence "[(x ^ e) ^ d = x] (mod (P * Q))"
using z_gt_0 cong_altdef_nat z_def by auto
thus ?thesis
qed

lemma rsa_inv_set_spmf_I:
assumes "((N, e), d) ∈ set_spmf I"
shows "((((x::nat) ^ e) mod N) ^ d) mod N = x mod N"
proof-
obtain P Q where "N = P * Q" and "d = nat (fst (bezw e ((P-1)*(Q-1))) mod int ((P-1)*(Q-1)))"
and "prime P"
and "prime Q"
and "coprime e ((P - 1)*(Q - 1))"
and "P ≠ Q"
using assms set_spmf_I_N
by blast
moreover have "e > 1" and "d > 1" using set_spmf_I_e_d assms by auto
ultimately show ?thesis using rsa_inv by blast
qed

sublocale etp_rsa: etp I domain range F F⇩i⇩n⇩v
unfolding etp_def apply(auto simp add: etp_def dom_eq_ran finite_range bij_betw1 lossless_I)
apply (metis fst_conv lessThan_iff mem_simps(2) nat_0_less_mult_iff prime_gt_0_nat range_def set_spmf_I_N)
apply(auto simp add: F_def F⇩i⇩n⇩v_def) using rsa_inv_set_spmf_I

sublocale etp: ETP_base I domain range B F F⇩i⇩n⇩v
unfolding ETP_base_def

text‹After proving the RSA collection is an ETP the proofs of security come easily from the general proofs.›

lemma correctness_rsa: "etp.OT_12.correctness m1 m2"
by (rule local.etp.correct)

lemma P1_security_rsa: "etp.OT_12.perfect_sec_P1 m1 m2"
by(rule local.etp.P1_security_inf_the)

lemma P2_security_rsa:
assumes "∀ a. lossless_spmf (D a)"

end

locale rsa_asym =
fixes prime_set :: "nat ⇒ nat set"
and B :: "index ⇒ nat ⇒ bool"
assumes rsa_proof_assm: "⋀ n. rsa_base (prime_set n)"
begin

sublocale rsa_base "(prime_set n)" B
using local.rsa_proof_assm  by simp

lemma correctness_rsa_asymp:
shows "etp.OT_12.correctness n m1 m2"
by(rule correctness_rsa)

lemma P1_sec_asymp: "etp.OT_12.perfect_sec_P1 n m1 m2"
by(rule local.P1_security_rsa)

lemma P2_sec_asym:
assumes "∀ a. lossless_spmf (D a)"
shows "negligible (λ n. etp.OT_12.adv_P2 n m1 m2 D)"
proof-
moreover have "¦etp.OT_12.adv_P2 n m1 m2 D¦ = etp.OT_12.adv_P2 n m1 m2 D"
moreover have "etp.OT_12.adv_P2 n m1 m2 D ≤ 2 * hcp_advantage n" for n
using P2_security_rsa assms by blast
ultimately show ?thesis
using assms negligible_le by presburger
qed

end

end

# Theory Noar_Pinkas_OT

subsection ‹Noar Pinkas OT›

text‹Here we prove security for the Noar Pinkas OT from \cite{DBLP:conf/soda/NaorP01}.›

theory Noar_Pinkas_OT imports
Cyclic_Group_Ext
Game_Based_Crypto.Diffie_Hellman
OT_Functionalities
Semi_Honest_Def
Uniform_Sampling
begin

locale np_base =
fixes 𝒢 :: "'grp cyclic_group" (structure)
assumes finite_group: "finite (carrier 𝒢)"
and or_gt_0: "0 < order 𝒢"
and prime_order: "prime (order 𝒢)"
begin

lemma prime_field: "a < (order 𝒢) ⟹ a ≠ 0 ⟹ coprime a (order 𝒢)"
by(metis dvd_imp_le neq0_conv not_le prime_imp_coprime prime_order coprime_commute)

lemma weight_sample_uniform_units: "weight_spmf (sample_uniform_units (order 𝒢)) = 1"
using  lossless_spmf_def lossless_sample_uniform_units prime_order  prime_gt_1_nat by auto

definition protocol :: "('grp × 'grp) ⇒ bool ⇒ (unit × 'grp) spmf"
where "protocol M v = do {
let (m0,m1) = M;
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
c⇩v' :: nat ← sample_uniform (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
let z0' = ((❙g [^] (if v then c⇩v' else c⇩v)) [^] s0) ⊗ ((❙g [^] b) [^] r0);
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z1' = ((❙g [^] ((if v then c⇩v else c⇩v'))) [^] s1) ⊗ ((❙g [^] b) [^] r1);
let enc_m0 = z0' ⊗ m0;
let enc_m1 = z1' ⊗ m1;
let out_2 = (if v then enc_m1 ⊗ inv (w1 [^] b) else enc_m0 ⊗ inv (w0 [^] b));
return_spmf ((), out_2)}"

lemma lossless_protocol: "lossless_spmf (protocol M σ)"
apply(auto simp add: protocol_def Let_def split_def lossless_sample_uniform_units or_gt_0)
using prime_order prime_gt_1_nat lossless_sample_uniform_units by simp

type_synonym 'grp' view1 = "(('grp' × 'grp') × ('grp' × 'grp' × 'grp' × 'grp')) spmf"

type_synonym 'grp' dist_adversary = "(('grp' × 'grp') × 'grp' × 'grp' × 'grp' × 'grp') ⇒ bool spmf"

definition R1 :: "('grp × 'grp) ⇒ bool ⇒ 'grp view1"
where "R1 msgs σ = do {
let (m0, m1) = msgs;
a ← sample_uniform (order 𝒢);
b ← sample_uniform (order 𝒢);
let c⇩σ = a*b;
c⇩σ' ← sample_uniform (order 𝒢);
return_spmf (msgs, (❙g [^] a, ❙g [^] b, (if σ then ❙g [^] c⇩σ' else ❙g [^] c⇩σ), (if σ then ❙g [^] c⇩σ else ❙g [^] c⇩σ')))}"

lemma lossless_R1: "lossless_spmf (R1 M σ)"
by(simp add: R1_def Let_def lossless_sample_uniform_units or_gt_0)

definition inter :: "('grp × 'grp) ⇒ 'grp view1"
where "inter msgs = do {
a ← sample_uniform (order 𝒢);
b ← sample_uniform (order 𝒢);
c ← sample_uniform (order 𝒢);
d ← sample_uniform (order 𝒢);
return_spmf (msgs, ❙g [^] a, ❙g [^] b, ❙g [^] c, ❙g [^] d)}"

definition S1 :: "('grp × 'grp) ⇒ unit ⇒ 'grp view1"
where "S1 msgs out1 = do {
let (m0, m1) = msgs;
a ← sample_uniform (order 𝒢);
b ← sample_uniform (order 𝒢);
c ← sample_uniform (order 𝒢);
return_spmf (msgs, (❙g [^] a, ❙g [^] b, ❙g [^] c, ❙g [^] (a*b)))}"

lemma lossless_S1: "lossless_spmf (S1 M out1)"
by(simp add: S1_def Let_def lossless_sample_uniform_units or_gt_0)

fun R1_inter_adversary :: "'grp dist_adversary ⇒ ('grp × 'grp) ⇒ 'grp ⇒ 'grp ⇒ 'grp ⇒ bool spmf"
where "R1_inter_adversary 𝒜 msgs α β γ = do {
c ← sample_uniform (order 𝒢);
𝒜 (msgs, α, β, γ, ❙g [^] c)}"

fun inter_S1_adversary :: "'grp dist_adversary ⇒ ('grp × 'grp) ⇒ 'grp ⇒ 'grp ⇒ 'grp ⇒ bool spmf"
where "inter_S1_adversary 𝒜 msgs α β γ = do {
c ← sample_uniform (order 𝒢);
𝒜 (msgs, α, β, ❙g [^] c, γ)}"

sublocale ddh: ddh 𝒢 .

definition R2 :: "('grp × 'grp) ⇒ bool ⇒ (bool × 'grp × 'grp ×  'grp × 'grp × 'grp × 'grp × 'grp) spmf"
where "R2 M v  = do {
let (m0,m1) = M;
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
c⇩v' :: nat ← sample_uniform (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
let z = ((❙g [^] c⇩v') [^] s0) ⊗ ((❙g [^] b) [^] r0);
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
let enc_m = z ⊗ (if v then m0 else m1);
let enc_m' = z' ⊗ (if v then m1 else m0) ;
return_spmf(v, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"

lemma lossless_R2: "lossless_spmf (R2 M σ)"
apply(simp add: R2_def Let_def split_def lossless_sample_uniform_units or_gt_0)
using prime_order prime_gt_1_nat lossless_sample_uniform_units by simp

definition S2 :: "bool ⇒ 'grp ⇒ (bool × 'grp × 'grp × 'grp × 'grp × 'grp × 'grp × 'grp) spmf"
where "S2 v m =  do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
s' ← sample_uniform (order 𝒢);
let enc_m =  ❙g [^] s';
let enc_m' = z' ⊗ m ;
return_spmf(v, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"

lemma lossless_S2: "lossless_spmf (S2 σ out2)"
apply(simp add: S2_def Let_def lossless_sample_uniform_units or_gt_0)
using prime_order prime_gt_1_nat lossless_sample_uniform_units by simp

sublocale sim_def: sim_det_def R1 S1 R2 S2 funct_OT_12 protocol
unfolding sim_det_def_def
by(auto simp add: lossless_R1 lossless_S1 lossless_R2 lossless_S2 lossless_protocol lossless_funct_OT_12)

end

locale np = np_base + cyclic_group 𝒢
begin

lemma protocol_inverse:
assumes "m0 ∈ carrier 𝒢" "m1 ∈ carrier 𝒢"
shows" ((❙g [^] ((a*b) mod (order 𝒢))) [^] (s1 :: nat)) ⊗ ((❙g [^] b) [^] (r1::nat)) ⊗ (if v then m0 else m1) ⊗ inv (((❙g [^] a) [^] s1 ⊗ ❙g [^] r1) [^] b)
= (if v then m0 else m1)"
(is "?lhs = ?rhs")
proof-
have  1: "(a*b)*(s1) + b*r1 =((a::nat)*(s1) + r1)*b " using mult.commute mult.assoc  add_mult_distrib by auto
have "?lhs =
((❙g [^] (a*b)) [^] s1) ⊗ ((❙g [^] b) [^] r1) ⊗ (if v then m0 else m1) ⊗ inv (((❙g [^] a) [^] s1 ⊗ ❙g [^] r1) [^] b)"
also have "... = (❙g [^] ((a*b)*(s1))) ⊗ ((❙g [^] (b*r1))) ⊗ ((if v then m0 else m1) ⊗ inv (((❙g [^] ((a*(s1) + r1)*b)))))"
by(auto simp add: nat_pow_pow nat_pow_mult assms cyclic_group_assoc)
also have "... = ❙g [^] ((a*b)*(s1)) ⊗ ❙g [^] (b*r1) ⊗ ((inv (((❙g [^] ((a*(s1) + r1)*b))))) ⊗ (if v then m0 else m1))"
also have "... = (❙g [^] ((a*b)*(s1) + b*r1) ⊗ inv (((❙g [^] ((a*(s1) + r1)*b))))) ⊗ (if v then m0 else m1)"
also have "... = (❙g [^] ((a*b)*(s1) + b*r1) ⊗ inv (((❙g [^] (((a*b)*(s1) + r1*b)))))) ⊗ (if v then m0 else m1)"
using 1 by (simp add: mult.commute)
ultimately show ?thesis
using l_cancel_inv assms  by (simp add: mult.commute)
qed

lemma correctness:
assumes "m0 ∈ carrier 𝒢" "m1 ∈ carrier 𝒢"
shows "sim_def.correctness (m0,m1) σ"
proof-
have "protocol (m0, m1) σ = funct_OT_12 (m0, m1) σ"
proof-
have "protocol (m0, m1) σ = do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let out_2 = ((❙g [^] ((a*b) mod (order 𝒢))) [^] s1) ⊗ ((❙g [^] b) [^] r1) ⊗ (if σ then m1 else m0) ⊗ inv (((❙g [^] a) [^] s1 ⊗ ❙g [^] r1) [^] b);
return_spmf ((), out_2)}"
by(simp add: protocol_def lossless_sample_uniform_units bind_spmf_const weight_sample_uniform_units or_gt_0)
also have "... = do {
let out_2 = (if σ then m1 else m0);
return_spmf ((), out_2)}"
by(simp add: protocol_inverse assms lossless_sample_uniform_units bind_spmf_const weight_sample_uniform_units or_gt_0)
ultimately show ?thesis by(simp add: Let_def funct_OT_12_def)
qed
thus ?thesis
qed

lemma security_P1:
(is "?lhs ≤ ?rhs")
proof(cases σ)
case True
have "R1 msgs σ = S1 msgs out1" for out1
then have "sim_def.adv_P1 msgs σ D = 0"
also have "ddh.advantage A ≥ 0" for A using ddh.advantage_def by simp
ultimately show ?thesis by simp
next
case False
have bounded_advantage: "¦(a :: real) - b¦ = e1 ⟹ ¦b - c¦ = e2 ⟹ ¦a - c¦ ≤ e1 + e2"
for a b e1 c e2 by simp
also have R1_inter_dist: "¦spmf (R1 msgs False ⤜ D) True - spmf ((inter msgs) ⤜ D) True¦ = ddh.advantage (R1_inter_adversary D msgs)"
unfolding R1_def inter_def ddh.advantage_def ddh.ddh_0_def ddh.ddh_1_def Let_def split_def by(simp)
also  have inter_S1_dist: "¦spmf ((inter msgs) ⤜ D) True - spmf (S1 msgs out1 ⤜ D) True¦ = ddh.advantage (inter_S1_adversary D msgs)"
ultimately have "¦spmf (R1 msgs False ⤜ (λview. D view)) True - spmf (S1 msgs out1 ⤜ (λview. D view)) True¦ ≤ ?rhs"
for out1 using R1_inter_dist by auto
qed

assumes "s0 < order 𝒢"
and "s0 ≠ 0"
shows "map_spmf (λ c⇩v'. (((b* r0) + (s0 * c⇩v')) mod(order 𝒢))) (sample_uniform (order 𝒢)) = sample_uniform (order 𝒢)"
proof-
have "gcd s0 (order 𝒢) = 1"
using assms prime_field by simp
thus ?thesis
qed

lemma security_P2:
assumes "m0 ∈ carrier 𝒢" "m1 ∈ carrier 𝒢"
shows "sim_def.perfect_sec_P2 (m0,m1) σ"
proof-
have "R2 (m0, m1) σ = S2 σ (if σ then m1 else m0)"
proof-
have "R2 (m0, m1) σ = do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
c⇩v' :: nat ← sample_uniform (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
let s' = (((b* r0) + ((c⇩v')*(s0))) mod(order 𝒢));
let z = ❙g [^] s' ;
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
let enc_m = z ⊗ (if σ then m0 else m1);
let enc_m' = z' ⊗ (if σ then m1 else m0) ;
return_spmf(σ, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"
also have "... =  do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
s' ← map_spmf (λ c⇩v'. (((b* r0) + ((c⇩v')*(s0))) mod(order 𝒢))) (sample_uniform (order 𝒢));
let z = ❙g [^] s';
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
let enc_m = z ⊗ (if σ then m0 else m1);
let enc_m' = z' ⊗ (if σ then m1 else m0) ;
return_spmf(σ, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"
also have "... =  do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
s' ←  (sample_uniform (order 𝒢));
let z = ❙g [^] s';
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
let enc_m = z ⊗ (if σ then m0 else m1);
let enc_m' = z' ⊗ (if σ then m1 else m0) ;
return_spmf(σ, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"
also have "... =  do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
enc_m ← map_spmf (λ s'.  ❙g [^] s' ⊗ (if σ then m0 else m1)) (sample_uniform (order 𝒢));
let enc_m' = z' ⊗ (if σ then m1 else m0) ;
return_spmf(σ, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"
also have "... =  do {
a :: nat ← sample_uniform (order 𝒢);
b :: nat ← sample_uniform (order 𝒢);
let c⇩v = (a*b) mod (order 𝒢);
r0 :: nat ← sample_uniform_units (order 𝒢);
s0 :: nat ← sample_uniform_units (order 𝒢);
let w0 = (❙g [^] a) [^] s0 ⊗ ❙g [^] r0;
r1 :: nat ← sample_uniform_units (order 𝒢);
s1 :: nat ← sample_uniform_units (order 𝒢);
let w1 = (❙g [^] a) [^] s1 ⊗ ❙g [^] r1;
let z' = ((❙g [^] (c⇩v)) [^] s1) ⊗ ((❙g [^] b) [^] r1);
enc_m ← map_spmf (λ s'.  ❙g [^] s') (sample_uniform (order 𝒢));
let enc_m' = z' ⊗ (if σ then m1 else m0) ;
return_spmf(σ, ❙g [^] a, ❙g [^] b, ❙g [^] c⇩v, w0, enc_m, w1, enc_m')}"
ultimately show ?thesis by(simp add: S2_def Let_def bind_map_spmf o_def)
qed
thus ?thesis
qed

end

locale np_asymp =
fixes 𝒢 :: "security ⇒ 'grp cyclic_group"
assumes np: "⋀η. np (𝒢 η)"
begin

sublocale np "𝒢 η" for η by(simp add: np)

theorem correctness_asymp:
assumes "m0 ∈ carrier (𝒢 η)" "m1 ∈ carrier (𝒢 η)"
shows "sim_def.correctness η (m0, m1) σ"

theorem security_P1_asymp:
shows "negligible (λ η. sim_def.adv_P1 η msgs σ D)"
proof-
for η
using security_P1 by simp
using assms
ultimately show ?thesis
qed

theorem security_P2_asymp:
assumes "m0 ∈ carrier (𝒢 η)" "m1 ∈ carrier (𝒢 η)"
shows "sim_def.perfect_sec_P2 η (m0,m1) σ"

end

end

# Theory OT14

subsection ‹1-out-of-2 OT to 1-out-of-4 OT›

text ‹Here we construct a protocol that achieves 1-out-of-4 OT from 1-out-of-2 OT. We follow the protocol
for constructing 1-out-of-n OT from 1-out-of-2 OT from \cite{DBLP:books/cu/Goldreich2004}. We assume the security
properties on 1-out-of-2 OT.›

theory OT14 imports
Semi_Honest_Def
OT_Functionalities
Uniform_Sampling
begin

type_synonym input1 = "bool × bool × bool × bool"
type_synonym input2 = "bool × bool"
type_synonym 'v_OT121' view1 = "(input1 × (bool × bool × bool × bool × bool × bool) × 'v_OT121' × 'v_OT121' × 'v_OT121')"
type_synonym 'v_OT122' view2 = "(input2 × (bool × bool × bool × bool) × 'v_OT122' × 'v_OT122' × 'v_OT122')"

locale ot14_base =
fixes S1_OT12 :: "(bool × bool) ⇒ unit ⇒ 'v_OT121 spmf" ― ‹simulator for party 1 in OT12›
and R1_OT12 :: "(bool × bool) ⇒ bool ⇒ 'v_OT121 spmf" ― ‹real view for party 1 in OT12›
and S2_OT12 :: "bool ⇒ bool ⇒ 'v_OT122 spmf"
and R2_OT12 :: "(bool × bool) ⇒ bool ⇒ 'v_OT122 spmf"
and protocol_OT12 :: "(bool × bool) ⇒ bool ⇒ (unit × bool) spmf"
and inf_th_OT12_P2:  "sim_det_def.perfect_sec_P2 R2_OT12 S2_OT12 funct_OT12 (m0,m1) σ" ― ‹information theoretic security for party 2›
and correct: "protocol_OT12 msgs b = funct_OT_12 msgs b"
and lossless_R1_12: "lossless_spmf (R1_OT12 m c)"
and lossless_S1_12: "lossless_spmf (S1_OT12 m out1)"
and lossless_S2_12: "lossless_spmf (S2_OT12 c out2)"
and lossless_R2_12: "lossless_spmf (R2_OT12 M c)"
and lossless_funct_OT12: "lossless_spmf (funct_OT12 (m0,m1) c)"
and lossless_protocol_OT12: "lossless_spmf (protocol_OT12 M C)"
begin

sublocale OT_12_sim: sim_det_def R1_OT12 S1_OT12 R2_OT12 S2_OT12 funct_OT_12 protocol_OT12
unfolding sim_det_def_def
by(simp add: lossless_R1_12 lossless_S1_12 lossless_funct_OT12 lossless_R2_12 lossless_S2_12)

lemma OT_12_P1_assms_bound': "¦spmf (bind_spmf (R1_OT12 (m0,m1) c) (λ view. ((D::'v_OT121 ⇒ bool spmf) view ))) True
- spmf (bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (D view ))) True¦ ≤ adv_OT12"
proof-
have "sim_det_def.adv_P1 R1_OT12 S1_OT12 funct_OT_12 (m0,m1) c D =
¦spmf (bind_spmf (R1_OT12 (m0,m1) c) (λ view. (D view ))) True
- spmf (funct_OT_12 (m0,m1) c ⤜ (λ ((out1::unit), (out2::bool)).
S1_OT12 (m0,m1) out1 ⤜ (λ view. D view))) True¦"
also have "... = ¦spmf (bind_spmf (R1_OT12 (m0,m1) c) (λ view. ((D::'v_OT121 ⇒ bool spmf) view ))) True
- spmf (bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (D view ))) True¦"
ultimately show ?thesis
qed

lemma OT_12_P2_assm: "R2_OT12 (m0,m1) σ = funct_OT_12 (m0,m1) σ ⤜ (λ (out1, out2). S2_OT12 σ out2)"
using inf_th_OT12_P2 OT_12_sim.perfect_sec_P2_def by blast

definition protocol_14_OT :: "input1 ⇒ input2 ⇒ (unit × bool) spmf"
where "protocol_14_OT M C = do {
let (c0,c1) = C;
let (m00, m01, m10, m11) = M;
S0 ← coin_spmf;
S1 ← coin_spmf;
S2 ← coin_spmf;
S3 ← coin_spmf;
S4 ← coin_spmf;
S5 ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m00;
let a1 = S0 ⊕ S3 ⊕ m01;
let a2 = S1 ⊕ S4 ⊕ m10;
let a3 = S1 ⊕ S5 ⊕ m11;
(_,Si) ← protocol_OT12 (S0, S1) c0;
(_,Sj) ← protocol_OT12 (S2, S3) c1;
(_,Sk) ← protocol_OT12 (S4, S5) c1;
let s2 = Si ⊕ (if c0 then Sk else Sj) ⊕ (if c0 then (if c1 then a3 else a2) else (if c1 then a1 else a0));
return_spmf ((), s2)}"

lemma lossless_protocol_14_OT: "lossless_spmf (protocol_14_OT M C)"

definition R1_14 :: "input1 ⇒ input2 ⇒ 'v_OT121 view1 spmf"
where "R1_14 msgs choice = do {
let (m00, m01, m10, m11) = msgs;
let (c0, c1) = choice;
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a :: 'v_OT121 ← R1_OT12 (S0, S1) c0;
b :: 'v_OT121 ← R1_OT12 (S2, S3) c1;
c :: 'v_OT121 ← R1_OT12 (S4, S5) c1;
return_spmf (msgs, (S0, S1, S2, S3, S4, S5), a, b, c)}"

lemma lossless_R1_14: "lossless_spmf (R1_14 msgs C)"

definition R1_14_interm1 :: "input1 ⇒ input2 ⇒ 'v_OT121 view1 spmf"
where "R1_14_interm1 msgs choice = do {
let (m00, m01, m10, m11) = msgs;
let (c0, c1) = choice;
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a :: 'v_OT121 ← S1_OT12 (S0, S1) ();
b :: 'v_OT121 ← R1_OT12 (S2, S3) c1;
c :: 'v_OT121 ← R1_OT12 (S4, S5) c1;
return_spmf (msgs, (S0, S1, S2, S3, S4, S5), a, b, c)}"

lemma lossless_R1_14_interm1: "lossless_spmf (R1_14_interm1 msgs C)"
by(simp add: R1_14_interm1_def split_def lossless_R1_12 lossless_S1_12)

definition R1_14_interm2 :: "input1 ⇒ input2 ⇒ 'v_OT121 view1 spmf"
where "R1_14_interm2 msgs choice = do {
let (m00, m01, m10, m11) = msgs;
let (c0, c1) = choice;
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a :: 'v_OT121 ← S1_OT12 (S0, S1) ();
b :: 'v_OT121 ← S1_OT12 (S2, S3) ();
c :: 'v_OT121 ← R1_OT12 (S4, S5) c1;
return_spmf (msgs, (S0, S1, S2, S3, S4, S5), a, b, c)}"

lemma lossless_R1_14_interm2: "lossless_spmf (R1_14_interm2 msgs C)"
by(simp add: R1_14_interm2_def split_def lossless_R1_12 lossless_S1_12)

definition S1_14 :: "input1 ⇒ unit ⇒ 'v_OT121 view1 spmf"
where "S1_14 msgs _ = do {
let (m00, m01, m10, m11) = msgs;
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a :: 'v_OT121 ← S1_OT12 (S0, S1) ();
b :: 'v_OT121 ← S1_OT12 (S2, S3) ();
c :: 'v_OT121 ← S1_OT12 (S4, S5) ();
return_spmf (msgs, (S0, S1, S2, S3, S4, S5), a, b, c)}"

lemma lossless_S1_14: "lossless_spmf (S1_14 m out)"

lemma reduction_step1:
shows "∃ A1. ¦spmf (bind_spmf (R1_14 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c0) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦"
proof-
define A1' where "A1' == λ (view :: 'v_OT121) (m0,m1). do {
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
b :: 'v_OT121 ← R1_OT12 (S2, S3) c1;
c :: 'v_OT121 ← R1_OT12 (S4, S5) c1;
let R = (M, (m0,m1, S2, S3, S4, S5), view, b, c);
D R}"
have "¦spmf (bind_spmf (R1_14 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c0) (λ view. (A1' view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1' view (m0,m1))))) True¦"
apply(simp add: pair_spmf_alt_def R1_14_def R1_14_interm1_def A1'_def Let_def split_def)
apply(subst bind_commute_spmf[of "S1_OT12 _ _"])
apply(subst bind_commute_spmf[of "S1_OT12 _ _"])
apply(subst bind_commute_spmf[of "S1_OT12 _ _"])
apply(subst bind_commute_spmf[of "S1_OT12 _ _"])
apply(subst bind_commute_spmf[of "S1_OT12 _ _"])
by auto
then show ?thesis by auto
qed

lemma reduction_step1':
shows "¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c0) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦
proof-
have int1: "integrable (measure_spmf (pair_spmf coin_spmf coin_spmf)) (λx. spmf (case x of (m0, m1) ⇒ R1_OT12 (m0, m1) c0 ⤜ (λview. A1 view (m0, m1))) True)"
and int2: "integrable (measure_spmf (pair_spmf coin_spmf coin_spmf)) (λx. spmf (case x of (m0, m1) ⇒ S1_OT12 (m0, m1) () ⤜ (λview. A1 view (m0, m1))) True)"
by(rule measure_spmf.integrable_const_bound[where B=1]; simp add: pmf_le_1)+
have "?lhs =
¦LINT x|measure_spmf (pair_spmf coin_spmf coin_spmf). spmf (case x of (m0, m1) ⇒ R1_OT12 (m0, m1) c0 ⤜ (λview. A1 view (m0, m1))) True
- spmf (case x of (m0, m1) ⇒ S1_OT12 (m0, m1) () ⤜ (λview. A1 view (m0, m1))) True¦"
apply(subst (1 2) spmf_bind) using int1 int2 by simp
also have "... ≤ LINT x|measure_spmf (pair_spmf coin_spmf coin_spmf).
¦spmf (R1_OT12 x c0 ⤜ (λview. A1 view x)) True - spmf (S1_OT12 x () ⤜ (λview. A1 view x)) True¦"
by(rule integral_abs_bound[THEN order_trans]; simp add: split_beta)
ultimately have "?lhs ≤ LINT x|measure_spmf (pair_spmf coin_spmf coin_spmf).
¦spmf (R1_OT12 x c0 ⤜ (λview. A1 view x)) True - spmf (S1_OT12 x () ⤜ (λview. A1 view x)) True¦"
by simp
also have "LINT x|measure_spmf (pair_spmf coin_spmf coin_spmf).
¦spmf (R1_OT12 x c0 ⤜ (λview::'v_OT121. A1 view x)) True
- spmf (S1_OT12 x () ⤜ (λview::'v_OT121. A1 view x)) True¦ ≤ adv_OT12"
apply(rule integral_mono[THEN order_trans])
apply(rule measure_spmf.integrable_const_bound[where B=2])
apply clarsimp
apply(rule abs_triangle_ineq4[THEN order_trans])
subgoal for m0 m1
using pmf_le_1[of "R1_OT12 (m0, m1) c0 ⤜ (λview. A1 view (m0, m1))" "Some True"]
pmf_le_1[of "S1_OT12 (m0, m1) () ⤜ (λview. A1 view (m0, m1))" "Some True"]
by simp
apply simp
apply(rule measure_spmf.integrable_const)
apply clarify
apply(rule OT_12_P1_assms_bound'[rule_format])
by simp
ultimately show ?thesis by simp
qed

lemma reduction_step2:
shows "∃ A1. ¦spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦"
proof-
define A1' where "A1' == λ (view :: 'v_OT121) (m0,m1). do {
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a :: 'v_OT121 ← S1_OT12 (S2,S3) ();
c :: 'v_OT121 ← R1_OT12 (S4, S5) c1;
let R = (M, (S2,S3, m0, m1, S4, S5), a, view, c);
D R}"
have "¦spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A1' view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1' view (m0,m1))))) True¦"
proof-
have "(bind_spmf (R1_14_interm1 M (c0, c1)) D) = (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A1' view (m0,m1)))))"
unfolding R1_14_interm1_def R1_14_interm2_def A1'_def Let_def split_def
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
also have "(bind_spmf (R1_14_interm2 M (c0, c1)) D) =  (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1' view (m0,m1)))))"
unfolding R1_14_interm1_def R1_14_interm2_def A1'_def Let_def split_def
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑"  in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑"  in "_ = ⌑" bind_commute_spmf)
by(simp)
ultimately show ?thesis by simp
qed
then show ?thesis by auto
qed

lemma reduction_step3:
shows "∃ A1. ¦spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True - spmf (bind_spmf (S1_14 M out) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦"
proof-
define A1' where "A1' == λ (view :: 'v_OT121) (m0,m1). do {
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a :: 'v_OT121 ← S1_OT12 (S2,S3) ();
b :: 'v_OT121 ← S1_OT12 (S4, S5) ();
let R = (M, (S2,S3, S4, S5,m0, m1), a, b, view);
D R}"
have "¦spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True - spmf (bind_spmf (S1_14 M out) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A1' view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1' view (m0,m1))))) True¦"
proof-
have "(bind_spmf (R1_14_interm2 M (c0, c1)) D) = (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A1' view (m0,m1)))))"
unfolding  R1_14_interm2_def A1'_def Let_def split_def
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
also have "(bind_spmf (S1_14 M out) D) = (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1' view (m0,m1)))))"
unfolding S1_14_def Let_def A1'_def split_def
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
ultimately show ?thesis by simp
qed
then show ?thesis by auto
qed

lemma reduction_P1_interm:
shows "¦spmf (bind_spmf (R1_14 M (c0,c1)) (D)) True - spmf (bind_spmf (S1_14 M out) (D)) True¦ ≤ 3 * adv_OT12"
(is "?lhs ≤ ?rhs")
proof-
have lhs: "?lhs ≤ ¦spmf (bind_spmf (R1_14 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True¦ +
¦spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True¦ +
¦spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True - spmf (bind_spmf (S1_14 M out) D) True¦"
by simp
obtain A1 where A1: "¦spmf (bind_spmf (R1_14 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c0) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦"
using reduction_step1 by blast
obtain A2 where A2: "¦spmf (bind_spmf (R1_14_interm1 M (c0, c1)) D) True - spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A2 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A2 view (m0,m1))))) True¦"
using reduction_step2 by blast
obtain A3 where A3: "¦spmf (bind_spmf (R1_14_interm2 M (c0, c1)) D) True - spmf (bind_spmf (S1_14 M out) D) True¦ =
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A3 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A3 view (m0,m1))))) True¦"
using reduction_step3 by blast
have lhs_bound: "?lhs ≤ ¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c0) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦ +
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A2 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A2 view (m0,m1))))) True¦ +
¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A3 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A3 view (m0,m1))))) True¦"
using A1 A2 A3 lhs by simp
have bound1: "¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c0) (λ view. (A1 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A1 view (m0,m1))))) True¦
and bound2: "¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A2 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A2 view (m0,m1))))) True¦
and bound3: "¦spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (R1_OT12 (m0,m1) c1) (λ view. (A3 view (m0,m1))))) True -
spmf (bind_spmf (pair_spmf coin_spmf coin_spmf) (λ(m0, m1). bind_spmf (S1_OT12 (m0,m1) ()) (λ view. (A3 view (m0,m1))))) True¦ ≤ adv_OT12"
using reduction_step1' by auto
thus ?thesis
using reduction_step1' lhs_bound by argo
qed

lemma reduction_P1: "¦spmf (bind_spmf (R1_14 M (c0,c1)) (D)) True
- spmf (funct_OT_14 M (c0,c1) ⤜ (λ (out1,out2). S1_14 M out1 ⤜ (λ view. D view))) True¦
by(simp add: funct_OT_14_def split_def Let_def reduction_P1_interm )

text‹Party 2 security.›

lemma coin_coin: "map_spmf (λ S0. S0 ⊕ S3 ⊕ m1) coin_spmf = coin_spmf"
(is "?lhs = ?rhs")
proof-
have lhs: "?lhs = map_spmf (λ S0. S0 ⊕ (S3 ⊕ m1)) coin_spmf" by blast
also have op_eq: "... = map_spmf ((⊕) (S3 ⊕ m1)) coin_spmf"
by (metis xor_bool_def)
also have "... = ?rhs"
using xor_uni_samp by fastforce
ultimately show ?thesis
using op_eq by auto
qed

lemma coin_coin': "map_spmf (λ S3. S0 ⊕ S3 ⊕ m1) coin_spmf = coin_spmf"
proof-
have "map_spmf (λ S3. S0 ⊕ S3 ⊕ m1) coin_spmf = map_spmf (λ S3. S3 ⊕ S0 ⊕ m1) coin_spmf"
by (metis xor_left_commute)
thus ?thesis using coin_coin by simp
qed

definition R2_14:: "input1 ⇒ input2 ⇒ 'v_OT122 view2 spmf"
where "R2_14 M C = do {
let (m0,m1,m2,m3) = M;
let (c0,c1) = C;
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
let a1 = S0 ⊕ S3 ⊕ m1;
let a2 = S1 ⊕ S4 ⊕ m2;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← R2_OT12 (S0,S1) c0;
b :: 'v_OT122 ← R2_OT12 (S2,S3) c1;
c :: 'v_OT122 ← R2_OT12 (S4,S5) c1;
return_spmf (C, (a0,a1,a2,a3), a,b,c)}"

lemma lossless_R2_14: "lossless_spmf (R2_14 M C)"

definition S2_14 :: "input2 ⇒ bool ⇒ 'v_OT122 view2 spmf"
where "S2_14 C out = do {
let ((c0::bool),(c1::bool)) = C;
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
a1 :: bool ← coin_spmf;
a2 :: bool ← coin_spmf;
a3 :: bool ← coin_spmf;
let a0' = (if ((¬ c0) ∧ (¬ c1)) then (S0 ⊕ S2 ⊕ out) else a0);
let a1' = (if ((¬ c0) ∧ c1) then (S0 ⊕ S3 ⊕ out) else a1);
let a2' = (if (c0 ∧ (¬ c1)) then (S1 ⊕ S4 ⊕ out) else a2);
let a3' = (if (c0 ∧ c1) then (S1 ⊕ S5 ⊕ out) else a3);
a :: 'v_OT122 ← S2_OT12 (c0::bool) (if c0 then S1 else S0);
b :: 'v_OT122 ← S2_OT12 (c1::bool) (if c1 then S3 else S2);
c :: 'v_OT122 ← S2_OT12 (c1::bool) (if c1 then S5 else S4);
return_spmf ((c0,c1), (a0',a1',a2',a3'), a,b,c)}"

lemma lossless_S2_14: "lossless_spmf (S2_14 c out)"

lemma P2_OT_14_FT: "R2_14 (m0,m1,m2,m3) (False,True) = funct_OT_14 (m0,m1,m2,m3) (False,True) ⤜ (λ (out1, out2). S2_14 (False,True) out2)"
proof-
have "R2_14 (m0,m1,m2,m3) (False,True) =  do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← map_spmf (λ S2. S0 ⊕ S2 ⊕ m0) coin_spmf;
let a1 = S0 ⊕ S3 ⊕ m1;
a2 ← map_spmf (λ S4. S1 ⊕ S4 ⊕ m2) coin_spmf;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((False,True), (a0,a1,a2,a3), a,b,c)}"
by(simp add: bind_map_spmf o_def Let_def R2_14_def inf_th_OT12_P2 funct_OT_12_def OT_12_P2_assm)
also have "... =  do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
let a1 = S0 ⊕ S3 ⊕ m1;
a2 ← coin_spmf;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((False,True), (a0,a1,a2,a3), a,b,c)}"
using coin_coin' by simp
also have "... =  do {
S0 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
let a1 = S0 ⊕ S3 ⊕ m1;
a2 :: bool ← coin_spmf;
a3 ← map_spmf (λ S1. S1 ⊕ S5 ⊕ m3) coin_spmf;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((False,True), (a0,a1,a2,a3), a,b,c)}"
also have "... =  do {
S0 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
let a1 = S0 ⊕ S3 ⊕ m1;
a2 :: bool ← coin_spmf;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((False,True), (a0,a1,a2,a3), a,b,c)}"
using coin_coin by simp
ultimately show ?thesis
qed

lemma P2_OT_14_TT: "R2_14 (m0,m1,m2,m3) (True,True) = funct_OT_14 (m0,m1,m2,m3) (True,True) ⤜ (λ (out1, out2). S2_14 (True,True) out2)"
proof-
have "R2_14 (m0,m1,m2,m3) (True,True) =  do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← map_spmf (λ S2. S0 ⊕ S2 ⊕ m0) coin_spmf;
let a1 = S0 ⊕ S3 ⊕ m1;
a2 ← map_spmf (λ S4. S1 ⊕ S4 ⊕ m2) coin_spmf;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((True,True), (a0,a1,a2,a3), a,b,c)}"
by(simp add: bind_map_spmf o_def R2_14_def inf_th_OT12_P2 funct_OT_12_def OT_12_P2_assm Let_def)
also have "... = do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
let a1 = S0 ⊕ S3 ⊕ m1;
a2 ← coin_spmf;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((True,True), (a0,a1,a2,a3), a,b,c)}"
using coin_coin' by simp
also have "... = do {
S1 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
a1 :: bool ← map_spmf (λ S0. S0 ⊕ S3 ⊕ m1) coin_spmf;
a2 ← coin_spmf;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((True,True), (a0,a1,a2,a3), a,b,c)}"
also have "... = do {
S1 :: bool ← coin_spmf;
S3 :: bool ← coin_spmf;
S5 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
a1 :: bool ← coin_spmf;
a2 ← coin_spmf;
let a3 = S1 ⊕ S5 ⊕ m3;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 True S3;
c :: 'v_OT122 ← S2_OT12 True S5;
return_spmf ((True,True), (a0,a1,a2,a3), a,b,c)}"
using coin_coin by simp
ultimately show ?thesis
qed

lemma P2_OT_14_FF: "R2_14 (m0,m1,m2,m3) (False, False) = funct_OT_14 (m0,m1,m2,m3) (False, False) ⤜ (λ (out1, out2). S2_14 (False, False) out2)"
proof-
have "R2_14 (m0,m1,m2,m3) (False,False) =  do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
a1 :: bool ← map_spmf (λ S3. S0 ⊕ S3 ⊕ m1) coin_spmf;
let a2 = S1 ⊕ S4 ⊕ m2;
a3 ← map_spmf (λ S5. S1 ⊕ S5 ⊕ m3) coin_spmf;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((False,False), (a0,a1,a2,a3), a,b,c)}"
by(simp add: bind_map_spmf o_def R2_14_def inf_th_OT12_P2 funct_OT_12_def OT_12_P2_assm Let_def)
also have "... = do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
a1 :: bool ← coin_spmf;
let a2 = S1 ⊕ S4 ⊕ m2;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((False,False), (a0,a1,a2,a3), a,b,c)}"
using coin_coin' by simp
also have "... = do {
S0 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
a1 :: bool ← coin_spmf;
a2 :: bool ← map_spmf (λ S1. S1 ⊕ S4 ⊕ m2) coin_spmf;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((False,False), (a0,a1,a2,a3), a,b,c)}"
also have "... = do {
S0 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
a1 :: bool ← coin_spmf;
a2 :: bool ← coin_spmf;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 False S0;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((False,False), (a0,a1,a2,a3), a,b,c)}"
using coin_coin by simp
ultimately show ?thesis
qed

lemma P2_OT_14_TF: "R2_14 (m0,m1,m2,m3) (True,False) = funct_OT_14 (m0,m1,m2,m3) (True,False) ⤜ (λ (out1, out2). S2_14 (True,False) out2)"
proof-
have "R2_14 (m0,m1,m2,m3) (True,False) = do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
a1 :: bool ← map_spmf (λ S3. S0 ⊕ S3 ⊕ m1) coin_spmf;
let a2 = S1 ⊕ S4 ⊕ m2;
a3 ← map_spmf (λ S5. S1 ⊕ S5 ⊕ m3) coin_spmf;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((True,False), (a0,a1,a2,a3), a,b,c)}"
apply(simp add: R2_14_def inf_th_OT12_P2 OT_12_P2_assm funct_OT_12_def Let_def)
apply(rewrite in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑"  in "bind_spmf _ ⌑" in "⌑ = _" bind_commute_spmf)
also have "... = do {
S0 :: bool ← coin_spmf;
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
let a0 = S0 ⊕ S2 ⊕ m0;
a1 :: bool ← coin_spmf;
let a2 = S1 ⊕ S4 ⊕ m2;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((True,False), (a0,a1,a2,a3), a,b,c)}"
using coin_coin' by simp
also have "... = do {
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
a0 :: bool ← map_spmf (λ S0. S0 ⊕ S2 ⊕ m0) coin_spmf;
a1 :: bool ← coin_spmf;
let a2 = S1 ⊕ S4 ⊕ m2;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((True,False), (a0,a1,a2,a3), a,b,c)}"
also have "... = do {
S1 :: bool ← coin_spmf;
S2 :: bool ← coin_spmf;
S4 :: bool ← coin_spmf;
a0 :: bool ← coin_spmf;
a1 :: bool ← coin_spmf;
let a2 = S1 ⊕ S4 ⊕ m2;
a3 ← coin_spmf;
a :: 'v_OT122 ← S2_OT12 True S1;
b :: 'v_OT122 ← S2_OT12 False S2;
c :: 'v_OT122 ← S2_OT12 False S4;
return_spmf ((True,False), (a0,a1,a2,a3), a,b,c)}"
using coin_coin by simp
ultimately show ?thesis
apply(rewrite in "bind_spmf _ ⌑"  in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
apply(rewrite in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "bind_spmf _ ⌑" in "_ = ⌑" bind_commute_spmf)
by simp
qed

lemma P2_sec_OT_14_split: "R2_14 (m0,m1,m2,m3) (c0,c1) = funct_OT_14 (m0,m1,m2,m3) (c0,c1) ⤜ (λ (out1, out2). S2_14 (c0,c1) out2)"
by(cases c0; cases c1; auto simp add: P2_OT_14_FF P2_OT_14_TF P2_OT_14_FT P2_OT_14_TT)

lemma P2_sec_OT_14: "R2_14 M C = funct_OT_14 M C ⤜ (λ (out1, out2). S2_14 C out2)"
by(metis P2_sec_OT_14_split surj_pair)

sublocale OT_14: sim_det_def R1_14 S1_14 R2_14 S2_14 funct_OT_14 protocol_14_OT
unfolding sim_det_def_def
by(simp add: lossless_R1_14 lossless_S1_14 lossless_funct_14_OT lossless_R2_14 lossless_S2_14 )

lemma correctness_OT_14:
shows "funct_OT_14 M C = protocol_14_OT M C"
proof-
have "S1 = (S5 = (S1 = (S5 = d))) = d" for S1 S5 d by auto
thus ?thesis
by(cases "fst C"; cases "snd C"; simp add: funct_OT_14_def protocol_14_OT_def correct funct_OT_12_def lossless_funct_OT_12 bind_spmf_const split_def)
qed

lemma OT_14_correct: "OT_14.correctness M C"
unfolding OT_14.correctness_def
using correctness_OT_14 by auto

lemma OT_14_P2_sec: "OT_14.perfect_sec_P2 m1 m2"
unfolding OT_14.perfect_sec_P2_def
using P2_sec_OT_14 by blast

by (metis reduction_P1 surj_pair)

end

locale OT_14_asymp = sim_det_def +
fixes S1_OT12 :: "nat ⇒ (bool × bool) ⇒ unit ⇒ 'v_OT121 spmf"
and R1_OT12 :: "nat ⇒ (bool × bool) ⇒ bool ⇒ 'v_OT121 spmf"
and adv_OT12 :: "nat ⇒ real"
and S2_OT12 :: "nat ⇒ bool ⇒ bool ⇒ 'v_OT122 spmf"
and R2_OT12 :: "nat ⇒ (bool × bool) ⇒ bool ⇒ 'v_OT122 spmf"
and protocol_OT12 :: "(bool × bool) ⇒ bool ⇒ (unit × bool) spmf"
assumes ot14_base: "⋀ (n::nat). ot14_base (S1_OT12 n) (R1_12_0T n) (adv_OT12 n) (S2_OT12 n) (R2_12OT n) (protocol_OT12)"
begin

sublocale ot14_base "(S1_OT12 n)" "(R1_12_0T n)" "(adv_OT12 n)" "(S2_OT12 n)" "(R2_12OT n)" using local.ot14_base by simp

lemma OT_14_P1_sec: "OT_14.adv_P1 (R1_12_0T n) n m1 m2 D ≤ 3 * (adv_OT12 n)"
unfolding OT_14.adv_P1_def using reduction_P1 surj_pair by metis

theorem OT_14_P1_asym_sec: "negligible (λ n. OT_14.adv_P1 (R1_12_0T n) n m1 m2 D)" if "negligible (λ n. adv_OT12 n)"
proof-
have adv_neg: "negligible (λn. 3 * adv_OT12 n)" using that negligible_cmultI by simp
have "¦OT_14.adv_P1 (R1_12_0T n) n m1 m2 D¦ ≤ ¦3 * (adv_OT12 n)¦" for n
proof -
have "¦OT_14.adv_P1 (R1_12_0T n) n m1 m2 D¦ ≤ 3 * adv_OT12 n"
then show ?thesis
by (meson abs_ge_self order_trans)
qed
thus ?thesis using OT_14_P1_sec negligible_le adv_neg
by (metis (no_types, lifting) negligible_absI)
qed

theorem OT_14_P2_asym_sec: "OT_14.perfect_sec_P2 R2_OT12 n m1 m2"
using OT_14_P2_sec by simp

end

end



# Theory GMW

subsection ‹1-out-of-4 OT to GMW›

text‹We prove security for the gates of the GMW protocol in the semi honest model. We assume security on
1-out-of-4 OT.›

theory`