Session Jordan_Normal_Form

Theory Missing_Ring

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Missing Ring›

text ‹This theory contains several lemmas which might be of interest to the Isabelle distribution.›

theory Missing_Ring
imports
  "HOL-Algebra.Ring"
begin

context comm_monoid
begin

lemma finprod_reindex_bij_betw: "bij_betw h S T 
   g  h ` S  carrier G 
   finprod G (λx. g (h x)) S = finprod G g T"
  using finprod_reindex[of g h S] unfolding bij_betw_def by auto

lemma finprod_reindex_bij_witness:
  assumes witness:
    "a. a  S  i (j a) = a"
    "a. a  S  j a  T"
    "b. b  T  j (i b) = b"
    "b. b  T  i b  S"
  assumes eq:
    "a. a  S  h (j a) = g a"
  assumes g: "g  S  carrier G"
  and h: "h  j ` S  carrier G"
  shows "finprod G g S = finprod G h T"
proof -
  have b: "bij_betw j S T"
    using bij_betw_byWitness[where A=S and f=j and f'=i and A'=T] witness by auto
  have fp: "finprod G g S = finprod G (λx. h (j x)) S"
    by (rule finprod_cong, insert eq g, auto)
  show ?thesis
    using finprod_reindex_bij_betw[OF b h] unfolding fp .
qed
end

lemmas (in abelian_monoid) finsum_reindex_bij_witness = add.finprod_reindex_bij_witness

locale csemiring = semiring + comm_monoid R

context cring
begin
sublocale csemiring ..
end

lemma (in comm_monoid) finprod_one': 
  "( a. a  A  f a = 𝟭)  finprod G f A = 𝟭"
  by (induct A rule: infinite_finite_induct, auto)

lemma (in comm_monoid) finprod_split: 
  "finite A  f ` A  carrier G  a  A  finprod G f A = f a  finprod G f (A - {a})"
  by (rule trans[OF trans[OF _ finprod_Un_disjoint[of "{a}" "A - {a}" f]]], auto,
  rule arg_cong[of _ _ "finprod G f"], auto)

lemma (in comm_monoid) finprod_finprod:
  "finite A  finite B  ( a b. a  A   b  B  g a b  carrier G) 
  finprod G (λ a. finprod G (g a) B) A = finprod G (λ (a,b). g a b) (A × B)"
proof (induct A rule: finite_induct)
  case (insert a' A)
  note IH = this
  let ?l = "(ainsert a' A. finprod G (g a) B)"
  let ?r = "(ainsert a' A × B. case a of (a, b)  g a b)"
  have "?l = finprod G (g a') B  (aA. finprod G (g a) B)"
    using IH by simp
  also have "(aA. finprod G (g a) B) = finprod G (λ (a,b). g a b) (A × B)"
    by (rule IH(3), insert IH, auto)
  finally have idl: "?l = finprod G (g a') B  finprod G (λ (a,b). g a b) (A × B)" .
  from IH(2) have "insert a' A × B = {a'} × B  A × B" by auto
  hence "?r = (a{a'} × B  A × B. case a of (a, b)  g a b)" by simp
  also have " = (a{a'} × B. case a of (a, b)  g a b)  (a A × B. case a of (a, b)  g a b)"
    by (rule finprod_Un_disjoint, insert IH, auto)
  also have "(a{a'} × B. case a of (a, b)  g a b) = finprod G (g a') B"
    using IH(4) IH(5)
  proof (induct B rule: finite_induct)
    case (insert b' B)
    note IH = this
    have id: "(a{a'} × B. case a of (a, b)  g a b) = finprod G (g a') B"
      by (rule IH(3)[OF IH(4)], auto)
    have id2: " x F. {a'} × insert x F = insert (a',x) ({a'} × F)" by auto
    have id3: "(ainsert (a', b') ({a'} × B). case a of (a, b)  g a b)
      = g a' b'  (a({a'} × B). case a of (a, b)  g a b)"
      by (rule trans[OF finprod_insert], insert IH, auto)
    show ?case unfolding id2 id3 id
      by (rule sym, rule finprod_insert, insert IH, auto)
  qed simp
  finally have idr: "?r = finprod G (g a') B  (aA × B. case a of (a, b)  g a b)" .
  show ?case unfolding idl idr ..
qed simp

lemma (in comm_monoid) finprod_swap:
  assumes "finite A" "finite B" " a b. a  A   b  B  g a b  carrier G"
  shows "finprod G (λ (b,a). g a b) (B × A) = finprod G (λ (a,b). g a b) (A × B)"
proof -
  have [simp]: "(λ(a, b). (b, a)) ` (A × B) = B × A" by auto
  have [simp]: "(λ x. case case x of (a, b)  (b, a) of (a, b)  g b a) = (λ (a,b). g a b)"
    by (intro ext, auto)
  show ?thesis 
    by (rule trans[OF trans[OF _ finprod_reindex[of "λ (a,b). g b a" "λ (a,b). (b,a)"]]],
    insert assms, auto simp: inj_on_def)
qed

lemma (in comm_monoid) finprod_finprod_swap:
  "finite A  finite B  ( a b. a  A   b  B  g a b  carrier G) 
  finprod G (λ a. finprod G (g a) B) A = finprod G (λ b. finprod G (λ a. g a b) A) B"
  using finprod_finprod[of A B] finprod_finprod[of B A] finprod_swap[of A B]
  by simp



lemmas (in semiring) finsum_zero' = add.finprod_one' 
lemmas (in semiring) finsum_split = add.finprod_split 
lemmas (in semiring) finsum_finsum_swap = add.finprod_finprod_swap


lemma (in csemiring) finprod_zero: 
  "finite A  f  A  carrier R  aA. f a = 𝟬
    finprod R f A = 𝟬"
proof (induct A rule: finite_induct)
  case (insert a A)
  from finprod_insert[OF insert(1-2), of f] insert(4)
  have ins: "finprod R f (insert a A) = f a  finprod R f A" by simp
  have fA: "finprod R f A  carrier R"
    by (rule finprod_closed, insert insert, auto)
  show ?case
  proof (cases "f a = 𝟬")
    case True
    with fA show ?thesis unfolding ins by simp
  next
    case False
    with insert(5) have " a  A. f a = 𝟬" by auto
    from insert(3)[OF _ this] insert have "finprod R f A = 𝟬" by auto
    with insert show ?thesis unfolding ins by auto
  qed
qed simp

lemma (in semiring) finsum_product:
  assumes A: "finite A" and B: "finite B"
  and f: "f  A  carrier R" and g: "g  B  carrier R" 
  shows "finsum R f A  finsum R g B = (iA. jB. f i  g j)"
  unfolding finsum_ldistr[OF A finsum_closed[OF g] f]
proof (rule finsum_cong'[OF refl])
  fix a
  assume a: "a  A"
  show "f a  finsum R g B = (jB. f a  g j)"
  by (rule finsum_rdistr[OF B _ g], insert a f, auto)
qed (insert f g B, auto intro: finsum_closed)
    
lemma (in semiring) Units_one_side_I: 
  "a  carrier R  p  Units R  p  a = 𝟭  a  Units R"
  "a  carrier R  p  Units R  a  p = 𝟭  a  Units R"
  by (metis Units_closed Units_inv_Units Units_l_inv inv_unique)+

context ordered_cancel_semiring begin
subclass ordered_cancel_ab_semigroup_add ..
end

text ‹partially ordered variant›
class ordered_semiring_strict = semiring + comm_monoid_add + ordered_cancel_ab_semigroup_add +
  assumes mult_strict_left_mono: "a < b  0 < c  c * a < c * b"
  assumes mult_strict_right_mono: "a < b  0 < c  a * c < b * c"
begin

subclass semiring_0_cancel ..

subclass ordered_semiring
proof
  fix a b c :: 'a
  assume A: "a  b" "0  c"
  from A show "c * a  c * b"
    unfolding le_less
    using mult_strict_left_mono by (cases "c = 0") auto
  from A show "a * c  b * c"
    unfolding le_less
    using mult_strict_right_mono by (cases "c = 0") auto
qed

lemma mult_pos_pos[simp]: "0 < a  0 < b  0 < a * b"
using mult_strict_left_mono [of 0 b a] by simp

lemma mult_pos_neg: "0 < a  b < 0  a * b < 0"
using mult_strict_left_mono [of b 0 a] by simp

lemma mult_neg_pos: "a < 0  0 < b  a * b < 0"
using mult_strict_right_mono [of a 0 b] by simp

text ‹Legacy - use mult_neg_pos›
lemma mult_pos_neg2: "0 < a  b < 0  b * a < 0" 
by (drule mult_strict_right_mono [of b 0], auto)

text‹Strict monotonicity in both arguments›
lemma mult_strict_mono:
  assumes "a < b" and "c < d" and "0 < b" and "0  c"
  shows "a * c < b * d"
  using assms apply (cases "c=0")
  apply (simp)
  apply (erule mult_strict_right_mono [THEN less_trans])
  apply (force simp add: le_less)
  apply (erule mult_strict_left_mono, assumption)
  done

text‹This weaker variant has more natural premises›
lemma mult_strict_mono':
  assumes "a < b" and "c < d" and "0  a" and "0  c"
  shows "a * c < b * d"
by (rule mult_strict_mono) (insert assms, auto)

lemma mult_less_le_imp_less:
  assumes "a < b" and "c  d" and "0  a" and "0 < c"
  shows "a * c < b * d"
  using assms apply (subgoal_tac "a * c < b * c")
  apply (erule less_le_trans)
  apply (erule mult_left_mono)
  apply simp
  apply (erule mult_strict_right_mono)
  apply assumption
  done

lemma mult_le_less_imp_less:
  assumes "a  b" and "c < d" and "0 < a" and "0  c"
  shows "a * c < b * d"
  using assms apply (subgoal_tac "a * c  b * c")
  apply (erule le_less_trans)
  apply (erule mult_strict_left_mono)
  apply simp
  apply (erule mult_right_mono)
  apply simp
  done

end

class ordered_idom = idom + ordered_semiring_strict +
  assumes zero_less_one [simp]: "0 < 1" begin

subclass semiring_1 ..
subclass comm_ring_1 ..
subclass ordered_ring ..
subclass ordered_comm_semiring by(unfold_locales, fact mult_left_mono)
subclass ordered_ab_semigroup_add ..

lemma of_nat_ge_0[simp]: "of_nat x  0"
proof (induct x)
  case 0 thus ?case by auto
  next case (Suc x)
    hence "0  of_nat x" by auto
    also have "of_nat x < of_nat (Suc x)" by auto
    finally show ?case by auto
qed

lemma of_nat_eq_0[simp]: "of_nat x = 0  x = 0"
proof(induct x,simp)
  case (Suc x)
    have "of_nat (Suc x) > 0" apply(rule le_less_trans[of _ "of_nat x"]) by auto
    thus ?case by auto
qed

lemma inj_of_nat: "inj (of_nat :: nat  'a)"
proof(rule injI)
  fix x y show "of_nat x = of_nat y  x = y"
  proof (induct x arbitrary: y)
    case 0 thus ?case
      proof (induct y)
        case 0 thus ?case by auto
        next case (Suc y)
          hence "of_nat (Suc y) = 0" by auto
          hence "Suc y = 0" unfolding of_nat_eq_0 by auto
          hence False by auto
          thus ?case by auto
      qed
    next case (Suc x)
      thus ?case
      proof (induct y)
        case 0
          hence "of_nat (Suc x) = 0" by auto
          hence "Suc x = 0" unfolding of_nat_eq_0 by auto
          hence False by auto
          thus ?case by auto
        next case (Suc y) thus ?case by auto
      qed
  qed
qed

subclass ring_char_0 by(unfold_locales, fact inj_of_nat)

end

(*
instance linordered_idom ⊆ ordered_semiring_strict by (intro_classes,auto)
instance linordered_idom ⊆ ordered_idom by (intro_classes, auto)
*)

end

Theory Missing_Permutations

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Missing Permutations›

text ‹This theory provides some definitions and lemmas on permutations which we did not find in the 
  Isabelle distribution.›

theory Missing_Permutations
imports
  Missing_Ring
  "HOL-Combinatorics.Permutations"
begin

definition signof :: "(nat  nat)  'a :: ring_1" where
  "signof p = (if sign p = 1 then 1 else - 1)"

lemma signof_id[simp]: "signof id = 1" "signof (λ x. x) = 1"
  unfolding signof_def sign_id id_def[symmetric] by auto

lemma signof_inv: "finite S  p permutes S  signof (Hilbert_Choice.inv p) = signof p"
  unfolding signof_def using sign_inverse permutation_permutes by metis

lemma signof_pm_one: "signof p  {1, - 1}"
  unfolding signof_def by auto

lemma signof_compose: assumes "p permutes {0..<(n :: nat)}"
  and "q permutes {0 ..<(m :: nat)}"
  shows "signof (p o q) = signof p * signof q"
proof -
  from assms have pp: "permutation p" "permutation q"
    by (auto simp: permutation_permutes)
  show "signof (p o q) = signof p * signof q"
    unfolding signof_def sign_compose[OF pp] 
    by (auto simp: sign_def split: if_splits)
qed

lemma permutes_funcset: "p permutes A  (p ` A  B) = (A  B)"
  by (simp add: permutes_image)

context comm_monoid
begin
lemma finprod_permute:
  assumes p: "p permutes S"
  and f: "f  S  carrier G"
  shows "finprod G f S = finprod G (f  p) S"
proof -
  from p permutes S have "inj p"
    by (rule permutes_inj)
  then have "inj_on p S"
    by (auto intro: subset_inj_on)
  from finprod_reindex[OF _ this, unfolded permutes_image[OF p], OF f]
  show ?thesis unfolding o_def .
qed

lemma finprod_singleton_set[simp]: assumes "f a  carrier G"
  shows "finprod G f {a} = f a"
proof -
  have "finprod G f {a} = f a  finprod G f {}"
    by (rule finprod_insert, insert assms, auto)
  also have " = f a" using assms by auto
  finally show ?thesis .
qed
end

lemmas (in semiring) finsum_permute = add.finprod_permute
lemmas (in semiring) finsum_singleton_set = add.finprod_singleton_set

lemma permutes_less[simp]: assumes p: "p permutes {0..<(n :: nat)}"
  shows "i < n  p i < n" "i < n  Hilbert_Choice.inv p i < n" 
  "p (Hilbert_Choice.inv p i) = i"
  "Hilbert_Choice.inv p (p i) = i"
proof -
  assume i: "i < n"
  show "p i < n" using permutes_in_image[OF p] i by auto
  let ?inv = "Hilbert_Choice.inv p" 
  have "n. ?inv (p n) = n"
      using permutes_inverses[OF p] by simp
  thus "?inv i < n" 
      by (metis (no_types) atLeastLessThan_iff f_inv_into_f inv_into_into le0 permutes_image[OF p] i)
qed (insert permutes_inverses[OF p], auto)
    
context cring
begin

lemma finsum_permutations_inverse: 
  assumes f: "f  {p. p permutes S}  carrier R"
  shows "finsum R f {p. p permutes S} = finsum R (λp. f(Hilbert_Choice.inv p)) {p. p permutes S}"
  (is "?lhs = ?rhs")
proof -
  let ?inv = "Hilbert_Choice.inv"
  let ?S = "{p . p permutes S}"
  have th0: "inj_on ?inv ?S"
  proof (auto simp add: inj_on_def)
    fix q r
    assume q: "q permutes S"
      and r: "r permutes S"
      and qr: "?inv q = ?inv r"
    then have "?inv (?inv q) = ?inv (?inv r)"
      by simp
    with permutes_inv_inv[OF q] permutes_inv_inv[OF r] show "q = r"
      by metis
  qed
  have th1: "?inv ` ?S = ?S"
    using image_inverse_permutations by blast
  have th2: "?rhs = finsum R (f  ?inv) ?S"
    by (simp add: o_def)
  from finsum_reindex[OF _ th0, of f] show ?thesis unfolding th1 th2 using f .
qed

lemma finsum_permutations_compose_right: assumes q: "q permutes S"
  and *: "f  {p. p permutes S}  carrier R"
  shows "finsum R f {p. p permutes S} = finsum R (λp. f(p  q)) {p. p permutes S}"
  (is "?lhs = ?rhs")
proof -
  let ?S = "{p. p permutes S}"
  let ?inv = "Hilbert_Choice.inv"
  have th0: "?rhs = finsum R (f  (λp. p  q)) ?S"
    by (simp add: o_def)
  have th1: "inj_on (λp. p  q) ?S"
  proof (auto simp add: inj_on_def)
    fix p r
    assume "p permutes S"
      and r: "r permutes S"
      and rp: "p  q = r  q"
    then have "p  (q  ?inv q) = r  (q  ?inv q)"
      by (simp add: o_assoc)
    with permutes_surj[OF q, unfolded surj_iff] show "p = r"
      by simp
  qed
  have th3: "(λp. p  q) ` ?S = ?S"
    using image_compose_permutations_right[OF q] by auto
  from finsum_reindex[OF _ th1, of f]
  show ?thesis unfolding th0 th1 th3 using * .
qed

end

text ‹The following lemma is slightly generalized from Determinants.thy in HMA.›

lemma finite_bounded_functions:
  assumes fS: "finite S"
  shows "finite T  finite {f. (i  T. f i  S)  (i. i  T  f i = i)}"
proof (induct T rule: finite_induct)
  case empty
  have th: "{f. i. f i = i} = {id}"
    by auto
  show ?case
    by (auto simp add: th)
next
  case (insert a T)
  let ?f = "λ(y,g) i. if i = a then y else g i"
  let ?S = "?f ` (S × {f. (iT. f i  S)  (i. i  T  f i = i)})"
  have "?S = {f. (i insert a T. f i  S)  (i. i  insert a T  f i = i)}"
    apply (auto simp add: image_iff)
    apply (rule_tac x="x a" in bexI)
    apply (rule_tac x = "λi. if i = a then i else x i" in exI)
    apply (insert insert, auto)
    done
  with finite_imageI[OF finite_cartesian_product[OF fS insert.hyps(3)], of ?f]
  show ?case
    by metis
qed

lemma finite_bounded_functions':
  assumes fS: "finite S"
  shows "finite T  finite {f. (i  T. f i  S)  (i. i  T  f i = j)}"
proof (induct T rule: finite_induct)
  case empty
  have th: "{f. i. f i = j} = {(λ x. j)}"
    by auto
  show ?case
    by (auto simp add: th)
next
  case (insert a T)
  let ?f = "λ(y,g) i. if i = a then y else g i"
  let ?S = "?f ` (S × {f. (iT. f i  S)  (i. i  T  f i = j)})"
  have "?S = {f. (i insert a T. f i  S)  (i. i  insert a T  f i = j)}"
    apply (auto simp add: image_iff)
    apply (rule_tac x="x a" in bexI)
    apply (rule_tac x = "λi. if i = a then j else x i" in exI)
    apply (insert insert, auto)
    done
  with finite_imageI[OF finite_cartesian_product[OF fS insert.hyps(3)], of ?f]
  show ?case
    by metis
qed

context
  fixes A :: "'a set" 
    and B :: "'b set"
    and a_to_b :: "'a  'b"
    and b_to_a :: "'b  'a"
  assumes ab: " a. a  A  a_to_b a  B"
    and ba: " b. b  B  b_to_a b  A"
    and ab_ba: " a. a  A  b_to_a (a_to_b a) = a"
    and ba_ab: " b. b  B  a_to_b (b_to_a b) = b"
begin

qualified lemma permutes_memb: fixes p :: "'b  'b"
  assumes p: "p permutes B"
  and a: "a  A"
  defines "ip  Hilbert_Choice.inv p"
  shows "a  A" "a_to_b a  B" "ip (a_to_b a)  B" "p (a_to_b a)  B" 
    "b_to_a (p (a_to_b a))  A" "b_to_a (ip (a_to_b a))  A"
proof -
  let ?b = "a_to_b a"
  from p have ip: "ip permutes B" unfolding ip_def by (rule permutes_inv)
  note in_ip = permutes_in_image[OF ip]
  note in_p = permutes_in_image[OF p]
  show a: "a  A" by fact
  show b: "?b  B" by (rule ab[OF a])
  show pb: "p ?b  B" unfolding in_p by (rule b)
  show ipb: "ip ?b  B" unfolding in_ip by (rule b)
  show "b_to_a (p ?b)  A" by (rule ba[OF pb])
  show "b_to_a (ip ?b)  A" by (rule ba[OF ipb])
qed

lemma permutes_bij_main: 
  "{p . p permutes A}  (λ p a. if a  A then b_to_a (p (a_to_b a)) else a) ` {p . p permutes B}" 
  (is "?A  ?f ` ?B")
proof 
  note d = permutes_def
  let ?g = "λ q b. if b  B then a_to_b (q (b_to_a b)) else b"
  let ?inv = "Hilbert_Choice.inv"
  fix p
  assume p: "p  ?f ` ?B"
  then obtain q where q: "q permutes B" and p: "p = ?f q" by auto    
  let ?iq = "?inv q"
  from q have iq: "?iq permutes B" by (rule permutes_inv)
  note in_iq = permutes_in_image[OF iq]
  note in_q = permutes_in_image[OF q]
  have qiB: " b. b  B  q (?iq b) = b" using q by (rule permutes_inverses)
  have iqB: " b. b  B  ?iq (q b) = b" using q by (rule permutes_inverses)
  from q[unfolded d] 
  have q1: " b. b  B  q b = b" 
   and q2: " b. ∃!b'. q b' = b" by auto
  note memb = permutes_memb[OF q]
  show "p  ?A" unfolding p d
  proof (rule, intro conjI impI allI, force)
    fix a
    show "∃!a'. ?f q a' = a"
    proof (cases "a  A")
      case True
      note a = memb[OF True]
      let ?a = "b_to_a (?iq (a_to_b a))"
      show ?thesis
      proof 
        show "?f q ?a = a" using a by (simp add: ba_ab qiB ab_ba)
      next
        fix a'
        assume id: "?f q a' = a"
        show "a' = ?a"
        proof (cases "a'  A")
          case False
          thus ?thesis using id a by auto
        next
          case True
          note a' = memb[OF this]
          from id True have "b_to_a (q (a_to_b a')) = a" by simp
          from arg_cong[OF this, of "a_to_b"] a' a
          have "q (a_to_b a') = a_to_b a" by (simp add: ba_ab)
          from arg_cong[OF this, of ?iq]
          have "a_to_b a' = ?iq (a_to_b a)" unfolding iqB[OF a'(2)] .
          from arg_cong[OF this, of b_to_a] show ?thesis unfolding ab_ba[OF True] .
        qed
      qed
    next
      case False note a = this
      show ?thesis
      proof
        show "?f q a = a" using a by simp
      next
        fix a'
        assume id: "?f q a' = a"
        show "a' = a"
        proof (cases "a'  A")
          case False
          with id show ?thesis by simp
        next
          case True
          note a' = memb[OF True]
          with id False show ?thesis by auto
        qed
      qed
    qed
  qed
qed
end

lemma  permutes_bij': assumes ab: " a. a  A  a_to_b a  B"
    and ba: " b. b  B  b_to_a b  A"
    and ab_ba: " a. a  A  b_to_a (a_to_b a) = a"
    and ba_ab: " b. b  B  a_to_b (b_to_a b) = b"
  shows "{p . p permutes A} = (λ p a. if a  A then b_to_a (p (a_to_b a)) else a) ` {p . p permutes B}" 
  (is "?A = ?f ` ?B")
proof -
  note one_dir = ab ba ab_ba ba_ab
  note other_dir = ba ab ba_ab ab_ba
  let ?g = "(λ p b. if b  B then a_to_b (p (b_to_a b)) else b)"
  define PA where "PA = ?A"
  define f where "f = ?f"
  define g where "g = ?g"
  {
    fix p
    assume "p  PA"
    hence p: "p permutes A" unfolding PA_def by simp
    from p[unfolded permutes_def] have pnA: " a. a  A  p a = a" by auto
    have "?f (?g p) = p"
    proof (rule ext)
      fix a
      show "?f (?g p) a = p a"
      proof (cases "a  A")
        case False
        thus ?thesis by (simp add: pnA)
      next
        case True note a = this
        hence "p a  A" unfolding permutes_in_image[OF p] .
        thus ?thesis using a by (simp add: ab_ba ba_ab ab)
      qed
    qed
    hence "f (g p) = p" unfolding f_def g_def .
  }
  hence "f ` g ` PA = PA" by force
  hence id: "?f ` ?g ` ?A = ?A" unfolding PA_def f_def g_def .
  have "?f ` ?B  ?A" by (rule permutes_bij_main[OF one_dir])
  moreover have "?g ` ?A  ?B" by (rule permutes_bij_main[OF ba ab ba_ab ab_ba])
  hence "?f ` ?g ` ?A  ?f ` ?B" by auto
  hence "?A  ?f ` ?B" unfolding id .
  ultimately show ?thesis by blast
qed    

lemma inj_on_nat_permutes: assumes i: "inj_on f (S :: nat set)"
  and fS: "f  S  S"
  and fin: "finite S"
  and f: " i. i  S  f i = i"
  shows "f permutes S"
  unfolding permutes_def
proof (intro conjI allI impI, rule f)
  fix y
  from endo_inj_surj[OF fin _ i] fS have fs: "f ` S = S" by auto
  show "∃!x. f x = y"
  proof (cases "y  S")
    case False
    thus ?thesis by (intro ex1I[of _ y], insert fS f, auto)
  next
    case True
    with fs obtain x where x: "x  S" and fx: "f x = y" by force
    show ?thesis
    proof (rule ex1I, rule fx)
      fix x'
      assume fx': "f x' = y"
      with True f[of x'] have "x'  S" by metis
      from inj_onD[OF i fx[folded fx'] x this]
      show "x' = x" by simp
    qed
  qed
qed


lemma permutes_pair_eq:
  assumes p: "p permutes S"
  shows "{ (p s, s) | s. s  S } = { (s, Hilbert_Choice.inv p s) | s. s  S }"
    (is "?L = ?R")
proof
  show "?L  ?R"
  proof
    fix x assume "x  ?L"
    then obtain s where x: "x = (p s, s)" and s: "s  S" by auto
    note x
    also have "(p s, s) = (p s, Hilbert_Choice.inv p (p s))"
      using permutes_inj[OF p] inv_f_f by auto
    also have "...  ?R" using s permutes_in_image[OF p] by auto
    finally show "x  ?R".
  qed
  show "?R  ?L"
  proof
    fix x assume "x  ?R"
    then obtain s
      where x: "x = (s, Hilbert_Choice.inv p s)" (is "_ = (s, ?ips)")
        and s: "s  S" by auto
    note x
    also have "(s, ?ips) = (p ?ips, ?ips)"
      using inv_f_f[OF permutes_inj[OF permutes_inv[OF p]]]
      using inv_inv_eq[OF permutes_bij[OF p]] by auto
    also have "...  ?L"
      using s permutes_in_image[OF permutes_inv[OF p]] by auto
    finally show "x  ?L".
  qed
qed

lemma inj_on_finite[simp]:
  assumes inj: "inj_on f A" shows "finite (f ` A) = finite A"
proof
  assume fin: "finite (f ` A)"
  show "finite A"
  proof (cases "card (f ` A) = 0")
    case True thus ?thesis using fin by auto
    next case False 
      hence "card A > 0" unfolding card_image[OF inj] by auto
      thus ?thesis using card.infinite by force
  qed
qed auto

lemma permutes_prod:
  assumes p: "p permutes S"
  shows "(sS. f (p s) s) = (sS. f s (Hilbert_Choice.inv p s))"
    (is "?l = ?r")
proof -
  let ?f = "λ(x,y). f x y"
  let ?ps = "λs. (p s, s)"
  let ?ips = "λs. (s, Hilbert_Choice.inv p s)"
  have inj1: "inj_on ?ps S" by (rule inj_onI;auto)
  have inj2: "inj_on ?ips S" by (rule inj_onI;auto)
  have "?l = prod ?f (?ps ` S)"
    using prod.reindex[OF inj1, of ?f] by simp
  also have "?ps ` S = {(p s, s) |s. s  S}" by auto
  also have "... = {(s, Hilbert_Choice.inv p s) | s. s  S}"
    unfolding permutes_pair_eq[OF p] by simp
  also have "... = ?ips ` S" by auto
  also have "prod ?f ... = ?r"
    using prod.reindex[OF inj2, of ?f] by simp
  finally show ?thesis.
qed

lemma permutes_sum:
  assumes p: "p permutes S"
  shows "(sS. f (p s) s) = (sS. f s (Hilbert_Choice.inv p s))"
    (is "?l = ?r")
proof -
  let ?f = "λ(x,y). f x y"
  let ?ps = "λs. (p s, s)"
  let ?ips = "λs. (s, Hilbert_Choice.inv p s)"
  have inj1: "inj_on ?ps S" by (rule inj_onI;auto)
  have inj2: "inj_on ?ips S" by (rule inj_onI;auto)
  have "?l = sum ?f (?ps ` S)"
    using sum.reindex[OF inj1, of ?f] by simp
  also have "?ps ` S = {(p s, s) |s. s  S}" by auto
  also have "... = {(s, Hilbert_Choice.inv p s) | s. s  S}"
    unfolding permutes_pair_eq[OF p] by simp
  also have "... = ?ips ` S" by auto
  also have "sum ?f ... = ?r"
    using sum.reindex[OF inj2, of ?f] by simp
  finally show ?thesis.
qed

lemma inv_inj_on_permutes: "inj_on Hilbert_Choice.inv { p. p permutes S }"
proof (intro inj_onI, unfold mem_Collect_eq)
  let ?i = "Hilbert_Choice.inv"
  fix p q
  assume p: "p permutes S" and q: "q permutes S" and eq: "?i p = ?i q"
  have "?i (?i p) = ?i (?i q)" using eq by simp
  thus "p = q"
    using inv_inv_eq[OF permutes_bij] p q by metis
qed

lemma permutes_others:
  assumes p: "p permutes S" and x: "x  S" shows "p x = x"
  using p unfolding permutes_def using x by simp

end

Theory Conjugate

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
theory Conjugate
  imports HOL.Complex
begin

class conjugate =
  fixes conjugate :: "'a  'a"
  assumes conjugate_id[simp]: "conjugate (conjugate a) = a"
      and conjugate_cancel_iff[simp]: "conjugate a = conjugate b  a = b"

class conjugatable_ring = ring + conjugate +
  assumes conjugate_dist_mul: "conjugate (a * b) = conjugate a * conjugate b"
      and conjugate_dist_add: "conjugate (a + b) = conjugate a + conjugate b"
      and conjugate_neg: "conjugate (-a) = - conjugate a"
      and conjugate_zero[simp]: "conjugate 0 = 0"
begin
  lemma conjugate_zero_iff[simp]: "conjugate a = 0  a = 0"
    using conjugate_cancel_iff[of _ 0, unfolded conjugate_zero].
end

class conjugatable_field = conjugatable_ring + field

lemma sum_conjugate:
  fixes f :: "'b  'a :: conjugatable_ring"
  assumes finX: "finite X"
  shows "conjugate (sum f X) = sum (λx. conjugate (f x)) X"
  using finX by (induct set:finite, auto simp: conjugate_dist_add)

class conjugatable_ordered_ring = conjugatable_ring + ordered_comm_monoid_add +
  assumes conjugate_square_positive: "a * conjugate a  0"

class conjugatable_ordered_field = conjugatable_ordered_ring + field
begin
  subclass conjugatable_field..
end

lemma conjugate_square_0:
  fixes a :: "'a :: {conjugatable_ordered_ring, semiring_no_zero_divisors}"
  shows "a * conjugate a = 0  a = 0" by auto


subsection ‹Instantiations›

instantiation complex :: conjugatable_ordered_field
begin
  definition [simp]: "conjugate  cnj"
  definition [simp]: "x < y  Im x = Im y  Re x < Re y"
  definition [simp]: "x  y  Im x = Im y  Re x  Re y"
  
  instance by (intro_classes, auto simp: complex.expand)
end

instantiation real :: conjugatable_ordered_field
begin
  definition [simp]: "conjugate (x::real)  x"
  instance by (intro_classes, auto)
end

instantiation rat :: conjugatable_ordered_field
begin
  definition [simp]: "conjugate (x::rat)  x"
  instance by (intro_classes, auto)
end

instantiation int :: conjugatable_ordered_ring
begin
  definition [simp]: "conjugate (x::int)  x"
  instance by (intro_classes, auto)
end

lemma conjugate_square_eq_0 [simp]:
  fixes x :: "'a :: {conjugatable_ring,semiring_no_zero_divisors}"
  shows "x * conjugate x = 0  x = 0" "conjugate x * x = 0  x = 0"
  by auto

lemma conjugate_square_greater_0 [simp]:
  fixes x :: "'a :: {conjugatable_ordered_ring,ring_no_zero_divisors}"
  shows "x * conjugate x > 0  x  0" 
  using conjugate_square_positive[of x]
  by (auto simp: le_less)

lemma conjugate_square_smaller_0 [simp]:
  fixes x :: "'a :: {conjugatable_ordered_ring,ring_no_zero_divisors}"
  shows "¬ x * conjugate x < 0"
  using conjugate_square_positive[of x] by auto

end

Theory Matrix

(*
    Author:      René Thiemann
                 Akihisa Yamada
    License:     BSD
*)
(* with contributions from Alexander Bentkamp, Universität des Saarlandes *)

section‹Vectors and Matrices›

text ‹We define vectors as pairs of dimension and a characteristic function from natural numbers
to elements.
Similarly, matrices are defined as triples of two dimensions and one
characteristic function from pairs of natural numbers to elements.
Via a subtype we ensure that the characteristic function always behaves the same
on indices outside the intended one. Hence, every matrix has a unique representation.

In this part we define basic operations like matrix-addition, -multiplication, scalar-product,
etc. We connect these operations to HOL-Algebra with its explicit carrier sets.›

theory Matrix
imports
  Missing_Ring
  "HOL-Algebra.Module"
  Polynomial_Interpolation.Ring_Hom
  Conjugate
begin

subsection‹Vectors›

text ‹Here we specify which value should be returned in case
  an index is out of bounds. The current solution has the advantage
  that in the implementation later on, no index comparison has to be performed.›

definition undef_vec :: "nat  'a" where
  "undef_vec i  [] ! i"

definition mk_vec :: "nat  (nat  'a)  (nat  'a)" where
  "mk_vec n f  λ i. if i < n then f i else undef_vec (i - n)"

typedef 'a vec = "{(n, mk_vec n f) | n f :: nat  'a. True}"
  by auto

setup_lifting type_definition_vec

lift_definition dim_vec :: "'a vec  nat" is fst .
lift_definition vec_index :: "'a vec  (nat  'a)" (infixl "$" 100) is snd .
lift_definition vec :: "nat  (nat  'a)  'a vec"
  is "λ n f. (n, mk_vec n f)" by auto

lift_definition vec_of_list :: "'a list  'a vec" is
  "λ v. (length v, mk_vec (length v) (nth v))" by auto

lift_definition list_of_vec :: "'a vec  'a list" is
  "λ (n,v). map v [0 ..< n]" .

definition carrier_vec :: "nat  'a vec set" where
  "carrier_vec n = { v . dim_vec v = n}"

lemma carrier_vec_dim_vec[simp]: "v  carrier_vec (dim_vec v)" unfolding carrier_vec_def by auto

lemma dim_vec[simp]: "dim_vec (vec n f) = n" by transfer simp
lemma vec_carrier[simp]: "vec n f  carrier_vec n" unfolding carrier_vec_def by auto
lemma index_vec[simp]: "i < n  vec n f $ i = f i" by transfer (simp add: mk_vec_def)
lemma eq_vecI[intro]: "( i. i < dim_vec w  v $ i = w $ i)  dim_vec v = dim_vec w
   v = w"
  by (transfer, auto simp: mk_vec_def)

lemma carrier_dim_vec: "v  carrier_vec n  dim_vec v = n"
  unfolding carrier_vec_def by auto

lemma carrier_vecD[simp]: "v  carrier_vec n  dim_vec v = n" using carrier_dim_vec by auto

lemma carrier_vecI: "dim_vec v = n  v  carrier_vec n" using carrier_dim_vec by auto

instantiation vec :: (plus) plus
begin
definition plus_vec :: "'a vec  'a vec  'a :: plus vec" where
  "v1 + v2  vec (dim_vec v2) (λ i. v1 $ i + v2 $ i)"
instance ..
end

instantiation vec :: (minus) minus
begin
definition minus_vec :: "'a vec  'a vec  'a :: minus vec" where
  "v1 - v2  vec (dim_vec v2) (λ i. v1 $ i - v2 $ i)"
instance ..
end

definition
  zero_vec :: "nat  'a :: zero vec" ("0v")
  where "0v n  vec n (λ i. 0)"

lemma zero_carrier_vec[simp]: "0v n  carrier_vec n"
  unfolding zero_vec_def carrier_vec_def by auto

lemma index_zero_vec[simp]: "i < n  0v n $ i = 0" "dim_vec (0v n) = n"
  unfolding zero_vec_def by auto

lemma vec_of_dim_0[simp]: "dim_vec v = 0  v = 0v 0" by auto

definition
  unit_vec :: "nat  nat  ('a :: zero_neq_one) vec"
  where "unit_vec n i = vec n (λ j. if j = i then 1 else 0)"

lemma index_unit_vec[simp]:
  "i < n  j < n  unit_vec n i $ j = (if j = i then 1 else 0)"
  "i < n  unit_vec n i $ i = 1"
  "dim_vec (unit_vec n i) = n"
  unfolding unit_vec_def by auto

lemma unit_vec_eq[simp]:
  assumes i: "i < n"
  shows "(unit_vec n i = unit_vec n j) = (i = j)"
proof -
  have "i  j  unit_vec n i $ i  unit_vec n j $ i"
    unfolding unit_vec_def using i by simp
  then show ?thesis by metis
qed

lemma unit_vec_nonzero[simp]:
  assumes i_n: "i < n" shows "unit_vec n i  zero_vec n" (is "?l  ?r")
proof -
  have "?l $ i = 1" "?r $ i = 0" using i_n by auto
  thus "?l  ?r" by auto
qed

lemma unit_vec_carrier[simp]: "unit_vec n i  carrier_vec n"
  unfolding unit_vec_def carrier_vec_def by auto

definition unit_vecs:: "nat  'a :: zero_neq_one vec list"
  where "unit_vecs n = map (unit_vec n) [0..<n]"

text "List of first i units"

fun unit_vecs_first:: "nat  nat  'a::zero_neq_one vec list"
  where "unit_vecs_first n 0 = []"
    |   "unit_vecs_first n (Suc i) = unit_vecs_first n i @ [unit_vec n i]"

lemma unit_vecs_first: "unit_vecs n = unit_vecs_first n n"
  unfolding unit_vecs_def set_map set_upt
proof -
  {fix m
    have "m  n  map (unit_vec n) [0..<m] = unit_vecs_first n m"
    proof (induct m)
      case (Suc m) then have mn:"mn" by auto
        show ?case unfolding upt_Suc using Suc(1)[OF mn] by auto
    qed auto
  }
  thus "map (unit_vec n) [0..<n] = unit_vecs_first n n" by auto
qed

text "list of last i units"

fun unit_vecs_last:: "nat  nat  'a :: zero_neq_one vec list"
  where "unit_vecs_last n 0 = []"
    |   "unit_vecs_last n (Suc i) = unit_vec n (n - Suc i) # unit_vecs_last n i"

lemma unit_vecs_last_carrier: "set (unit_vecs_last n i)  carrier_vec n"
  by (induct i;auto)

lemma unit_vecs_last[code]: "unit_vecs n = unit_vecs_last n n"
proof -
  { fix m assume "m = n"
    have "m  n  map (unit_vec n) [n-m..<n] = unit_vecs_last n m"
      proof (induction m)
      case (Suc m)
        then have nm:"n - Suc m < n" by auto
        have ins: "[n - Suc m ..< n] = (n - Suc m) # [n - m ..< n]"
          unfolding upt_conv_Cons[OF nm]
          by (auto simp: Suc.prems Suc_diff_Suc Suc_le_lessD)
        show ?case
          unfolding ins
          unfolding unit_vecs_last.simps
          unfolding list.map
          using Suc
          unfolding Suc by auto
      qed simp
  }
  thus "unit_vecs n = unit_vecs_last n n"
    unfolding unit_vecs_def by auto
qed

lemma unit_vecs_carrier: "set (unit_vecs n)  carrier_vec n"
proof
  fix u :: "'a vec"  assume u: "u  set (unit_vecs n)"
  then obtain i where "u = unit_vec n i" unfolding unit_vecs_def by auto
  then show "u  carrier_vec n"
    using unit_vec_carrier by auto
qed

lemma unit_vecs_last_distinct:
  "j  n  i < n - j  unit_vec n i  set (unit_vecs_last n j)"
  by (induction j arbitrary:i, auto)

lemma unit_vecs_first_distinct:
  "i  j  j < n  unit_vec n j  set (unit_vecs_first n i)"
  by (induction i arbitrary:j, auto)

definition map_vec where "map_vec f v  vec (dim_vec v) (λi. f (v $ i))"

instantiation vec :: (uminus) uminus
begin
definition uminus_vec :: "'a :: uminus vec  'a vec" where
  "- v  vec (dim_vec v) (λ i. - (v $ i))"
instance ..
end

definition smult_vec :: "'a :: times  'a vec  'a vec" (infixl "v" 70)
  where "a v v  vec (dim_vec v) (λ i. a * v $ i)"

definition scalar_prod :: "'a vec  'a vec  'a :: semiring_0" (infix "" 70)
  where "v  w   i  {0 ..< dim_vec w}. v $ i * w $ i"

definition monoid_vec :: "'a itself  nat  ('a :: monoid_add vec) monoid" where
  "monoid_vec ty n  
    carrier = carrier_vec n,
    mult = (+),
    one = 0v n"

definition module_vec ::
  "'a :: semiring_1 itself  nat  ('a,'a vec) module" where
  "module_vec ty n  
    carrier = carrier_vec n,
    mult = undefined,
    one = undefined,
    zero = 0v n,
    add = (+),
    smult = (⋅v)"

lemma monoid_vec_simps:
  "mult (monoid_vec ty n) = (+)"
  "carrier (monoid_vec ty n) = carrier_vec n"
  "one (monoid_vec ty n) = 0v n"
  unfolding monoid_vec_def by auto

lemma module_vec_simps:
  "add (module_vec ty n) = (+)"
  "zero (module_vec ty n) = 0v n"
  "carrier (module_vec ty n) = carrier_vec n"
  "smult (module_vec ty n) = (⋅v)"
  unfolding module_vec_def by auto

definition finsum_vec :: "'a :: monoid_add itself  nat  ('c  'a vec)  'c set  'a vec" where
  "finsum_vec ty n = finprod (monoid_vec ty n)"

lemma index_add_vec[simp]:
  "i < dim_vec v2  (v1 + v2) $ i = v1 $ i + v2 $ i" "dim_vec (v1 + v2) = dim_vec v2"
  unfolding plus_vec_def by auto

lemma index_minus_vec[simp]:
  "i < dim_vec v2  (v1 - v2) $ i = v1 $ i - v2 $ i" "dim_vec (v1 - v2) = dim_vec v2"
  unfolding minus_vec_def by auto

lemma index_map_vec[simp]:
  "i < dim_vec v  map_vec f v $ i = f (v $ i)"
  "dim_vec (map_vec f v) = dim_vec v"
  unfolding map_vec_def by auto

lemma map_carrier_vec[simp]: "map_vec h v  carrier_vec n = (v  carrier_vec n)"
  unfolding map_vec_def carrier_vec_def by auto

lemma index_uminus_vec[simp]:
  "i < dim_vec v  (- v) $ i = - (v $ i)"
  "dim_vec (- v) = dim_vec v"
  unfolding uminus_vec_def by auto

lemma index_smult_vec[simp]:
  "i < dim_vec v  (a v v) $ i = a * v $ i" "dim_vec (a v v) = dim_vec v"
  unfolding smult_vec_def by auto

lemma add_carrier_vec[simp]:
  "v1  carrier_vec n  v2  carrier_vec n  v1 + v2  carrier_vec n"
  unfolding carrier_vec_def by auto

lemma minus_carrier_vec[simp]:
  "v1  carrier_vec n  v2  carrier_vec n  v1 - v2  carrier_vec n"
  unfolding carrier_vec_def by auto

lemma comm_add_vec[ac_simps]:
  "(v1 :: 'a :: ab_semigroup_add vec)  carrier_vec n  v2  carrier_vec n  v1 + v2 = v2 + v1"
  by (intro eq_vecI, auto simp: ac_simps)

lemma assoc_add_vec[simp]:
  "(v1 :: 'a :: semigroup_add vec)  carrier_vec n  v2  carrier_vec n  v3  carrier_vec n
   (v1 + v2) + v3 = v1 + (v2 + v3)"
  by (intro eq_vecI, auto simp: ac_simps)

lemma zero_minus_vec[simp]: "(v :: 'a :: group_add vec)  carrier_vec n  0v n - v = - v"
  by (intro eq_vecI, auto)

lemma minus_zero_vec[simp]: "(v :: 'a :: group_add vec)  carrier_vec n  v - 0v n = v"
  by (intro eq_vecI, auto)

lemma minus_cancel_vec[simp]: "(v :: 'a :: group_add vec)  carrier_vec n  v - v = 0v n"
  by (intro eq_vecI, auto)

lemma minus_add_uminus_vec: "(v :: 'a :: group_add vec)  carrier_vec n 
  w  carrier_vec n  v - w = v + (- w)"
  by (intro eq_vecI, auto)

lemma comm_monoid_vec: "comm_monoid (monoid_vec TYPE ('a :: comm_monoid_add) n)"
  by (unfold_locales, auto simp: monoid_vec_def ac_simps)

lemma left_zero_vec[simp]: "(v :: 'a :: monoid_add vec)  carrier_vec n   0v n + v = v" by auto

lemma right_zero_vec[simp]: "(v :: 'a :: monoid_add vec)  carrier_vec n   v + 0v n = v" by auto


lemma uminus_carrier_vec[simp]:
  "(- v  carrier_vec n) = (v  carrier_vec n)"
  unfolding carrier_vec_def by auto

lemma uminus_r_inv_vec[simp]:
  "(v :: 'a :: group_add vec)  carrier_vec n  (v + - v) = 0v n"
  by (intro eq_vecI, auto)

lemma uminus_l_inv_vec[simp]:
  "(v :: 'a :: group_add vec)  carrier_vec n  (- v + v) = 0v n"
  by (intro eq_vecI, auto)

lemma add_inv_exists_vec:
  "(v :: 'a :: group_add vec)  carrier_vec n   w  carrier_vec n. w + v = 0v n  v + w = 0v n"
  by (intro bexI[of _ "- v"], auto)

lemma comm_group_vec: "comm_group (monoid_vec TYPE ('a :: ab_group_add) n)"
  by (unfold_locales, insert add_inv_exists_vec, auto simp: monoid_vec_def ac_simps Units_def)

lemmas finsum_vec_insert =
  comm_monoid.finprod_insert[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemmas finsum_vec_closed =
  comm_monoid.finprod_closed[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemmas finsum_vec_empty =
  comm_monoid.finprod_empty[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemma smult_carrier_vec[simp]: "(a v v  carrier_vec n) = (v  carrier_vec n)"
  unfolding carrier_vec_def by auto

lemma scalar_prod_left_zero[simp]: "v  carrier_vec n  0v n  v = 0"
  unfolding scalar_prod_def
  by (rule sum.neutral, auto)

lemma scalar_prod_right_zero[simp]: "v  carrier_vec n  v  0v n = 0"
  unfolding scalar_prod_def
  by (rule sum.neutral, auto)

lemma scalar_prod_left_unit[simp]: assumes v: "(v :: 'a :: semiring_1 vec)  carrier_vec n" and i: "i < n"
  shows "unit_vec n i  v = v $ i"
proof -
  let ?f = "λ k. unit_vec n i $ k * v $ k"
  have id: "(k{0..<n}. ?f k) = unit_vec n i $ i * v $ i + (k{0..<n} - {i}. ?f k)"
    by (rule sum.remove, insert i, auto)
  also have "( k{0..<n} - {i}. ?f k) = 0"
    by (rule sum.neutral, insert i, auto)
  finally
  show ?thesis unfolding scalar_prod_def using i v by simp
qed

lemma scalar_prod_right_unit[simp]: assumes i: "i < n"
  shows "(v :: 'a :: semiring_1 vec)  unit_vec n i = v $ i"
proof -
  let ?f = "λ k. v $ k * unit_vec n i $ k"
  have id: "(k{0..<n}. ?f k) = v $ i * unit_vec n i $ i + (k{0..<n} - {i}. ?f k)"
    by (rule sum.remove, insert i, auto)
  also have "(k{0..<n} - {i}. ?f k) = 0"
    by (rule sum.neutral, insert i, auto)
  finally
  show ?thesis unfolding scalar_prod_def using i by simp
qed

lemma add_scalar_prod_distrib: assumes v: "v1  carrier_vec n" "v2  carrier_vec n" "v3  carrier_vec n"
  shows "(v1 + v2)  v3 = v1  v3 + v2  v3"
proof -
  have "(i{0..<dim_vec v3}. (v1 + v2) $ i * v3 $ i) = (i{0..<dim_vec v3}. v1 $ i * v3 $ i + v2 $ i * v3 $ i)"
    by (rule sum.cong, insert v, auto simp: algebra_simps)
  thus ?thesis unfolding scalar_prod_def using v by (auto simp: sum.distrib)
qed

lemma scalar_prod_add_distrib: assumes v: "v1  carrier_vec n" "v2  carrier_vec n" "v3  carrier_vec n"
  shows "v1  (v2 + v3) = v1  v2 + v1  v3"
proof -
  have "(i{0..<dim_vec v3}. v1 $ i * (v2 + v3) $ i) = (i{0..<dim_vec v3}. v1 $ i * v2 $ i + v1 $ i * v3 $ i)"
    by (rule sum.cong, insert v, auto simp: algebra_simps)
  thus ?thesis unfolding scalar_prod_def using v by (auto intro: sum.distrib)
qed

lemma smult_scalar_prod_distrib[simp]: assumes v: "v1  carrier_vec n" "v2  carrier_vec n"
  shows "(a v v1)  v2 = a * (v1  v2)"
  unfolding scalar_prod_def sum_distrib_left
  by (rule sum.cong, insert v, auto simp: ac_simps)

lemma scalar_prod_smult_distrib[simp]: assumes v: "v1  carrier_vec n" "v2  carrier_vec n"
  shows "v1  (a v v2) = (a :: 'a :: comm_ring) * (v1  v2)"
  unfolding scalar_prod_def sum_distrib_left
  by (rule sum.cong, insert v, auto simp: ac_simps)

lemma comm_scalar_prod: assumes "(v1 :: 'a :: comm_semiring_0 vec)  carrier_vec n" "v2  carrier_vec n"
  shows "v1  v2 = v2  v1"
  unfolding scalar_prod_def
  by (rule sum.cong, insert assms, auto simp: ac_simps)

lemma add_smult_distrib_vec:
  "((a::'a::ring) + b) v v = a v v + b v v"
  unfolding smult_vec_def plus_vec_def
  by (rule eq_vecI, auto simp: distrib_right)

lemma smult_add_distrib_vec:
  assumes "v  carrier_vec n" "w  carrier_vec n"
  shows "(a::'a::ring) v (v + w) = a v v + a v w"
  apply (rule eq_vecI)
  unfolding smult_vec_def plus_vec_def
  using assms distrib_left by auto

lemma smult_smult_assoc:
  "a v (b v v) = (a * b::'a::ring) v v"
  apply (rule sym, rule eq_vecI)
  unfolding smult_vec_def plus_vec_def using mult.assoc by auto

lemma one_smult_vec [simp]:
  "(1::'a::ring_1) v v = v" unfolding smult_vec_def
  by (rule eq_vecI,auto)

lemma uminus_zero_vec[simp]: "- (0v n) = (0v n :: 'a :: group_add vec)" 
  by (intro eq_vecI, auto)

lemma index_finsum_vec: assumes "finite F" and i: "i < n"
  and vs: "vs  F  carrier_vec n"
  shows "finsum_vec TYPE('a :: comm_monoid_add) n vs F $ i = sum (λ f. vs f $ i) F"
  using ‹finite F vs
proof (induct F)
  case (insert f F)
  hence IH: "finsum_vec TYPE('a) n vs F $ i = (fF. vs f $ i)"
    and vs: "vs  F  carrier_vec n" "vs f  carrier_vec n" by auto
  show ?case unfolding finsum_vec_insert[OF insert(1-2) vs]
    unfolding sum.insert[OF insert(1-2)]
    unfolding IH[symmetric]
    by (rule index_add_vec, insert i, insert finsum_vec_closed[OF vs(1)], auto)
qed (insert i, auto simp: finsum_vec_empty)

text ‹Definition of pointwise ordering on vectors for non-strict part, and
  strict version is defined in a way such that the @{class order} constraints are satisfied.›

instantiation vec :: (ord) ord
begin

definition less_eq_vec :: "'a vec  'a vec  bool" where
  "less_eq_vec v w = (dim_vec v = dim_vec w  ( i < dim_vec w. v $ i  w $ i))" 

definition less_vec :: "'a vec  'a vec  bool" where
  "less_vec v w = (v  w  ¬ (w  v))"
instance ..
end

instantiation vec :: (preorder) preorder
begin
instance
  by (standard, auto simp: less_vec_def less_eq_vec_def order_trans)
end

instantiation vec :: (order) order
begin
instance
  by (standard, intro eq_vecI, auto simp: less_eq_vec_def order.antisym)
end


subsection‹Matrices›

text ‹Similarly as for vectors, we specify which value should be returned in case
  an index is out of bounds. It is defined in a way that only few
  index comparisons have to be performed in the implementation.›

definition undef_mat :: "nat  nat  (nat × nat  'a)  nat × nat  'a" where
  "undef_mat nr nc f  λ (i,j). [[f (i,j). j <- [0 ..< nc]] . i <- [0 ..< nr]] ! i ! j"

lemma undef_cong_mat: assumes " i j. i < nr  j < nc  f (i,j) = f' (i,j)"
  shows "undef_mat nr nc f x = undef_mat nr nc f' x"
proof (cases x)
  case (Pair i j)
  have nth_map_ge: " i xs. ¬ i < length xs  xs ! i = [] ! (i - length xs)"
    by (metis append_Nil2 nth_append)
  note [simp] = Pair undef_mat_def nth_map_ge[of i] nth_map_ge[of j]
  show ?thesis
    by (cases "i < nr", simp, cases "j < nc", insert assms, auto)
qed

definition mk_mat :: "nat  nat  (nat × nat  'a)  (nat × nat  'a)" where
  "mk_mat nr nc f  λ (i,j). if i < nr  j < nc then f (i,j) else undef_mat nr nc f (i,j)"

lemma cong_mk_mat: assumes " i j. i < nr  j < nc  f (i,j) = f' (i,j)"
  shows "mk_mat nr nc f = mk_mat nr nc f'"
  using undef_cong_mat[of nr nc f f', OF assms]
  using assms unfolding mk_mat_def
  by auto

typedef 'a mat = "{(nr, nc, mk_mat nr nc f) | nr nc f :: nat × nat  'a. True}"
  by auto

setup_lifting type_definition_mat

lift_definition dim_row :: "'a mat  nat" is fst .
lift_definition dim_col :: "'a mat  nat" is "fst o snd" .
lift_definition index_mat :: "'a mat  (nat × nat  'a)" (infixl "$$" 100) is "snd o snd" .
lift_definition mat :: "nat  nat  (nat × nat  'a)  'a mat"
  is "λ nr nc f. (nr, nc, mk_mat nr nc f)" by auto
lift_definition mat_of_row_fun :: "nat  nat  (nat  'a vec)  'a mat" ("matr")
  is "λ nr nc f. (nr, nc, mk_mat nr nc (λ (i,j). f i $ j))" by auto

definition mat_to_list :: "'a mat  'a list list" where
  "mat_to_list A = [ [A $$ (i,j) . j <- [0 ..< dim_col A]] . i <- [0 ..< dim_row A]]"

fun square_mat :: "'a mat  bool" where "square_mat A = (dim_col A = dim_row A)"

definition upper_triangular :: "'a::zero mat  bool"
  where "upper_triangular A 
    i < dim_row A.  j < i. A $$ (i,j) = 0"

lemma upper_triangularD[elim] :
  "upper_triangular A  j < i  i < dim_row A  A $$ (i,j) = 0"
unfolding upper_triangular_def by auto

lemma upper_triangularI[intro] :
  "(i j. j < i  i < dim_row A  A $$ (i,j) = 0)  upper_triangular A"
unfolding upper_triangular_def by auto

lemma dim_row_mat[simp]: "dim_row (mat nr nc f) = nr" "dim_row (matr nr nc g) = nr"
  by (transfer, simp)+

lemma dim_col_mat[simp]: "dim_col (mat nr nc f) = nc" "dim_col (matr nr nc g) = nc"
  by (transfer, simp)+

definition carrier_mat :: "nat  nat  'a mat set"
  where "carrier_mat nr nc = { m . dim_row m = nr  dim_col m = nc}"

lemma carrier_mat_triv[simp]: "m  carrier_mat (dim_row m) (dim_col m)"
  unfolding carrier_mat_def by auto

lemma mat_carrier[simp]: "mat nr nc f  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

definition elements_mat :: "'a mat  'a set"
  where "elements_mat A = set [A $$ (i,j). i <- [0 ..< dim_row A], j <- [0 ..< dim_col A]]"

lemma elements_matD [dest]:
  "a  elements_mat A  i j. i < dim_row A  j < dim_col A  a = A $$ (i,j)"
  unfolding elements_mat_def by force

lemma elements_matI [intro]:
  "A  carrier_mat nr nc  i < nr  j < nc  a = A $$ (i,j)  a  elements_mat A"
  unfolding elements_mat_def carrier_mat_def by force

lemma index_mat[simp]:  "i < nr  j < nc  mat nr nc f $$ (i,j) = f (i,j)"
  "i < nr  j < nc  matr nr nc g $$ (i,j) = g i $ j"
  by (transfer', simp add: mk_mat_def)+

lemma eq_matI[intro]: "( i j . i < dim_row B  j < dim_col B  A $$ (i,j) = B $$ (i,j))
   dim_row A = dim_row B
   dim_col A = dim_col B
   A = B"
  by (transfer, auto intro!: cong_mk_mat, auto simp: mk_mat_def)

lemma carrier_matI[intro]:
  assumes "dim_row A = nr" "dim_col A = nc" shows  "A  carrier_mat nr nc"
  using assms unfolding carrier_mat_def by auto

lemma carrier_matD[dest,simp]: assumes "A  carrier_mat nr nc"
  shows "dim_row A = nr" "dim_col A = nc" using assms
  unfolding carrier_mat_def by auto

lemma cong_mat: assumes "nr = nr'" "nc = nc'" " i j. i < nr  j < nc 
  f (i,j) = f' (i,j)" shows "mat nr nc f = mat nr' nc' f'"
  by (rule eq_matI, insert assms, auto)

definition row :: "'a mat  nat  'a vec" where
  "row A i = vec (dim_col A) (λ j. A $$ (i,j))"

definition rows :: "'a mat  'a vec list" where
  "rows A = map (row A) [0..<dim_row A]"

lemma row_carrier[simp]: "row A i  carrier_vec (dim_col A)" unfolding row_def by auto

lemma rows_carrier[simp]: "set (rows A)  carrier_vec (dim_col A)" unfolding rows_def by auto

lemma length_rows[simp]: "length (rows A) = dim_row A" unfolding rows_def by auto

lemma nth_rows[simp]: "i < dim_row A  rows A ! i = row A i"
  unfolding rows_def by auto

lemma row_mat_of_row_fun[simp]: "i < nr  dim_vec (f i) = nc  row (matr nr nc f) i = f i"
  by (rule eq_vecI, auto simp: row_def)

lemma set_rows_carrier:
  assumes "A  carrier_mat m n" and "v  set (rows A)" shows "v  carrier_vec n"
  using assms by (auto simp: rows_def row_def)

definition mat_of_rows :: "nat  'a vec list  'a mat"
  where "mat_of_rows n rs = mat (length rs) n (λ(i,j). rs ! i $ j)"

definition mat_of_rows_list :: "nat  'a list list  'a mat" where
  "mat_of_rows_list nc rs = mat (length rs) nc (λ (i,j). rs ! i ! j)"

lemma mat_of_rows_carrier[simp]:
  "mat_of_rows n vs  carrier_mat (length vs) n"
  "dim_row (mat_of_rows n vs) = length vs"
  "dim_col (mat_of_rows n vs) = n"
  unfolding mat_of_rows_def by auto

lemma mat_of_rows_row[simp]:
  assumes i:"i < length vs" and n: "vs ! i  carrier_vec n"
  shows "row (mat_of_rows n vs) i = vs ! i"
  unfolding mat_of_rows_def row_def using n i by auto

lemma rows_mat_of_rows[simp]:
  assumes "set vs  carrier_vec n" shows "rows (mat_of_rows n vs) = vs"
  unfolding rows_def apply (rule nth_equalityI)
  using assms unfolding subset_code(1) by auto

lemma mat_of_rows_rows[simp]:
  "mat_of_rows (dim_col A) (rows A) = A"
  unfolding mat_of_rows_def by (rule, auto simp: row_def)


definition col :: "'a mat  nat  'a vec" where
  "col A j = vec (dim_row A) (λ i. A $$ (i,j))"

definition cols :: "'a mat  'a vec list" where
  "cols A = map (col A) [0..<dim_col A]"

definition mat_of_cols :: "nat  'a vec list  'a mat"
  where "mat_of_cols n cs = mat n (length cs) (λ(i,j). cs ! j $ i)"

definition mat_of_cols_list :: "nat  'a list list  'a mat" where
  "mat_of_cols_list nr cs = mat nr (length cs) (λ (i,j). cs ! j ! i)"

lemma col_dim[simp]: "col A i  carrier_vec (dim_row A)" unfolding col_def by auto

lemma dim_col[simp]: "dim_vec (col A i) = dim_row A" by auto

lemma cols_dim[simp]: "set (cols A)  carrier_vec (dim_row A)" unfolding cols_def by auto

lemma cols_length[simp]: "length (cols A) = dim_col A" unfolding cols_def by auto

lemma cols_nth[simp]: "i < dim_col A  cols A ! i = col A i"
  unfolding cols_def by auto

lemma mat_of_cols_carrier[simp]:
  "mat_of_cols n vs  carrier_mat n (length vs)"
  "dim_row (mat_of_cols n vs) = n"
  "dim_col (mat_of_cols n vs) = length vs"
  unfolding mat_of_cols_def by auto

lemma col_mat_of_cols[simp]:
  assumes j:"j < length vs" and n: "vs ! j  carrier_vec n"
  shows "col (mat_of_cols n vs) j = vs ! j"
  unfolding mat_of_cols_def col_def using j n by auto

lemma cols_mat_of_cols[simp]:
  assumes "set vs  carrier_vec n" shows "cols (mat_of_cols n vs) = vs"
  unfolding cols_def apply(rule nth_equalityI)
  using assms unfolding subset_code(1) by auto

lemma mat_of_cols_cols[simp]:
  "mat_of_cols (dim_row A) (cols A) = A"
  unfolding mat_of_cols_def by (rule, auto simp: col_def)


instantiation mat :: (ord) ord
begin

definition less_eq_mat :: "'a mat  'a mat  bool" where
  "less_eq_mat A B = (dim_row A = dim_row B  dim_col A = dim_col B  
      ( i < dim_row B.  j < dim_col B. A $$ (i,j)  B $$ (i,j)))" 

definition less_mat :: "'a mat  'a mat  bool" where
  "less_mat A B = (A  B  ¬ (B  A))"
instance ..
end

instantiation mat :: (preorder) preorder
begin
instance
proof (standard, auto simp: less_mat_def less_eq_mat_def, goal_cases)
  case (1 A B C i j)
  thus ?case using order_trans[of "A $$ (i,j)" "B $$ (i,j)" "C $$ (i,j)"] by auto
qed
end

instantiation mat :: (order) order
begin
instance
  by (standard, intro eq_matI, auto simp: less_eq_mat_def order.antisym)
end

instantiation mat :: (plus) plus
begin
definition plus_mat :: "('a :: plus) mat  'a mat  'a mat" where
  "A + B  mat (dim_row B) (dim_col B) (λ ij. A $$ ij + B $$ ij)"
instance ..
end

definition map_mat :: "('a  'b)  'a mat  'b mat" where
  "map_mat f A  mat (dim_row A) (dim_col A) (λ ij. f (A $$ ij))"

definition smult_mat :: "'a :: times  'a mat  'a mat" (infixl "m" 70)
  where "a m A  map_mat (λ b. a * b) A"

definition zero_mat :: "nat  nat  'a :: zero mat" ("0m") where
  "0m nr nc  mat nr nc (λ ij. 0)"

lemma elements_0_mat [simp]: "elements_mat (0m nr nc)  {0}"
  unfolding elements_mat_def zero_mat_def by auto

definition transpose_mat :: "'a mat  'a mat" where
  "transpose_mat A  mat (dim_col A) (dim_row A) (λ (i,j). A $$ (j,i))"

definition one_mat :: "nat  'a :: {zero,one} mat" ("1m") where
  "1m n  mat n n (λ (i,j). if i = j then 1 else 0)"

instantiation mat :: (uminus) uminus
begin
definition uminus_mat :: "'a :: uminus mat  'a mat" where
  "- A  mat (dim_row A) (dim_col A) (λ ij. - (A $$ ij))"
instance ..
end

instantiation mat :: (minus) minus
begin
definition minus_mat :: "('a :: minus) mat  'a mat  'a mat" where
  "A - B  mat (dim_row B) (dim_col B) (λ ij. A $$ ij - B $$ ij)"
instance ..
end

instantiation mat :: (semiring_0) times
begin
definition times_mat :: "'a :: semiring_0 mat  'a mat  'a mat"
  where "A * B  mat (dim_row A) (dim_col B) (λ (i,j). row A i  col B j)"
instance ..
end

definition mult_mat_vec :: "'a :: semiring_0 mat  'a vec  'a vec" (infixl "*v" 70)
  where "A *v v  vec (dim_row A) (λ i. row A i  v)"

definition inverts_mat :: "'a :: semiring_1 mat  'a mat  bool" where
  "inverts_mat A B  A * B = 1m (dim_row A)"

definition invertible_mat :: "'a :: semiring_1 mat  bool"
  where "invertible_mat A  square_mat A  (B. inverts_mat A B  inverts_mat B A)"

definition monoid_mat :: "'a :: monoid_add itself  nat  nat  'a mat monoid" where
  "monoid_mat ty nr nc  
    carrier = carrier_mat nr nc,
    mult = (+),
    one = 0m nr nc"

definition ring_mat :: "'a :: semiring_1 itself  nat  'b  ('a mat,'b) ring_scheme" where
  "ring_mat ty n b  
    carrier = carrier_mat n n,
    mult = (*),
    one = 1m n,
    zero = 0m n n,
    add = (+),
     = b"

definition module_mat :: "'a :: semiring_1 itself  nat  nat  ('a,'a mat)module" where
  "module_mat ty nr nc  
    carrier = carrier_mat nr nc,
    mult = (*),
    one = 1m nr,
    zero = 0m nr nc,
    add = (+),
    smult = (⋅m)"

lemma ring_mat_simps:
  "mult (ring_mat ty n b) = (*)"
  "add (ring_mat ty n b) = (+)"
  "one (ring_mat ty n b) = 1m n"
  "zero (ring_mat ty n b) = 0m n n"
  "carrier (ring_mat ty n b) = carrier_mat n n"
  unfolding ring_mat_def by auto

lemma module_mat_simps:
  "mult (module_mat ty nr nc) = (*)"
  "add (module_mat ty nr nc) = (+)"
  "one (module_mat ty nr nc) = 1m nr"
  "zero (module_mat ty nr nc) = 0m nr nc"
  "carrier (module_mat ty nr nc) = carrier_mat nr nc"
  "smult (module_mat ty nr nc) = (⋅m)"
  unfolding module_mat_def by auto

lemma index_zero_mat[simp]: "i < nr  j < nc  0m nr nc $$ (i,j) = 0"
  "dim_row (0m nr nc) = nr" "dim_col (0m nr nc) = nc"
  unfolding zero_mat_def by auto

lemma index_one_mat[simp]: "i < n  j < n  1m n $$ (i,j) = (if i = j then 1 else 0)"
  "dim_row (1m n) = n" "dim_col (1m n) = n"
  unfolding one_mat_def by auto

lemma index_add_mat[simp]:
  "i < dim_row B  j < dim_col B  (A + B) $$ (i,j) = A $$ (i,j) + B $$ (i,j)"
  "dim_row (A + B) = dim_row B" "dim_col (A + B) = dim_col B"
  unfolding plus_mat_def by auto

lemma index_minus_mat[simp]:
  "i < dim_row B  j < dim_col B  (A - B) $$ (i,j) = A $$ (i,j) - B $$ (i,j)"
  "dim_row (A - B) = dim_row B" "dim_col (A - B) = dim_col B"
  unfolding minus_mat_def by auto

lemma index_map_mat[simp]:
  "i < dim_row A  j < dim_col A  map_mat f A $$ (i,j) = f (A $$ (i,j))"
  "dim_row (map_mat f A) = dim_row A" "dim_col (map_mat f A) = dim_col A"
  unfolding map_mat_def by auto

lemma index_smult_mat[simp]:
  "i < dim_row A  j < dim_col A  (a m A) $$ (i,j) = a * A $$ (i,j)"
  "dim_row (a m A) = dim_row A" "dim_col (a m A) = dim_col A"
  unfolding smult_mat_def by auto

lemma index_uminus_mat[simp]:
  "i < dim_row A  j < dim_col A  (- A) $$ (i,j) = - (A $$ (i,j))"
  "dim_row (- A) = dim_row A" "dim_col (- A) = dim_col A"
  unfolding uminus_mat_def by auto

lemma index_transpose_mat[simp]:
  "i < dim_col A  j < dim_row A  transpose_mat A $$ (i,j) = A $$ (j,i)"
  "dim_row (transpose_mat A) = dim_col A" "dim_col (transpose_mat A) = dim_row A"
  unfolding transpose_mat_def by auto

lemma index_mult_mat[simp]:
  "i < dim_row A  j < dim_col B  (A * B) $$ (i,j) = row A i  col B j"
  "dim_row (A * B) = dim_row A" "dim_col (A * B) = dim_col B"
  by (auto simp: times_mat_def)

lemma dim_mult_mat_vec[simp]: "dim_vec (A *v v) = dim_row A"
  by (auto simp: mult_mat_vec_def)

lemma index_mult_mat_vec[simp]: "i < dim_row A  (A *v v) $ i = row A i  v"
  by (auto simp: mult_mat_vec_def)

lemma index_row[simp]:
  "i < dim_row A  j < dim_col A  row A i $ j = A $$ (i,j)"
  "dim_vec (row A i) = dim_col A"
  by (auto simp: row_def)

lemma index_col[simp]: "i < dim_row A  j < dim_col A  col A j $ i = A $$ (i,j)"
  by (auto simp: col_def)

lemma upper_triangular_one[simp]: "upper_triangular (1m n)"
  by (rule, auto)

lemma upper_triangular_zero[simp]: "upper_triangular (0m n n)"
  by (rule, auto)

lemma mat_row_carrierI[intro,simp]: "matr nr nc r  carrier_mat nr nc"
  by (unfold carrier_mat_def carrier_vec_def, auto)

lemma eq_rowI: assumes rows: " i. i < dim_row B  row A i = row B i"
  and dims: "dim_row A = dim_row B" "dim_col A = dim_col B"
  shows "A = B"
proof (rule eq_matI[OF _ dims])
  fix i j
  assume i: "i < dim_row B" and j: "j < dim_col B"
  from rows[OF i] have id: "row A i $ j = row B i $ j" by simp
  show "A $$ (i, j) = B $$ (i, j)"
    using index_row(1)[OF i j, folded id] index_row(1)[of i A j] i j dims
    by auto
qed

lemma row_mat[simp]: "i < nr  row (mat nr nc f) i = vec nc (λ j. f (i,j))"
  by auto

lemma col_mat[simp]: "j < nc  col (mat nr nc f) j = vec nr (λ i. f (i,j))"
  by auto

lemma zero_carrier_mat[simp]: "0m nr nc  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

lemma smult_carrier_mat[simp]:
  "A  carrier_mat nr nc  k m A  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

lemma add_carrier_mat[simp]:
  "B  carrier_mat nr nc  A + B  carrier_mat nr nc"
  unfolding carrier_mat_def by force

lemma one_carrier_mat[simp]: "1m n  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma uminus_carrier_mat:
  "A  carrier_mat nr nc  (- A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma uminus_carrier_iff_mat[simp]:
  "(- A  carrier_mat nr nc) = (A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma minus_carrier_mat:
  "B  carrier_mat nr nc  (A - B  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma transpose_carrier_mat[simp]: "(transpose_mat A  carrier_mat nc nr) = (A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma row_carrier_vec[simp]: "i < nr  A  carrier_mat nr nc  row A i  carrier_vec nc"
  unfolding carrier_vec_def by auto

lemma col_carrier_vec[simp]: "j < nc  A  carrier_mat nr nc  col A j  carrier_vec nr"
  unfolding carrier_vec_def by auto

lemma mult_carrier_mat[simp]:
  "A  carrier_mat nr n  B  carrier_mat n nc  A * B  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

lemma mult_mat_vec_carrier[simp]:
  "A  carrier_mat nr n  v  carrier_vec n  A *v v  carrier_vec nr"
  unfolding carrier_mat_def carrier_vec_def by auto


lemma comm_add_mat[ac_simps]:
  "(A :: 'a :: comm_monoid_add mat)  carrier_mat nr nc  B  carrier_mat nr nc  A + B = B + A"
  by (intro eq_matI, auto simp: ac_simps)


lemma minus_r_inv_mat[simp]:
  "(A :: 'a :: group_add mat)  carrier_mat nr nc  (A - A) = 0m nr nc"
  by (intro eq_matI, auto)

lemma uminus_l_inv_mat[simp]:
  "(A :: 'a :: group_add mat)  carrier_mat nr nc  (- A + A) = 0m nr nc"
  by (intro eq_matI, auto)

lemma add_inv_exists_mat:
  "(A :: 'a :: group_add mat)  carrier_mat nr nc   B  carrier_mat nr nc. B + A = 0m nr nc  A + B = 0m nr nc"
  by (intro bexI[of _ "- A"], auto)

lemma assoc_add_mat[simp]:
  "(A :: 'a :: monoid_add mat)  carrier_mat nr nc  B  carrier_mat nr nc  C  carrier_mat nr nc
   (A + B) + C = A + (B + C)"
  by (intro eq_matI, auto simp: ac_simps)

lemma uminus_add_mat: fixes A :: "'a :: group_add mat"
  assumes "A  carrier_mat nr nc"
  and "B  carrier_mat nr nc"
  shows "- (A + B) = - B + - A"
  by (intro eq_matI, insert assms, auto simp: minus_add)

lemma transpose_transpose[simp]:
  "transpose_mat (transpose_mat A) = A"
  by (intro eq_matI, auto)

lemma transpose_one[simp]: "transpose_mat (1m n) = (1m n)"
  by auto

lemma row_transpose[simp]:
  "j < dim_col A  row (transpose_mat A) j = col A j"
  unfolding row_def col_def
  by (intro eq_vecI, auto)

lemma col_transpose[simp]:
  "i < dim_row A  col (transpose_mat A) i = row A i"
  unfolding row_def col_def
  by (intro eq_vecI, auto)

lemma row_zero[simp]:
  "i < nr  row (0m nr nc) i = 0v nc"
   by (intro eq_vecI, auto)

lemma col_zero[simp]:
  "j < nc  col (0m nr nc) j = 0v nr"
   by (intro eq_vecI, auto)

lemma row_one[simp]:
  "i < n  row (1m n) i = unit_vec n i"
  by (intro eq_vecI, auto)

lemma col_one[simp]:
  "j < n  col (1m n) j = unit_vec n j"
  by (intro eq_vecI, auto)

lemma transpose_add: "A  carrier_mat nr nc  B  carrier_mat nr nc
   transpose_mat (A + B) = transpose_mat A + transpose_mat B"
  by (intro eq_matI, auto)

lemma transpose_minus: "A  carrier_mat nr nc  B  carrier_mat nr nc
   transpose_mat (A - B) = transpose_mat A - transpose_mat B"
  by (intro eq_matI, auto)

lemma transpose_uminus: "A  carrier_mat nr nc  transpose_mat (- A) = - (transpose_mat A)"
  by (intro eq_matI, auto)

lemma row_add[simp]:
  "A  carrier_mat nr nc  B  carrier_mat nr nc  i < nr
   row (A + B) i = row A i + row B i"
  "i < dim_row A  dim_row B = dim_row A  dim_col B = dim_col A  row (A + B) i = row A i + row B i"
  by (rule eq_vecI, auto)

lemma col_add[simp]:
  "A  carrier_mat nr nc  B  carrier_mat nr nc  j < nc
   col (A + B) j = col A j + col B j"
  by (rule eq_vecI, auto)

lemma row_mult[simp]: assumes m: "A  carrier_mat nr n" "B  carrier_mat n nc"
  and i: "i < nr"
  shows "row (A * B) i = vec nc (λ j. row A i  col B j)"
  by (rule eq_vecI, insert m i, auto)

lemma col_mult[simp]: assumes m: "A  carrier_mat nr n" "B  carrier_mat n nc"
  and j: "j < nc"
  shows "col (A * B) j = vec nr (λ i. row A i  col B j)"
  by (rule eq_vecI, insert m j, auto)

lemma transpose_mult:
  "(A :: 'a :: comm_semiring_0 mat)  carrier_mat nr n  B  carrier_mat n nc
   transpose_mat (A * B) = transpose_mat B * transpose_mat A"
  by (intro eq_matI, auto simp: comm_scalar_prod[of _ n])

lemma left_add_zero_mat[simp]:
  "(A :: 'a :: monoid_add mat)  carrier_mat nr nc   0m nr nc + A = A"
  by (intro eq_matI, auto)

lemma add_uminus_minus_mat: "A  carrier_mat nr nc  B  carrier_mat nr nc  
  A + (- B) = A - (B :: 'a :: group_add mat)" 
  by (intro eq_matI, auto)

lemma right_add_zero_mat[simp]: "A  carrier_mat nr nc  
  A + 0m nr nc = (A :: 'a :: monoid_add mat)" 
  by (intro eq_matI, auto)

lemma left_mult_zero_mat:
  "A  carrier_mat n nc  0m nr n * A = 0m nr nc"
  by (intro eq_matI, auto)

lemma left_mult_zero_mat'[simp]: "dim_row A = n  0m nr n * A = 0m nr (dim_col A)"
  by (rule left_mult_zero_mat, unfold carrier_mat_def, simp)

lemma right_mult_zero_mat:
  "A  carrier_mat nr n  A * 0m n nc = 0m nr nc"
  by (intro eq_matI, auto)

lemma right_mult_zero_mat'[simp]: "dim_col A = n  A * 0m n nc = 0m (dim_row A) nc"
  by (rule right_mult_zero_mat, unfold carrier_mat_def, simp)

lemma left_mult_one_mat:
  "(A :: 'a :: semiring_1 mat)  carrier_mat nr nc  1m nr * A = A"
  by (intro eq_matI, auto)

lemma left_mult_one_mat'[simp]: "dim_row (A :: 'a :: semiring_1 mat) = n  1m n * A = A"
  by (rule left_mult_one_mat, unfold carrier_mat_def, simp)

lemma right_mult_one_mat:
  "(A :: 'a :: semiring_1 mat)  carrier_mat nr nc  A * 1m nc = A"
  by (intro eq_matI, auto)

lemma right_mult_one_mat'[simp]: "dim_col (A :: 'a :: semiring_1 mat) = n  A * 1m n = A"
  by (rule right_mult_one_mat, unfold carrier_mat_def, simp)

lemma one_mult_mat_vec[simp]:
  "(v :: 'a :: semiring_1 vec)  carrier_vec n  1m n *v v = v"
  by (intro eq_vecI, auto)

lemma minus_add_uminus_mat: fixes A :: "'a :: group_add mat"
  shows "A  carrier_mat nr nc  B  carrier_mat nr nc 
  A - B = A + (- B)"
  by (intro eq_matI, auto)

lemma add_mult_distrib_mat[algebra_simps]: assumes m: "A  carrier_mat nr n"
  "B  carrier_mat nr n" "C  carrier_mat n nc"
  shows "(A + B) * C = A * C + B * C"
  using m by (intro eq_matI, auto simp: add_scalar_prod_distrib[of _ n])

lemma mult_add_distrib_mat[algebra_simps]: assumes m: "A  carrier_mat nr n"
  "B  carrier_mat n nc" "C  carrier_mat n nc"
  shows "A * (B + C) = A * B + A * C"
  using m by (intro eq_matI, auto simp: scalar_prod_add_distrib[of _ n])

lemma add_mult_distrib_mat_vec[algebra_simps]: assumes m: "A  carrier_mat nr nc"
  "B  carrier_mat nr nc" "v  carrier_vec nc"
  shows "(A + B) *v v = A *v v + B *v v"
  using m by (intro eq_vecI, auto intro!: add_scalar_prod_distrib)

lemma mult_add_distrib_mat_vec[algebra_simps]: assumes m: "A  carrier_mat nr nc"
  "v1  carrier_vec nc" "v2  carrier_vec nc"
  shows "A *v (v1 + v2) = A *v v1 + A *v v2"
  using m by (intro eq_vecI, auto simp: scalar_prod_add_distrib[of _ nc])

lemma mult_mat_vec:
  assumes m: "(A::'a::field mat)  carrier_mat nr nc" and v: "v  carrier_vec nc"
  shows "A *v (k v v) = k v (A *v v)" (is "?l = ?r")
proof
  have nr: "dim_vec ?l = nr" using m v by auto
  also have "... = dim_vec ?r" using m v by auto
  finally show "dim_vec ?l = dim_vec ?r".

  show "i. i < dim_vec ?r  ?l $ i = ?r $ i"
  proof -
    fix i assume "i < dim_vec ?r"
    hence i: "i < dim_row A" using nr m by auto
    hence i2: "i < dim_vec (A *v v)" using m by auto
    show "?l $ i = ?r $ i"
    apply (subst (1) mult_mat_vec_def)
    apply (subst (2) smult_vec_def)
    unfolding index_vec[OF i] index_vec[OF i2]
    unfolding mult_mat_vec_def smult_vec_def
    unfolding scalar_prod_def index_vec[OF i]
    by (simp add: mult.left_commute sum_distrib_left)
  qed
qed

lemma assoc_scalar_prod: assumes *: "v1  carrier_vec nr" "A  carrier_mat nr nc" "v2  carrier_vec nc"
  shows "vec nc (λj. v1  col A j)  v2 = v1  vec nr (λi. row A i  v2)"
proof -
  have "vec nc (λj. v1  col A j)  v2 = (i{0..<nc}. vec nc (λj. k{0..<nr}. v1 $ k * col A j $ k) $ i * v2 $ i)"
    unfolding scalar_prod_def using * by auto
  also have " = (i{0..<nc}. (k{0..<nr}. v1 $ k * col A i $ k) * v2 $ i)"
    by (rule sum.cong, auto)
  also have " = (i{0..<nc}. (k{0..<nr}. v1 $ k * col A i $ k * v2 $ i))"
    unfolding sum_distrib_right ..
  also have " = (k{0..<nr}. (i{0..<nc}. v1 $ k * col A i $ k * v2 $ i))"
    by (rule sum.swap)
  also have " = (k{0..<nr}. (i{0..<nc}. v1 $ k * (col A i $ k * v2 $ i)))"
    by (simp add: ac_simps)
  also have " = (k{0..<nr}. v1 $ k * (i{0..<nc}. col A i $ k * v2 $ i))"
    unfolding sum_distrib_left ..
  also have " = (k{0..<nr}. v1 $ k * vec nr (λk. i{0..<nc}. row A k $ i * v2 $ i) $ k)"
    using * by auto
  also have " = v1  vec nr (λi. row A i  v2)" unfolding scalar_prod_def using * by simp
  finally show ?thesis .
qed

lemma assoc_mult_mat[simp]:
  "A  carrier_mat n1 n2  B  carrier_mat n2 n3  C  carrier_mat n3 n4
   (A * B) * C = A * (B * C)"
  by (intro eq_matI, auto simp: assoc_scalar_prod)

lemma assoc_mult_mat_vec[simp]:
  "A  carrier_mat n1 n2  B  carrier_mat n2 n3  v  carrier_vec n3
   (A * B) *v v = A *v (B *v v)"
  by (intro eq_vecI, auto simp add: mult_mat_vec_def assoc_scalar_prod)

lemma comm_monoid_mat: "comm_monoid (monoid_mat TYPE('a :: comm_monoid_add) nr nc)"
  by (unfold_locales, auto simp: monoid_mat_def ac_simps)

lemma comm_group_mat: "comm_group (monoid_mat TYPE('a :: ab_group_add) nr nc)"
  by (unfold_locales, insert add_inv_exists_mat, auto simp: monoid_mat_def ac_simps Units_def)

lemma semiring_mat: "semiring (ring_mat TYPE('a :: semiring_1) n b)"
  by (unfold_locales, auto simp: ring_mat_def algebra_simps)

lemma ring_mat: "ring (ring_mat TYPE('a :: comm_ring_1) n b)"
  by (unfold_locales, insert add_inv_exists_mat, auto simp: ring_mat_def algebra_simps Units_def)

lemma abelian_group_mat: "abelian_group (module_mat TYPE('a :: comm_ring_1) nr nc)"
  by (unfold_locales, insert add_inv_exists_mat, auto simp: module_mat_def Units_def)

lemma row_smult[simp]: assumes i: "i < dim_row A"
  shows "row (k m A) i = k v (row A i)"
  by (rule eq_vecI, insert i, auto)

lemma col_smult[simp]: assumes i: "i < dim_col A"
  shows "col (k m A) i = k v (col A i)"
  by (rule eq_vecI, insert i, auto)

lemma row_uminus[simp]: assumes i: "i < dim_row A"
  shows "row (- A) i = - (row A i)"
  by (rule eq_vecI, insert i, auto)

lemma scalar_prod_uminus_left[simp]: assumes dim: "dim_vec v = dim_vec (w :: 'a :: ring vec)"
  shows "- v  w = - (v  w)"
  unfolding scalar_prod_def dim[symmetric]
  by (subst sum_negf[symmetric], rule sum.cong, auto)

lemma col_uminus[simp]: assumes i: "i < dim_col A"
  shows "col (- A) i = - (col A i)"
  by (rule eq_vecI, insert i, auto)

lemma scalar_prod_uminus_right[simp]: assumes dim: "dim_vec v = dim_vec (w :: 'a :: ring vec)"
  shows "v  - w = - (v  w)"
  unfolding scalar_prod_def dim
  by (subst sum_negf[symmetric], rule sum.cong, auto)

context fixes A B :: "'a :: ring mat"
  assumes dim: "dim_col A = dim_row B"
begin
lemma uminus_mult_left_mat[simp]: "(- A * B) = - (A * B)"
  by (intro eq_matI, insert dim, auto)

lemma uminus_mult_right_mat[simp]: "(A * - B) = - (A * B)"
  by (intro eq_matI, insert dim, auto)
end

lemma minus_mult_distrib_mat[algebra_simps]: fixes A :: "'a :: ring mat"
  assumes m: "A  carrier_mat nr n" "B  carrier_mat nr n" "C  carrier_mat n nc"
  shows "(A - B) * C = A * C - B * C"
  unfolding minus_add_uminus_mat[OF m(1,2)]
    add_mult_distrib_mat[OF m(1) uminus_carrier_mat[OF m(2)] m(3)]
  by (subst uminus_mult_left_mat, insert m, auto)

lemma minus_mult_distrib_mat_vec[algebra_simps]: assumes A: "(A :: 'a :: ring mat)  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  and v: "v  carrier_vec nc"
shows "(A - B) *v v = A *v v - B *v v"
  unfolding minus_add_uminus_mat[OF A B]
  by (subst add_mult_distrib_mat_vec[OF A _ v], insert A B v, auto)

lemma mult_minus_distrib_mat_vec[algebra_simps]: assumes A: "(A :: 'a :: ring mat)  carrier_mat nr nc"
  and v: "v  carrier_vec nc"
  and w: "w  carrier_vec nc"
shows "A *v (v - w) = A *v v - A *v w"
  unfolding minus_add_uminus_vec[OF v w]
  by (subst mult_add_distrib_mat_vec[OF A], insert A v w, auto)

lemma mult_minus_distrib_mat[algebra_simps]: fixes A :: "'a :: ring mat"
  assumes m: "A  carrier_mat nr n" "B  carrier_mat n nc" "C  carrier_mat n nc"
  shows "A * (B - C) = A * B - A * C"
  unfolding minus_add_uminus_mat[OF m(2,3)]
    mult_add_distrib_mat[OF m(1) m(2) uminus_carrier_mat[OF m(3)]]
  by (subst uminus_mult_right_mat, insert m, auto)

lemma uminus_mult_mat_vec[simp]: assumes v: "dim_vec v = dim_col (A :: 'a :: ring mat)"
  shows "- A *v v = - (A *v v)"
  using v by (intro eq_vecI, auto)

lemma uminus_zero_vec_eq: assumes v: "(v :: 'a :: group_add vec)  carrier_vec n"
  shows "(- v = 0v n) = (v = 0v n)"
proof
  assume z: "- v = 0v n"
  {
    fix i
    assume i: "i < n"
    have "v $ i = - (- (v $ i))" by simp
    also have "- (v $ i) = 0" using arg_cong[OF z, of "λ v. v $ i"] i v by auto
    also have "- 0 = (0 :: 'a)" by simp
    finally have "v $ i = 0" .
  }
  thus "v = 0v n" using v
    by (intro eq_vecI, auto)
qed auto

lemma map_carrier_mat[simp]:
  "(map_mat f A  carrier_mat nr nc) = (A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma col_map_mat[simp]:
  assumes "j < dim_col A" shows "col (map_mat f A) j = map_vec f (col A j)"
  unfolding map_mat_def map_vec_def using assms by auto

lemma scalar_vec_one[simp]: "1 v (v :: 'a :: semiring_1 vec) = v"
  by (rule eq_vecI, auto)

lemma scalar_prod_smult_right[simp]:
  "dim_vec w = dim_vec v  w  (k v v) = (k :: 'a :: comm_semiring_0) * (w  v)"
  unfolding scalar_prod_def sum_distrib_left
  by (auto intro: sum.cong simp: ac_simps)

lemma scalar_prod_smult_left[simp]:
  "dim_vec w = dim_vec v  (k v w)  v = (k :: 'a :: comm_semiring_0) * (w  v)"
  unfolding scalar_prod_def sum_distrib_left
  by (auto intro: sum.cong simp: ac_simps)

lemma mult_smult_distrib: assumes A: "A  carrier_mat nr n" and B: "B  carrier_mat n nc"
  shows "A * (k m B) = (k :: 'a :: comm_semiring_0) m (A * B)"
  by (rule eq_matI, insert A B, auto)

lemma add_smult_distrib_left_mat: assumes "A  carrier_mat nr nc" "B  carrier_mat nr nc"
  shows "k m (A + B) = (k :: 'a :: semiring) m A + k m B"
  by (rule eq_matI, insert assms, auto simp: field_simps)

lemma add_smult_distrib_right_mat: assumes "A  carrier_mat nr nc"
  shows "(k + l) m A = (k :: 'a :: semiring) m A + l m A"
  by (rule eq_matI, insert assms, auto simp: field_simps)

lemma mult_smult_assoc_mat: assumes A: "A  carrier_mat nr n" and B: "B  carrier_mat n nc"
  shows "(k m A) * B = (k :: 'a :: comm_semiring_0) m (A * B)"
  by (rule eq_matI, insert A B, auto)

definition similar_mat_wit :: "'a :: semiring_1 mat  'a mat  'a mat  'a mat  bool" where
  "similar_mat_wit A B P Q = (let n = dim_row A in {A,B,P,Q}  carrier_mat n n  P * Q = 1m n  Q * P = 1m n 
    A = P * B * Q)"

definition similar_mat :: "'a :: semiring_1 mat  'a mat  bool" where
  "similar_mat A B = ( P Q. similar_mat_wit A B P Q)"

lemma similar_matD: assumes "similar_mat A B"
  shows " n P Q. {A,B,P,Q}  carrier_mat n n  P * Q = 1m n  Q * P = 1m n  A = P * B * Q"
  using assms unfolding similar_mat_def similar_mat_wit_def[abs_def] Let_def by blast

lemma similar_matI: assumes "{A,B,P,Q}  carrier_mat n n" "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  shows "similar_mat A B" unfolding similar_mat_def
  by (rule exI[of _ P], rule exI[of _ Q], unfold similar_mat_wit_def Let_def, insert assms, auto)

fun pow_mat :: "'a :: semiring_1 mat  nat  'a mat" (infixr "^m" 75) where
  "A ^m 0 = 1m (dim_row A)"
| "A ^m (Suc k) = A ^m k * A"

lemma pow_mat_dim[simp]:
  "dim_row (A ^m k) = dim_row A"
  "dim_col (A ^m k) = (if k = 0 then dim_row A else dim_col A)"
  by (induct k, auto)

lemma pow_mat_dim_square[simp]:
  "A  carrier_mat n n  dim_row (A ^m k) = n"
  "A  carrier_mat n n  dim_col (A ^m k) = n"
  by auto

lemma pow_carrier_mat[simp]: "A  carrier_mat n n  A ^m k  carrier_mat n n"
  unfolding carrier_mat_def by auto

definition diag_mat :: "'a mat  'a list" where
  "diag_mat A = map (λ i. A $$ (i,i)) [0 ..< dim_row A]"

lemma prod_list_diag_prod: "prod_list (diag_mat A) = ( i = 0 ..< dim_row A. A $$ (i,i))"
  unfolding diag_mat_def
  by (subst prod.distinct_set_conv_list[symmetric], auto)

lemma diag_mat_transpose[simp]: "dim_row A = dim_col A 
  diag_mat (transpose_mat A) = diag_mat A" unfolding diag_mat_def by auto

lemma diag_mat_zero[simp]: "diag_mat (0m n n) = replicate n 0"
  unfolding diag_mat_def
  by (rule nth_equalityI, auto)

lemma diag_mat_one[simp]: "diag_mat (1m n) = replicate n 1"
  unfolding diag_mat_def
  by (rule nth_equalityI, auto)

lemma pow_mat_ring_pow: assumes A: "(A :: ('a :: semiring_1)mat)  carrier_mat n n"
  shows "A ^m k = A [^]ring_mat TYPE('a) n b k"
  (is "_ = A [^]?C k")
proof -
  interpret semiring ?C by (rule semiring_mat)
  show ?thesis
    by (induct k, insert A, auto simp: ring_mat_def nat_pow_def)
qed

definition diagonal_mat :: "'a::zero mat  bool" where
  "diagonal_mat A  i<dim_row A. j<dim_col A. i  j  A $$ (i,j) = 0"

definition (in comm_monoid_add) sum_mat :: "'a mat  'a" where
  "sum_mat A = sum (λ ij. A $$ ij) ({0 ..< dim_row A} × {0 ..< dim_col A})"

lemma sum_mat_0[simp]: "sum_mat (0m nr nc) = (0 :: 'a :: comm_monoid_add)"
  unfolding sum_mat_def
  by (rule sum.neutral, auto)

lemma sum_mat_add: assumes A: "(A :: 'a :: comm_monoid_add mat)  carrier_mat nr nc" and B: "B  carrier_mat nr nc"
  shows "sum_mat (A + B) = sum_mat A + sum_mat B"
proof -
  from A B have id: "dim_row A = nr" "dim_row B = nr" "dim_col A = nc" "dim_col B = nc"
    by auto
  show ?thesis unfolding sum_mat_def id
    by (subst sum.distrib[symmetric], rule sum.cong, insert A B, auto)
qed

subsection ‹Update Operators›

definition update_vec :: "'a vec  nat  'a  'a vec" ("_ |v _  _" [60,61,62] 60)
  where "v |v i  a = vec (dim_vec v) (λi'. if i' = i then a else v $ i')"

definition update_mat :: "'a mat  nat × nat  'a  'a mat" ("_ |m _  _" [60,61,62] 60)
  where "A |m ij  a = mat (dim_row A) (dim_col A) (λij'. if ij' = ij then a else A $$ ij')"

lemma dim_update_vec[simp]:
  "dim_vec (v |v i  a) = dim_vec v" unfolding update_vec_def by simp

lemma index_update_vec1[simp]:
  assumes "i < dim_vec v" shows "(v |v i  a) $ i = a"
  unfolding update_vec_def using assms by simp

lemma index_update_vec2[simp]:
  assumes "i'  i" shows "(v |v i  a) $ i' = v $ i'"
  unfolding update_vec_def
  using assms apply transfer unfolding mk_vec_def by auto

lemma dim_update_mat[simp]:
  "dim_row (A |m ij  a) = dim_row A"
  "dim_col (A |m ij  a) = dim_col A" unfolding update_mat_def by simp+

lemma index_update_mat1[simp]:
  assumes "i < dim_row A" "j < dim_col A" shows "(A |m (i,j)  a) $$ (i,j) = a"
  unfolding update_mat_def using assms by simp

lemma index_update_mat2[simp]:
  assumes i': "i' < dim_row A" and j': "j' < dim_col A" and neq: "(i',j')  ij"
  shows "(A |m ij  a) $$ (i',j') = A $$ (i',j')"
  unfolding update_mat_def using assms by auto

subsection ‹Block Vectors and Matrices›

definition append_vec :: "'a vec  'a vec  'a vec" (infixr "@v" 65) where
  "v @v w  let n = dim_vec v; m = dim_vec w in
    vec (n + m) (λ i. if i < n then v $ i else w $ (i - n))"

lemma index_append_vec[simp]: "i < dim_vec v + dim_vec w
   (v @v w) $ i = (if i < dim_vec v then v $ i else w $ (i - dim_vec v))"
  "dim_vec (v @v w) = dim_vec v + dim_vec w"
  unfolding append_vec_def Let_def by auto

lemma append_carrier_vec[simp,intro]:
  "v  carrier_vec n1  w  carrier_vec n2  v @v w  carrier_vec (n1 + n2)"
  unfolding carrier_vec_def by auto

lemma scalar_prod_append: assumes "v1  carrier_vec n1" "v2  carrier_vec n2"
  "w1  carrier_vec n1" "w2  carrier_vec n2"
  shows "(v1 @v v2)  (w1 @v w2) = v1  w1 + v2  w2"
proof -
  from assms have dim: "dim_vec v1 = n1" "dim_vec v2 = n2" "dim_vec w1 = n1" "dim_vec w2 = n2" by auto
  have id: "{0 ..< n1 + n2} = {0 ..< n1}  {n1 ..< n1 + n2}" by auto
  have id2: "{n1 ..< n1 + n2} = (plus n1) ` {0 ..< n2}"
    by (simp add: ac_simps)
  have "(v1 @v v2)  (w1 @v w2) = (i = 0..<n1. v1 $ i * w1 $ i) +
    (i = n1..<n1 + n2. v2 $ (i - n1) * w2 $ (i - n1))"
  unfolding scalar_prod_def
    by (auto simp: dim id, subst sum.union_disjoint, insert assms, force+)
  also have "(i = n1..<n1 + n2. v2 $ (i - n1) * w2 $ (i - n1))
    = (i = 0..< n2. v2 $ i * w2 $ i)"
    by (rule sum.reindex_cong [OF _ id2]) simp_all
  finally show ?thesis by (simp, insert assms, auto simp: scalar_prod_def)
qed

definition "vec_first v n  vec n (λi. v $ i)"
definition "vec_last v n  vec n (λi. v $ (dim_vec v - n + i))"

lemma dim_vec_first[simp]: "dim_vec (vec_first v n) = n" unfolding vec_first_def by auto
lemma dim_vec_last[simp]: "dim_vec (vec_last v n) = n" unfolding vec_last_def by auto

lemma vec_first_carrier[simp]: "vec_first v n  carrier_vec n" by (rule carrier_vecI, auto)
lemma vec_last_carrier[simp]: "vec_last v n  carrier_vec n" by (rule carrier_vecI, auto)

lemma vec_first_last_append[simp]:
  assumes "v  carrier_vec (n+m)" shows "vec_first v n @v vec_last v m = v"
  apply(rule) unfolding vec_first_def vec_last_def using assms by auto

lemma append_vec_le: assumes "v  carrier_vec n" and w: "w  carrier_vec n" 
  shows "v @v v'  w @v w'  v  w  v'  w'" 
proof -
  {
    fix i
    assume *: "i. (¬ i < n  i < n + dim_vec w'  v' $ (i - n)  w' $ (i - n))"
      and i: "i < dim_vec w'" 
    have "v' $ i  w' $ i" using *[rule_format, of "n + i"] i by auto
  }
  thus ?thesis using assms unfolding less_eq_vec_def by auto
qed

lemma all_vec_append: "( x  carrier_vec (n + m). P x)  ( x1  carrier_vec n.  x2  carrier_vec m. P (x1 @v x2))" 
proof (standard, force, intro ballI, goal_cases)
  case (1 x)
  have "x = vec n (λ i. x $ i) @v vec m (λ i. x $ (n + i))" 
    by (rule eq_vecI, insert 1(2), auto)
  hence "P x = P (vec n (λ i. x $ i) @v vec m (λ i. x $ (n + i)))" by simp
  also have "" using 1 by auto
  finally show ?case .
qed


(* A B
   C D *)
definition four_block_mat :: "'a mat  'a mat  'a mat  'a mat  'a mat" where
  "four_block_mat A B C D =
    (let nra = dim_row A; nrd = dim_row D;
         nca = dim_col A; ncd = dim_col D
       in
    mat (nra + nrd) (nca + ncd) (λ (i,j). if i < nra then
      if j < nca then A $$ (i,j) else B $$ (i,j - nca)
      else if j < nca then C $$ (i - nra, j) else D $$ (i - nra, j - nca)))"

lemma index_mat_four_block[simp]:
  "i < dim_row A + dim_row D  j < dim_col A + dim_col D  four_block_mat A B C D $$ (i,j)
  = (if i < dim_row A then
      if j < dim_col A then A $$ (i,j) else B $$ (i,j - dim_col A)
      else if j < dim_col A then C $$ (i - dim_row A, j) else D $$ (i - dim_row A, j - dim_col A))"
  "dim_row (four_block_mat A B C D) = dim_row A + dim_row D"
  "dim_col (four_block_mat A B C D) = dim_col A + dim_col D"
  unfolding four_block_mat_def Let_def by auto

lemma four_block_carrier_mat[simp]:
  "A  carrier_mat nr1 nc1  D  carrier_mat nr2 nc2 
  four_block_mat A B C D  carrier_mat (nr1 + nr2) (nc1 + nc2)"
  unfolding carrier_mat_def by auto

lemma cong_four_block_mat: "A1 = B1  A2 = B2  A3 = B3  A4 = B4 
  four_block_mat A1 A2 A3 A4 = four_block_mat B1 B2 B3 B4" by auto

lemma four_block_one_mat[simp]:
  "four_block_mat (1m n1) (0m n1 n2) (0m n2 n1) (1m n2) = 1m (n1 + n2)"
  by (rule eq_matI, auto)

lemma four_block_zero_mat[simp]:
  "four_block_mat (0m nr1 nc1) (0m nr1 nc2) (0m nr2 nc1) (0m nr2 nc2) = 0m (nr1 + nr2) (nc1 + nc2)"
  by (rule eq_matI, auto)

lemma row_four_block_mat:
  assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows
  "i < nr1  row (four_block_mat A B C D) i = row A i @v row B i" (is "_  ?AB")
  "¬ i < nr1  i < nr1 + nr2  row (four_block_mat A B C D) i = row C (i - nr1) @v row D (i - nr1)"
  (is "_  _  ?CD")
proof -
  assume i: "i < nr1"
  show ?AB by (rule eq_vecI, insert i c, auto)
next
  assume i: "¬ i < nr1" "i < nr1 + nr2"
  show ?CD by (rule eq_vecI, insert i c, auto)
qed

lemma col_four_block_mat:
  assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows
  "j < nc1  col (four_block_mat A B C D) j = col A j @v col C j" (is "_  ?AC")
  "¬ j < nc1  j < nc1 + nc2  col (four_block_mat A B C D) j = col B (j - nc1) @v col D (j - nc1)"
  (is "_  _  ?BD")
proof -
  assume j: "j < nc1"
  show ?AC by (rule eq_vecI, insert j c, auto)
next
  assume j: "¬ j < nc1" "j < nc1 + nc2"
  show ?BD by (rule eq_vecI, insert j c, auto)
qed

lemma mult_four_block_mat: assumes
  c1: "A1  carrier_mat nr1 n1" "B1  carrier_mat nr1 n2" "C1  carrier_mat nr2 n1" "D1  carrier_mat nr2 n2" and
  c2: "A2  carrier_mat n1 nc1" "B2  carrier_mat n1 nc2" "C2  carrier_mat n2 nc1" "D2  carrier_mat n2 nc2"
  shows "four_block_mat A1 B1 C1 D1 * four_block_mat A2 B2 C2 D2
  = four_block_mat (A1 * A2 + B1 * C2) (A1 * B2 + B1 * D2)
    (C1 * A2 + D1 * C2) (C1 * B2 + D1 * D2)" (is "?M1 * ?M2 = _")
proof -
  note row = row_four_block_mat[OF c1]
  note col = col_four_block_mat[OF c2]
  {
    fix i j
    assume i: "i < nr1" and j: "j < nc1"
    have "row ?M1 i  col ?M2 j = row A1 i  col A2 j + row B1 i  col C2 j"
      unfolding row(1)[OF i] col(1)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i j, auto)
  }
  moreover
  {
    fix i j
    assume i: "¬ i < nr1" "i < nr1 + nr2" and j: "j < nc1"
    hence i': "i - nr1 < nr2" by auto
    have "row ?M1 i  col ?M2 j = row C1 (i - nr1)  col A2 j + row D1 (i - nr1)  col C2 j"
      unfolding row(2)[OF i] col(1)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i i' j, auto)
  }
  moreover
  {
    fix i j
    assume i: "i < nr1" and j: "¬ j < nc1" "j < nc1 + nc2"
    hence j': "j - nc1 < nc2" by auto
    have "row ?M1 i  col ?M2 j = row A1 i  col B2 (j - nc1) + row B1 i  col D2 (j - nc1)"
      unfolding row(1)[OF i] col(2)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i j' j, auto)
  }
  moreover
  {
    fix i j
    assume i: "¬ i < nr1" "i < nr1 + nr2" and j: "¬ j < nc1" "j < nc1 + nc2"
    hence i': "i - nr1 < nr2" and j': "j - nc1 < nc2" by auto
    have "row ?M1 i  col ?M2 j = row C1 (i - nr1)  col B2 (j - nc1) + row D1 (i - nr1)  col D2 (j - nc1)"
      unfolding row(2)[OF i] col(2)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i i' j' j, auto)
  }
  ultimately show ?thesis
    by (intro eq_matI, insert c1 c2, auto)
qed

definition append_rows :: "'a :: zero mat  'a mat  'a mat" (infixr "@r" 65)where
  "A @r B = four_block_mat A (0m (dim_row A) 0) B (0m (dim_row B) 0)" 

lemma carrier_append_rows[simp,intro]: "A  carrier_mat nr1 nc  B  carrier_mat nr2 nc 
  A @r B  carrier_mat (nr1 + nr2) nc" 
  unfolding append_rows_def by auto

lemma col_mult2[simp]:
  assumes A: "A : carrier_mat nr n"
      and B: "B : carrier_mat n nc"
      and j: "j < nc"
  shows "col (A * B) j = A *v col B j"
proof
  have AB: "A * B : carrier_mat nr nc" using A B by auto
  fix i assume i: "i < dim_vec (A *v col B j)"
  show "col (A * B) j $ i = (A *v col B j) $ i"
    using A B AB j i by simp
qed auto

lemma mat_vec_as_mat_mat_mult: assumes A: "A  carrier_mat nr nc" 
  and v: "v  carrier_vec nc" 
shows "A *v v = col (A * mat_of_cols nc [v]) 0"  
  by (subst col_mult2[OF A], insert v, auto)

lemma mat_mult_append: assumes A: "A  carrier_mat nr1 nc" 
  and B: "B  carrier_mat nr2 nc" 
  and v: "v  carrier_vec nc" 
shows "(A @r B) *v v = (A *v v) @v (B *v v)" 
proof -
  let ?Fb1 = "four_block_mat A (0m nr1 0) B (0m nr2 0)" 
  let ?Fb2 = "four_block_mat (mat_of_cols nc [v]) (0m nc 0) (0m 0 1) (0m 0 0)" 
  have id: "?Fb2 = mat_of_cols nc [v]" 
    using v by auto
  have "(A @r B) *v v = col (?Fb1 * ?Fb2) 0" unfolding id
    by (subst mat_vec_as_mat_mat_mult[OF _ v], insert A B, auto simp: append_rows_def)
  also have "?Fb1 * ?Fb2 = four_block_mat (A * mat_of_cols nc [v] + 0m nr1 0 * 0m 0 1) (A * 0m nc 0 + 0m nr1 0 * 0m 0 0)
     (B * mat_of_cols nc [v] + 0m nr2 0 * 0m 0 1) (B * 0m nc 0 + 0m nr2 0 * 0m 0 0)" 
    by (rule mult_four_block_mat[OF A _ B], auto)
  also have "(A * mat_of_cols nc [v] + 0m nr1 0 * 0m 0 1) = A * mat_of_cols nc [v]" 
    using A v by auto
  also have "(B * mat_of_cols nc [v] + 0m nr2 0 * 0m 0 1) = B * mat_of_cols nc [v]" 
    using B v by auto
  also have "(A * 0m nc 0 + 0m nr1 0 * 0m 0 0) = 0m nr1 0" using A by auto 
  also have "(B * 0m nc 0 + 0m nr2 0 * 0m 0 0) = 0m nr2 0" using B by auto
  finally have "(A @r B) *v v = col (four_block_mat (A * mat_of_cols nc [v]) (0m nr1 0) (B * mat_of_cols nc [v]) (0m nr2 0)) 0" .
  also have " = col (A * mat_of_cols nc [v]) 0 @v col (B * mat_of_cols nc [v]) 0" 
    by (rule col_four_block_mat, insert A B v, auto)
  also have "col (A * mat_of_cols nc [v]) 0 = A *v v" 
    by (rule mat_vec_as_mat_mat_mult[symmetric, OF A v])
  also have "col (B * mat_of_cols nc [v]) 0 = B *v v" 
    by (rule mat_vec_as_mat_mat_mult[symmetric, OF B v])
  finally show ?thesis .
qed
 
lemma append_rows_le: assumes A: "A  carrier_mat nr1 nc" 
  and B: "B  carrier_mat nr2 nc" 
  and a: "a  carrier_vec nr1" 
  and v: "v  carrier_vec nc"
shows "(A @r B) *v v  (a @v b)  A *v v  a  B *v v  b" 
  unfolding mat_mult_append[OF A B v]
  by (rule append_vec_le[OF _ a], insert A v, auto)


lemma elements_four_block_mat:
  assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows
  "elements_mat (four_block_mat A B C D) 
   elements_mat A  elements_mat B  elements_mat C  elements_mat D"
   (is "elements_mat ?four  _")
proof rule
  fix a assume "a  elements_mat ?four"
  then obtain i j
    where i4: "i < dim_row ?four" and j4: "j < dim_col ?four" and a: "a = ?four $$ (i, j)"
    by auto
  show "a  elements_mat A  elements_mat B  elements_mat C  elements_mat D"
  proof (cases "i < nr1")
    case True note i1 = this
    show ?thesis
    proof (cases "j < nc1")
      case True
      then have "a = A $$ (i,j)" using c i1 a by simp
      thus ?thesis using c i1 True by auto next
      case False
      then have "a = B $$ (i,j-nc1)" using c i1 a j4 by simp
      moreover have "j - nc1 < nc2" using c j4 False by auto
      ultimately show ?thesis using c i1 by auto
    qed next
    case False note i1 = this
    have i2: "i - nr1 < nr2" using c i1 i4 by auto
    show ?thesis
    proof (cases "j < nc1")
      case True
      then have "a = C $$ (i-nr1,j)" using c i2 a i1 by simp
      thus ?thesis using c i2 True by auto next
      case False
      then have "a = D $$ (i-nr1,j-nc1)" using c i2 a i1 j4 by simp
      moreover have "j - nc1 < nc2" using c j4 False by auto
      ultimately show ?thesis using c i2 by auto
    qed
  qed
qed

lemma assoc_four_block_mat: fixes FB :: "'a mat  'a mat  'a :: zero mat"
  defines FB: "FB  λ Bb Cc. four_block_mat Bb (0m (dim_row Bb) (dim_col Cc)) (0m (dim_row Cc) (dim_col Bb)) Cc"
  shows "FB A (FB B C) = FB (FB A B) C" (is "?L = ?R")
proof -
  let ?ar = "dim_row A" let ?ac = "dim_col A"
  let ?br = "dim_row B" let ?bc = "dim_col B"
  let ?cr = "dim_row C" let ?cc = "dim_col C"
  let ?r = "?ar + ?br + ?cr" let ?c = "?ac + ?bc + ?cc"
  let ?BC = "FB B C" let ?AB = "FB A B"
  have dL: "dim_row ?L = ?r" "dim_col ?L = ?c" unfolding FB by auto
  have dR: "dim_row ?R = ?ar + ?br + ?cr" "dim_col ?R = ?ac + ?bc + ?cc" unfolding FB by auto
  have dBC: "dim_row ?BC = ?br + ?cr" "dim_col ?BC = ?bc + ?cc" unfolding FB by auto
  have dAB: "dim_row ?AB = ?ar + ?br" "dim_col ?AB = ?ac + ?bc" unfolding FB by auto
  show ?thesis
  proof (intro eq_matI[of ?R ?L, unfolded dL dR, OF _ refl refl])
    fix i j
    assume i: "i < ?r" and j: "j < ?c"
    show "?L $$ (i,j) = ?R $$ (i,j)"
    proof (cases "i < ?ar")
      case True note i = this
      thus ?thesis using j
        by (cases "j < ?ac", auto simp: FB)
    next
      case False note ii = this
      show ?thesis
      proof (cases "j < ?ac")
        case True
        with i ii show ?thesis unfolding FB by auto
      next
        case False note jj = this
        from j jj i ii have L: "?L $$ (i,j) = ?BC $$ (i - ?ar, j - ?ac)" unfolding FB by auto
        have R: "?R $$ (i,j) = ?BC $$ (i - ?ar, j - ?ac)" using ii jj i j
          by (cases "i < ?ar + ?br"; cases "j < ?ac + ?bc", auto simp: FB)
        show ?thesis unfolding L R ..
      qed
    qed
  qed
qed

definition split_block :: "'a mat  nat  nat  ('a mat × 'a mat × 'a mat × 'a mat)"
  where "split_block A sr sc = (let
    nr = dim_row A; nc = dim_col A;
    nr2 = nr - sr; nc2 = nc - sc;
    A1 = mat sr sc (λ ij. A $$ ij);
    A2 = mat sr nc2 (λ (i,j). A $$ (i,j+sc));
    A3 = mat nr2 sc (λ (i,j). A $$ (i+sr,j));
    A4 = mat nr2 nc2 (λ (i,j). A $$ (i+sr,j+sc))
  in (A1,A2,A3,A4))"

lemma split_block: assumes res: "split_block A sr1 sc1 = (A1,A2,A3,A4)"
  and dims: "dim_row A = sr1 + sr2" "dim_col A = sc1 + sc2"
  shows "A1  carrier_mat sr1 sc1" "A2  carrier_mat sr1 sc2"
    "A3  carrier_mat sr2 sc1" "A4  carrier_mat sr2 sc2"
    "A = four_block_mat A1 A2 A3 A4"
  using res unfolding split_block_def Let_def
  by (auto simp: dims)

text ‹Using @{const four_block_mat} we define block-diagonal matrices.›

fun diag_block_mat :: "'a :: zero mat list  'a mat" where
  "diag_block_mat [] = 0m 0 0"
| "diag_block_mat (A # As) = (let
     B = diag_block_mat As
     in four_block_mat A (0m (dim_row A) (dim_col B)) (0m (dim_row B) (dim_col A)) B)"

lemma dim_diag_block_mat:
  "dim_row (diag_block_mat As) = sum_list (map dim_row As)" (is "?row")
  "dim_col (diag_block_mat As) = sum_list (map dim_col As)" (is "?col")
proof -
  have "?row  ?col"
    by (induct As, auto simp: Let_def)
  thus ?row and ?col by auto
qed

lemma diag_block_mat_singleton[simp]: "diag_block_mat [A] = A"
  by auto

lemma diag_block_mat_append: "diag_block_mat (As @ Bs) =
  (let A = diag_block_mat As; B = diag_block_mat Bs
  in four_block_mat A (0m (dim_row A) (dim_col B)) (0m (dim_row B) (dim_col A)) B)"
  unfolding Let_def
proof (induct As)
  case (Cons A As)
  show ?case
    unfolding append.simps
    unfolding diag_block_mat.simps Let_def
    unfolding Cons
    by (rule assoc_four_block_mat)
qed auto

lemma diag_block_mat_last: "diag_block_mat (As @ [B]) =
  (let A = diag_block_mat As
  in four_block_mat A (0m (dim_row A) (dim_col B)) (0m (dim_row B) (dim_col A)) B)"
  unfolding diag_block_mat_append diag_block_mat_singleton by auto


lemma diag_block_mat_square:
  "Ball (set As) square_mat  square_mat (diag_block_mat As)"
by (induct As, auto simp:Let_def)

lemma diag_block_one_mat[simp]:
  "diag_block_mat (map (λA. 1m (dim_row A)) As) = (1m (sum_list (map dim_row As)))"
  by (induct As, auto simp: Let_def)

lemma elements_diag_block_mat:
  "elements_mat (diag_block_mat As)  {0}   (set (map elements_mat As))"
proof (induct As)
  case Nil then show ?case using dim_diag_block_mat[of Nil] by auto next
  case (Cons A As)
    let ?D = "diag_block_mat As"
    let ?B = "0m (dim_row A) (dim_col ?D)"
    let ?C = "0m (dim_row ?D) (dim_col A)"
    have A: "A  carrier_mat (dim_row A) (dim_col A)" by auto
    have B: "?B  carrier_mat (dim_row A) (dim_col ?D)" by auto
    have C: "?C  carrier_mat (dim_row ?D) (dim_col A)" by auto
    have D: "?D  carrier_mat (dim_row ?D) (dim_col ?D)" by auto
    have
      "elements_mat (diag_block_mat (A#As)) 
       elements_mat A  elements_mat ?B  elements_mat ?C  elements_mat ?D"
      unfolding diag_block_mat.simps Let_def
      using elements_four_block_mat[OF A B C D] elements_0_mat
      by auto
    also have "...  {0}  elements_mat A  elements_mat ?D"
      using elements_0_mat by auto
    finally show ?case using Cons by auto
qed

lemma diag_block_pow_mat: assumes sq: "Ball (set As) square_mat"
  shows "diag_block_mat As ^m n = diag_block_mat (map (λ A. A ^m n) As)" (is "?As ^m _ = _")
proof (induct n)
  case 0
  have "?As ^m 0 = 1m (dim_row ?As)" by simp
  also have "dim_row ?As = sum_list (map dim_row As)"
    using diag_block_mat_square[OF sq] unfolding dim_diag_block_mat by auto
  also have "1m  = diag_block_mat (map (λA. 1m (dim_row A)) As)" by simp
  also have " = diag_block_mat (map (λ A. A ^m 0) As)" by simp
  finally show ?case .
next
  case (Suc n)
  let ?An = "λ As. diag_block_mat (map (λA. A ^m n) As)"
  let ?Asn = "λ As. diag_block_mat (map (λA. A ^m n * A) As)"
  from Suc have "?case = (?An As * diag_block_mat As = ?Asn As)" by simp
  also have "" using sq
  proof (induct As)
    case (Cons A As)
    hence IH: "?An As * diag_block_mat As = ?Asn As"
      and sq: "Ball (set As) square_mat" and A: "dim_col A = dim_row A" by auto
    have sq2: "Ball (set (List.map (λA. A ^m n) As)) square_mat"
      and sq3: "Ball (set (List.map (λA. A ^m n * A) As)) square_mat"
      using sq by auto
    define n1 where "n1 = dim_row A"
    define n2 where "n2 = sum_list (map dim_row As)"
    from A have A: "A  carrier_mat n1 n1" unfolding n1_def carrier_mat_def by simp
    have [simp]: "dim_col (?An As) = n2" "dim_row (?An As) = n2"
      unfolding n2_def
      using diag_block_mat_square[OF sq2,unfolded square_mat.simps]
      unfolding dim_diag_block_mat map_map by (auto simp:o_def)
    have [simp]: "dim_col (?Asn As) = n2" "dim_row (?Asn As) = n2"
      unfolding n2_def
      using diag_block_mat_square[OF sq3,unfolded square_mat.simps]
      unfolding dim_diag_block_mat map_map by (auto simp:o_def)
    have [simp]:
      "dim_row (diag_block_mat As) = n2"
      "dim_col (diag_block_mat As) = n2"
      unfolding n2_def
      using diag_block_mat_square[OF sq,unfolded square_mat.simps]
      unfolding dim_diag_block_mat by auto

    have [simp]: "diag_block_mat As  carrier_mat n2 n2" unfolding carrier_mat_def by simp
    have [simp]: "?An As  carrier_mat n2 n2" unfolding carrier_mat_def by simp
    show ?case unfolding diag_block_mat.simps Let_def list.simps
      by (subst mult_four_block_mat[of _ n1 n1 _ n2 _ n2 _ _ n1 _ n2],
      insert A, auto simp: IH)
  qed auto
  finally show ?case by simp
qed

lemma diag_block_upper_triangular: assumes
    " A i j. A  set As  j < i  i < dim_row A  A $$ (i,j) = 0"
  and "Ball (set As) square_mat"
  and "j < i" "i < dim_row (diag_block_mat As)"
  shows "diag_block_mat As $$ (i,j) = 0"
  using assms
proof (induct As arbitrary: i j)
  case (Cons A As i j)
  let ?n1 = "dim_row A"
  let ?n2 = "sum_list (map dim_row As)"
  from Cons have [simp]: "dim_col A = ?n1" by simp
  from Cons have "Ball (set As) square_mat" by auto
  note [simp] = diag_block_mat_square[OF this,unfolded square_mat.simps]
  note [simp] = dim_diag_block_mat(1)
  from Cons(5) have i: "i < ?n1 + ?n2" by simp
  show ?case
  proof (cases "i < ?n1")
    case True
    with Cons(4) have j: "j < ?n1" by auto
    with True Cons(2)[of A, OF _ Cons(4)] show ?thesis
      by (simp add: Let_def)
  next
    case False note iAs = this
    show ?thesis
    proof (cases "j < ?n1")
      case True
      with i iAs show ?thesis by (simp add: Let_def)
    next
      case False note jAs = this
      from Cons(4) i have j: "j < ?n1 + ?n2" by auto
      show ?thesis using iAs jAs i j
        by (simp add: Let_def, subst Cons(1), insert Cons(2-4), auto)
    qed
  qed
qed simp

lemma smult_four_block_mat: assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "a m four_block_mat A B C D = four_block_mat (a m A) (a m B) (a m C) (a m D)"
  by (rule eq_matI, insert c, auto)

lemma map_four_block_mat: assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "map_mat f (four_block_mat A B C D) = four_block_mat (map_mat f A) (map_mat f B) (map_mat f C) (map_mat f D)"
  by (rule eq_matI, insert c, auto)

lemma add_four_block_mat: assumes
  c1: "A1  carrier_mat nr1 nc1" "B1  carrier_mat nr1 nc2" "C1  carrier_mat nr2 nc1" "D1  carrier_mat nr2 nc2" and
  c2: "A2  carrier_mat nr1 nc1" "B2  carrier_mat nr1 nc2" "C2  carrier_mat nr2 nc1" "D2  carrier_mat nr2 nc2"
  shows "four_block_mat A1 B1 C1 D1 + four_block_mat A2 B2 C2 D2
  = four_block_mat (A1 + A2) (B1 + B2) (C1 + C2) (D1 + D2)"
  by (rule eq_matI, insert assms, auto)


lemma diag_four_block_mat: assumes c: "A  carrier_mat n1 n1"
   "D  carrier_mat n2 n2"
  shows "diag_mat (four_block_mat A B C D) = diag_mat A @ diag_mat D"
  by (rule nth_equalityI, insert c, auto simp: diag_mat_def nth_append)

definition mk_diagonal :: "'a::zero list  'a mat"
  where "mk_diagonal as = diag_block_mat (map (λa. mat (Suc 0) (Suc 0) (λ_. a)) as)"

lemma mk_diagonal_dim:
  "dim_row (mk_diagonal as) = length as" "dim_col (mk_diagonal as) = length as"
  unfolding mk_diagonal_def by(induct as, auto simp: Let_def)

lemma mk_diagonal_diagonal: "diagonal_mat (mk_diagonal as)"
  unfolding mk_diagonal_def
proof (induct as)
  case Nil show ?case unfolding mk_diagonal_def diagonal_mat_def by simp next
  case (Cons a as)
    let ?n = "length (a#as)"
    let ?A = "mat (Suc 0) (Suc 0) (λ_. a)"
    let ?f = "map (λa. mat (Suc 0) (Suc 0) (λ_. a))"
    let ?AS = "diag_block_mat (?f as)"
    let ?AAS = "diag_block_mat (?f (a#as))"
    show ?case
      unfolding diagonal_mat_def
    proof(intro allI impI)
      fix i j assume ir: "i < dim_row ?AAS" and jc: "j < dim_col ?AAS" and ij: "i  j"
      hence ir2: "i < 1 + dim_row ?AS" and jc2: "j < 1 + dim_col ?AS"
        unfolding dim_row_mat list.map diag_block_mat.simps Let_def
        by auto
      show "?AAS $$ (i,j) = 0"
      proof (cases "i = 0")
        case True
          then show ?thesis using jc ij by (auto simp: Let_def) next
        case False note i0 = this
          show ?thesis
          proof (cases "j = 0")
            case True
              then show ?thesis using ir ij by (auto simp: Let_def) next
            case False
              have ir3: "i-1 < dim_row ?AS" and jc3: "j-1 < dim_col ?AS"
                using ir2 jc2 i0 False by auto
              have IH: "i j. i < dim_row ?AS  j < dim_col ?AS  i  j 
                ?AS $$ (i,j) = 0"
                using Cons unfolding diagonal_mat_def by auto
              have "?AS $$ (i-1,j-1) = 0"
                using IH[OF ir3 jc3] i0 False ij by auto
              thus ?thesis using ir jc ij by (simp add: Let_def)
          qed
      qed
    qed
qed

definition orthogonal_mat :: "'a::semiring_0 mat  bool"
  where "orthogonal_mat A 
    let B = transpose_mat A * A in
    diagonal_mat B  (i<dim_col A. B $$ (i,i)  0)"

lemma orthogonal_matD[elim]:
  "orthogonal_mat A 
   i < dim_col A  j < dim_col A  (col A i  col A j = 0) = (i  j)"
  unfolding orthogonal_mat_def diagonal_mat_def by auto

lemma orthogonal_matI[intro]:
  "(i j. i < dim_col A  j < dim_col A  (col A i  col A j = 0) = (i  j)) 
   orthogonal_mat A"
  unfolding orthogonal_mat_def diagonal_mat_def by auto

definition orthogonal :: "'a::semiring_0 vec list  bool"
  where "orthogonal vs 
    i j. i < length vs  j < length vs 
      (vs ! i  vs ! j = 0) = (i  j)"

lemma orthogonalD[elim]:
  "orthogonal vs  i < length vs  j < length vs 
  (nth vs i  nth vs j = 0) = (i  j)"
  unfolding orthogonal_def by auto

lemma orthogonalI[intro]:
  "(i j. i < length vs  j < length vs  (nth vs i  nth vs j = 0) = (i  j)) 
   orthogonal vs"
  unfolding orthogonal_def by auto


lemma transpose_four_block_mat: assumes *: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "transpose_mat (four_block_mat A B C D) =
    four_block_mat (transpose_mat A) (transpose_mat C) (transpose_mat B) (transpose_mat D)"
  by (rule eq_matI, insert *, auto)

lemma zero_transpose_mat[simp]: "transpose_mat (0m n m) = (0m m n)"
  by (rule eq_matI, auto)

lemma upper_triangular_four_block: assumes AD: "A  carrier_mat n n" "D  carrier_mat m m"
  and ut: "upper_triangular A" "upper_triangular D"
  shows "upper_triangular (four_block_mat A B (0m m n) D)"
proof -
  let ?C = "four_block_mat A B (0m m n) D"
  from AD have dim: "dim_row ?C = n + m" "dim_col ?C = n + m" "dim_row A = n" by auto
  show ?thesis
  proof (rule upper_triangularI, unfold dim)
    fix i j
    assume *: "j < i" "i < n + m"
    show "?C $$ (i,j) = 0"
    proof (cases "i < n")
      case True
      with upper_triangularD[OF ut(1) *(1)] * AD show ?thesis by auto
    next
      case False note i = this
      show ?thesis by (cases "j < n", insert upper_triangularD[OF ut(2)] * i AD, auto)
    qed
  qed
qed

lemma pow_four_block_mat: assumes A: "A  carrier_mat n n"
  and B: "B  carrier_mat m m"
  shows "(four_block_mat A (0m n m) (0m m n) B) ^m k =
    four_block_mat (A ^m k) (0m n m) (0m m n) (B ^m k)"
proof (induct k)
  case (Suc k)
  let ?FB = "λ A B. four_block_mat A (0m n m) (0m m n) B"
  let ?A = "?FB A B"
  let ?B = "?FB (A ^m k) (B ^m k)"
  from A B have Ak: "A ^m k  carrier_mat n n" and Bk: "B ^m k  carrier_mat m m" by auto
  have "?A ^m Suc k = ?A ^m k * ?A" by simp
  also have "?A ^m k = ?B " by (rule Suc)
  also have "?B * ?A = ?FB (A ^m Suc k) (B ^m Suc k)"
    by (subst mult_four_block_mat[OF Ak _ _ Bk A _ _ B], insert A B, auto)
  finally show ?case .
qed (insert A B, auto)

lemma uminus_scalar_prod:
  assumes [simp]: "v : carrier_vec n" "w : carrier_vec n"
  shows "- ((v::'a::field vec)  w) = (- v)  w"
  unfolding scalar_prod_def uminus_vec_def
  apply (subst sum_negf[symmetric])
proof (rule sum.cong[OF refl])
  fix i assume i: "i : {0 ..<dim_vec w}"
  have [simp]: "dim_vec v = n" "dim_vec w = n" by auto
  show "- (v $ i * w $ i) = vec (dim_vec v) (λi. - v $ i) $ i * w $ i"
    unfolding minus_mult_left using i by auto
qed


lemma append_vec_eq:
  assumes [simp]: "v : carrier_vec n" "v' : carrier_vec n"
  shows [simp]: "v @v w = v' @v w'  v = v'  w = w'" (is "?L  ?R")
proof
  have [simp]: "dim_vec v = n" "dim_vec v' = n" by auto
  { assume L: ?L
    have vv': "v = v'"
    proof
      fix i assume i: "i < dim_vec v'"
      have "(v @v w) $ i = (v' @v w') $ i" using L by auto
      thus "v $ i = v' $ i" using i by auto
    qed auto
    moreover have "w = w'"
    proof
      show "dim_vec w = dim_vec w'" using vv' L
        by (metis add_diff_cancel_left' index_append_vec(2))
      moreover fix i assume i: "i < dim_vec w'"
      have "(v @v w) $ (n + i) = (v' @v w') $ (n + i)" using L by auto
      ultimately show "w $ i = w' $ i" using i by simp
    qed
    ultimately show ?R by simp
  }
qed auto

lemma append_vec_add:
  assumes [simp]: "v : carrier_vec n" "v' : carrier_vec n"
      and [simp]: "w : carrier_vec m" "w' : carrier_vec m"
  shows "(v @v w) + (v' @v w') = (v + v') @v (w + w')" (is "?L = ?R")
proof
  have [simp]: "dim_vec v = n" "dim_vec v' = n" by auto
  have [simp]: "dim_vec w = m" "dim_vec w' = m" by auto
  fix i assume i: "i < dim_vec ?R"
  thus "?L $ i = ?R $ i" by (cases "i < n",auto)
qed auto


lemma mult_mat_vec_split:
  assumes A: "A : carrier_mat n n"
      and D: "D : carrier_mat m m"
      and a: "a : carrier_vec n"
      and d: "d : carrier_vec m"
  shows "four_block_mat A (0m n m) (0m m n) D *v (a @v d) = A *v a @v D *v d"
    (is "?A00D *v _ = ?r")
proof
  have A00D: "?A00D : carrier_mat (n+m) (n+m)" using four_block_carrier_mat[OF A D].
  fix i assume i: "i < dim_vec ?r"
  show "(?A00D *v (a @v d)) $ i = ?r $ i" (is "?li = _")
  proof (cases "i < n")
    case True
      have "?li = (row A i @v 0v m)  (a @v d)"
        using A row_four_block_mat[OF A _ _ D] True by simp
      also have "... = row A i  a + 0v m  d"
        apply (rule scalar_prod_append) using A D a d True by auto
      also have "... = row A i  a" using d by simp
      finally show ?thesis using A True by auto
    next case False
      let ?i = "i - n"
      have "?li = (0v n @v row D ?i)  (a @v d)"
        using i row_four_block_mat[OF A _ _ D] False A D by simp
      also have "... = 0v n  a + row D ?i  d"
        apply (rule scalar_prod_append) using A D a d False by auto
      also have "... = row D ?i  d" using a by simp
      finally show ?thesis using A D False i by auto
  qed
qed auto

lemma similar_mat_witI: assumes "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n" "Q  carrier_mat n n"
  shows "similar_mat_wit A B P Q" using assms unfolding similar_mat_wit_def Let_def by auto

lemma similar_mat_witD: assumes "n = dim_row A" "similar_mat_wit A B P Q"
  shows "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n" "Q  carrier_mat n n"
  using assms(2) unfolding similar_mat_wit_def Let_def assms(1)[symmetric] by auto

lemma similar_mat_witD2: assumes "A  carrier_mat n m" "similar_mat_wit A B P Q"
  shows "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n" "Q  carrier_mat n n"
  using similar_mat_witD[OF _ assms(2), of n] assms(1)[unfolded carrier_mat_def] by auto

lemma similar_mat_wit_sym: assumes sim: "similar_mat_wit A B P Q"
  shows "similar_mat_wit B A Q P"
proof -
  from similar_mat_witD[OF refl sim] obtain n where
    AB: "{A, B, P, Q}  carrier_mat n n" "P * Q = 1m n" "Q * P = 1m n" and A: "A = P * B * Q" by blast
  hence *: "{B, A, Q, P}  carrier_mat n n" "Q * P = 1m n" "P * Q = 1m n" by auto
  let ?c = "λ A. A  carrier_mat n n"
  from * have Carr: "?c B" "?c P" "?c Q" by auto
  note [simp] = assoc_mult_mat[of _ n n _ n _ n]
  show ?thesis
  proof (rule similar_mat_witI[of _ _ n])
    have "Q * A * P = (Q * P) * B * (Q * P)"
      using Carr unfolding A by simp
    also have " = B" using Carr unfolding AB by simp
    finally show "B = Q * A * P" by simp
  qed (insert * AB, auto)
qed

lemma similar_mat_wit_refl: assumes A: "A  carrier_mat n n"
  shows "similar_mat_wit A A (1m n) (1m n)"
  by (rule similar_mat_witI[OF _ _ _ A], insert A, auto)

lemma similar_mat_wit_trans: assumes AB: "similar_mat_wit A B P Q"
  and BC: "similar_mat_wit B C P' Q'"
  shows "similar_mat_wit A C (P * P') (Q' * Q)"
proof -
  from similar_mat_witD[OF refl AB] obtain n where
    AB: "{A, B, P, Q}  carrier_mat n n" "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q" by blast
  hence B: "B  carrier_mat n n" by auto
  from similar_mat_witD2[OF B BC] have
    BC: "{C, P', Q'}  carrier_mat n n" "P' * Q' = 1m n" "Q' * P' = 1m n" "B = P' * C * Q'" by auto
  let ?c = "λ A. A  carrier_mat n n"
  let ?P = "P * P'"
  let ?Q = "Q' * Q"
  from AB BC have carr: "?c A" "?c B" "?c C" "?c P" "?c P'" "?c Q" "?c Q'"
    and Carr: "{A, C, ?P, ?Q}  carrier_mat n n" by auto
  note [simp] = assoc_mult_mat[of _ n n _ n _ n]
  have id: "A = ?P * C * ?Q" unfolding AB(4)[unfolded BC(4)] using carr
    by simp
  have "?P * ?Q = P * (P' * Q') * Q" using carr by simp
  also have " = 1m n" unfolding BC using carr AB by simp
  finally have PQ: "?P * ?Q = 1m n" .
  have "?Q * ?P = Q' * (Q * P) * P'" using carr by simp
  also have " = 1m n" unfolding AB using carr BC by simp
  finally have QP: "?Q * ?P = 1m n" .
  show ?thesis
    by (rule similar_mat_witI[OF PQ QP id], insert Carr, auto)
qed

lemma similar_mat_refl: "A  carrier_mat n n  similar_mat A A"
  using similar_mat_wit_refl unfolding similar_mat_def by blast

lemma similar_mat_trans: "similar_mat A B  similar_mat B C  similar_mat A C"
  using similar_mat_wit_trans unfolding similar_mat_def by blast

lemma similar_mat_sym: "similar_mat A B  similar_mat B A"
  using similar_mat_wit_sym unfolding similar_mat_def by blast

lemma similar_mat_wit_four_block: assumes
      1: "similar_mat_wit A1 B1 P1 Q1"
  and 2: "similar_mat_wit A2 B2 P2 Q2"
  and URA: "URA = (P1 * UR * Q2)"
  and LLA: "LLA = (P2 * LL * Q1)"
  and A1: "A1  carrier_mat n n"
  and A2: "A2  carrier_mat m m"
  and LL: "LL  carrier_mat m n"
  and UR: "UR  carrier_mat n m"
  shows "similar_mat_wit (four_block_mat A1 URA LLA A2) (four_block_mat B1 UR LL B2)
    (four_block_mat P1 (0m n m) (0m m n) P2) (four_block_mat Q1 (0m n m) (0m m n) Q2)"
  (is "similar_mat_wit ?A ?B ?P ?Q")
proof -
  let ?n = "n + m"
  let ?O1 = "1m n"   let ?O2 = "1m m"   let ?O = "1m ?n"
  from similar_mat_witD2[OF A1 1] have 11: "P1 * Q1 = ?O1" "Q1 * P1 = ?O1"
    and P1: "P1  carrier_mat n n" and Q1: "Q1  carrier_mat n n"
    and B1: "B1  carrier_mat n n" and 1: "A1 = P1 * B1 * Q1" by auto
  from similar_mat_witD2[OF A2 2] have 21: "P2 * Q2 = ?O2" "Q2 * P2 = ?O2"
    and P2: "P2  carrier_mat m m" and Q2: "Q2  carrier_mat m m"
    and B2: "B2  carrier_mat m m" and 2: "A2 = P2 * B2 * Q2" by auto
  have PQ1: "?P * ?Q = ?O"
    by (subst mult_four_block_mat[OF P1 _ _ P2 Q1 _ _ Q2], unfold 11 21, insert P1 P2 Q1 Q2,
      auto intro!: eq_matI)
  have QP1: "?Q * ?P = ?O"
    by (subst mult_four_block_mat[OF Q1 _ _ Q2 P1 _ _ P2], unfold 11 21, insert P1 P2 Q1 Q2,
      auto intro!: eq_matI)
  let ?PB = "?P * ?B"
  have P: "?P  carrier_mat ?n ?n" using P1 P2 by auto
  have Q: "?Q  carrier_mat ?n ?n" using Q1 Q2 by auto
  have B: "?B  carrier_mat ?n ?n" using B1 UR LL B2 by auto
  have PB: "?PB  carrier_mat ?n ?n" using P B by auto
  have PB1: "P1 * B1  carrier_mat n n" using P1 B1 by auto
  have PB2: "P2 * B2  carrier_mat m m" using P2 B2 by auto
  have P1UR: "P1 * UR  carrier_mat n m" using P1 UR by auto
  have P2LL: "P2 * LL  carrier_mat m n" using P2 LL by auto
  have id: "?PB = four_block_mat (P1 * B1) (P1 * UR) (P2 * LL) (P2 * B2)"
    by (subst mult_four_block_mat[OF P1 _ _ P2 B1 UR LL B2], insert P1 P2 B1 B2 LL UR, auto)
  have id: "?PB * ?Q = four_block_mat (P1 * B1 * Q1) (P1 * UR * Q2)
    (P2 * LL * Q1) (P2 * B2 * Q2)" unfolding id
    by (subst mult_four_block_mat[OF PB1 P1UR P2LL PB2 Q1 _ _ Q2],
    insert P1 P2 B1 B2 Q1 Q2 UR LL, auto)
  have id: "?A = ?P * ?B * ?Q" unfolding id 1 2 URA LLA ..
  show ?thesis
    by (rule similar_mat_witI[OF PQ1 QP1 id], insert A1 A2 B1 B2 Q1 Q2 P1 P2, auto)
qed


lemma similar_mat_four_block_0_ex: assumes
      1: "similar_mat A1 B1"
  and 2: "similar_mat A2 B2"
  and A0: "A0  carrier_mat n m"
  and A1: "A1  carrier_mat n n"
  and A2: "A2  carrier_mat m m"
  shows " B0. B0  carrier_mat n m  similar_mat (four_block_mat A1 A0 (0m m n) A2)
    (four_block_mat B1 B0 (0m m n) B2)"
proof -
  from 1[unfolded similar_mat_def] obtain P1 Q1 where 1: "similar_mat_wit A1 B1 P1 Q1" by auto
  note