Session QHLProver

Theory Complex_Matrix

section ‹Complex matrices›

theory Complex_Matrix
  imports 
    "Jordan_Normal_Form.Matrix" 
    "Jordan_Normal_Form.Conjugate" 
    "Jordan_Normal_Form.Jordan_Normal_Form_Existence"
begin

subsection ‹Trace of a matrix›

definition trace :: "'a::ring mat  'a" where
  "trace A = ( i  {0 ..< dim_row A}. A $$ (i,i))"

lemma trace_zero [simp]:
  "trace (0m n n) = 0"
  by (simp add: trace_def)

lemma trace_id [simp]:
  "trace (1m n) = n"
  by (simp add: trace_def)

lemma trace_comm:
  fixes A B :: "'a::comm_ring mat"
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
  shows "trace (A * B) = trace (B * A)"
proof (simp add: trace_def)
  have "(i = 0..<n. (A * B) $$ (i, i)) = (i = 0..<n. j = 0..<n. A $$ (i,j) * B $$ (j,i))"
    apply (rule sum.cong) using assms by (auto simp add: scalar_prod_def)
  also have " = (j = 0..<n. i = 0..<n. A $$ (i,j) * B $$ (j,i))"
    by (rule sum.swap)
  also have " = (j = 0..<n. col A j  row B j)"
    by (metis (no_types, lifting) A B atLeastLessThan_iff carrier_matD index_col index_row scalar_prod_def sum.cong)
  also have " = (j = 0..<n. row B j  col A j)"
    apply (rule sum.cong) apply auto
    apply (subst comm_scalar_prod[where n=n]) apply auto
    using assms by auto
  also have " = (j = 0..<n. (B * A) $$ (j, j))"
    apply (rule sum.cong) using assms by auto
  finally show "(i = 0..<dim_row A. (A * B) $$ (i, i)) = (i = 0..<dim_row B. (B * A) $$ (i, i))"
    using A B by auto
qed

lemma trace_add_linear:
  fixes A B :: "'a::comm_ring mat"
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
  shows "trace (A + B) = trace A + trace B" (is "?lhs = ?rhs")
proof -
  have "?lhs = (i=0..<n. A$$(i, i) + B$$(i, i))" unfolding trace_def using A B by auto
  also have " = (i=0..<n. A$$(i, i)) + (i=0..<n. B$$(i, i))" by (auto simp add: sum.distrib)
  finally have l: "?lhs = (i=0..<n. A$$(i, i)) + (i=0..<n. B$$(i, i))".
  have r: "?rhs = (i=0..<n. A$$(i, i)) + (i=0..<n. B$$(i, i))" unfolding trace_def using A B by auto
  from l r show ?thesis by auto
qed

lemma trace_minus_linear:
  fixes A B :: "'a::comm_ring mat"
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
  shows "trace (A - B) = trace A - trace B" (is "?lhs = ?rhs")
proof -
  have "?lhs = (i=0..<n. A$$(i, i) - B$$(i, i))" unfolding trace_def using A B by auto
  also have " = (i=0..<n. A$$(i, i)) - (i=0..<n. B$$(i, i))" by (auto simp add: sum_subtractf)
  finally have l: "?lhs = (i=0..<n. A$$(i, i)) - (i=0..<n. B$$(i, i))".
  have r: "?rhs = (i=0..<n. A$$(i, i)) - (i=0..<n. B$$(i, i))" unfolding trace_def using A B by auto
  from l r show ?thesis by auto
qed

lemma trace_smult: 
  assumes "A  carrier_mat n n"
  shows "trace (c m A) = c * trace A"
proof -
  have "trace (c m A) = (i = 0..<dim_row A. c * A $$ (i, i))" unfolding trace_def using assms by auto
  also have " = c * (i = 0..<dim_row A. A $$ (i, i))"
    by (simp add: sum_distrib_left)
  also have " = c * trace A" unfolding trace_def by auto
  ultimately show ?thesis by auto
qed

subsection ‹Conjugate of a vector›

lemma conjugate_scalar_prod:
  fixes v w :: "'a::conjugatable_ring vec"
  assumes "dim_vec v = dim_vec w"
  shows "conjugate (v  w) = conjugate v  conjugate w"
  using assms by (simp add: scalar_prod_def sum_conjugate conjugate_dist_mul)

subsection ‹Inner product›

abbreviation inner_prod :: "'a vec  'a vec  'a :: conjugatable_ring"
  where "inner_prod v w  w ∙c v"

lemma conjugate_scalar_prod_Im [simp]:
  "Im (v ∙c v) = 0"
  by (simp add: scalar_prod_def conjugate_vec_def sum.neutral)

lemma conjugate_scalar_prod_Re [simp]:
  "Re (v ∙c v)  0"
  by (simp add: scalar_prod_def conjugate_vec_def sum_nonneg)

lemma self_cscalar_prod_geq_0:
  fixes v ::  "'a::conjugatable_ordered_field vec"
  shows "v ∙c v  0"
  by (auto simp add: scalar_prod_def, rule sum_nonneg, rule conjugate_square_positive)

lemma inner_prod_distrib_left:
  fixes u v w :: "('a::conjugatable_field) vec"
  assumes dimu: "u  carrier_vec n" and dimv:"v  carrier_vec n" and dimw: "w  carrier_vec n" 
  shows "inner_prod (v + w) u = inner_prod v u + inner_prod w u" (is "?lhs = ?rhs")
proof -
  have dimcv: "conjugate v  carrier_vec n" and dimcw: "conjugate w  carrier_vec n" using assms by auto
  have dimvw: "conjugate (v + w)  carrier_vec n" using assms by auto
  have "u  (conjugate (v + w)) = u  conjugate v + u  conjugate w"
    using dimv dimw dimu dimcv dimcw 
    by (metis conjugate_add_vec scalar_prod_add_distrib)
  then show ?thesis by auto
qed

lemma inner_prod_distrib_right:
  fixes u v w :: "('a::conjugatable_field) vec"
  assumes dimu: "u  carrier_vec n" and dimv:"v  carrier_vec n" and dimw: "w  carrier_vec n" 
  shows "inner_prod u (v + w) = inner_prod u v + inner_prod u w" (is "?lhs = ?rhs")
proof -
  have dimvw: "v + w  carrier_vec n" using assms by auto
  have dimcu: "conjugate u  carrier_vec n" using assms by auto
  have "(v + w)  (conjugate u) = v  conjugate u + w  conjugate u"
    apply (simp add: comm_scalar_prod[OF dimvw dimcu])
    apply (simp add: scalar_prod_add_distrib[OF dimcu dimv dimw])
    apply (insert dimv dimw dimcu, simp add: comm_scalar_prod[of _ n])
    done
  then show ?thesis by auto
qed                

lemma inner_prod_minus_distrib_right:
  fixes u v w :: "('a::conjugatable_field) vec"
  assumes dimu: "u  carrier_vec n" and dimv:"v  carrier_vec n" and dimw: "w  carrier_vec n" 
  shows "inner_prod u (v - w) = inner_prod u v - inner_prod u w" (is "?lhs = ?rhs")
proof -
  have dimvw: "v - w  carrier_vec n" using assms by auto
  have dimcu: "conjugate u  carrier_vec n" using assms by auto
  have "(v - w)  (conjugate u) = v  conjugate u - w  conjugate u"
    apply (simp add: comm_scalar_prod[OF dimvw dimcu])
    apply (simp add: scalar_prod_minus_distrib[OF dimcu dimv dimw])
    apply (insert dimv dimw dimcu, simp add: comm_scalar_prod[of _ n])
    done
  then show ?thesis by auto
qed                

lemma inner_prod_smult_right:
  fixes u v :: "complex vec"
  assumes dimu: "u  carrier_vec n" and dimv:"v  carrier_vec n" 
  shows "inner_prod (a v u) v = conjugate a * inner_prod u v" (is "?lhs = ?rhs")
  using assms apply (simp add: scalar_prod_def conjugate_dist_mul) 
  apply (subst sum_distrib_left) by (rule sum.cong, auto)

lemma inner_prod_smult_left:
  fixes u v :: "complex vec"
  assumes dimu: "u  carrier_vec n" and dimv: "v  carrier_vec n" 
  shows "inner_prod u (a v v) = a * inner_prod u v" (is "?lhs = ?rhs")
  using assms apply (simp add: scalar_prod_def) 
  apply (subst sum_distrib_left) by (rule sum.cong, auto)

lemma inner_prod_smult_left_right:
  fixes u v :: "complex vec"
  assumes dimu: "u  carrier_vec n" and dimv: "v  carrier_vec n" 
  shows "inner_prod (a v u) (b v v) = conjugate a * b  * inner_prod u v" (is "?lhs = ?rhs")
  using assms apply (simp add: scalar_prod_def) 
  apply (subst sum_distrib_left) by (rule sum.cong, auto)

lemma inner_prod_swap:
  fixes x y :: "complex vec"
  assumes "y  carrier_vec n" and "x  carrier_vec n" 
  shows "inner_prod y x = conjugate (inner_prod x y)"
  apply (simp add: scalar_prod_def)
  apply (rule sum.cong) using assms by auto

text ‹Cauchy-Schwarz theorem for complex vectors. This is analogous to aux\_Cauchy
  and Cauchy\_Schwarz\_ineq in Generalizations2.thy in QR\_Decomposition. Consider
  merging and moving to Isabelle library.›
lemma aux_Cauchy:
  fixes x y :: "complex vec"
  assumes "x  carrier_vec n" and "y  carrier_vec n"
  shows "0  inner_prod x x + a * (inner_prod x y) + (cnj a) * ((cnj (inner_prod x y)) + a * (inner_prod y y))"
proof -
  have "(inner_prod (x+ a v y) (x+a v y)) = (inner_prod (x+a v y) x) + (inner_prod (x+a v y) (a v y))" 
    apply (subst inner_prod_distrib_right) using assms by auto
  also have " = inner_prod x x + (a) * (inner_prod x y) + cnj a * ((cnj (inner_prod x y)) + (a) * (inner_prod y y))"
    apply (subst (1 2) inner_prod_distrib_left[of _ n]) apply (auto simp add: assms)
    apply (subst (1 2) inner_prod_smult_right[of _ n]) apply (auto simp add: assms)
    apply (subst inner_prod_smult_left[of _ n]) apply (auto simp add: assms)
    apply (subst inner_prod_swap[of y n x]) apply (auto simp add: assms)
    unfolding distrib_left
    by auto
  finally show ?thesis by (metis self_cscalar_prod_geq_0)
qed

lemma Cauchy_Schwarz_complex_vec:
  fixes x y :: "complex vec"
  assumes "x  carrier_vec n" and "y  carrier_vec n"
  shows "inner_prod x y * inner_prod y x  inner_prod x x * inner_prod y y"
proof -
  define cnj_a where "cnj_a = - (inner_prod x y)/ cnj (inner_prod y y)"
  define a where "a = cnj (cnj_a)"
  have cnj_rw: "(cnj a) = cnj_a" 
    unfolding a_def by (simp)
  have rw_0: "cnj (inner_prod x y) + a * (inner_prod y y) = 0"
    unfolding a_def cnj_a_def using assms(1) assms(2) conjugate_square_eq_0_vec by fastforce
  have "0   (inner_prod x x + a * (inner_prod x y) + (cnj a) * ((cnj (inner_prod x y)) + a * (inner_prod y y)))"
    using aux_Cauchy assms by auto
  also have " =  (inner_prod x x + a * (inner_prod x y))" unfolding rw_0 by auto
  also have " =  (inner_prod x x - (inner_prod x y) * cnj (inner_prod x y) / (inner_prod y y))" 
  unfolding a_def cnj_a_def by simp
  finally have " 0   (inner_prod x x - (inner_prod x y) * cnj (inner_prod x y) / (inner_prod y y)) " .
  hence "0  (inner_prod x x - (inner_prod x y) * cnj (inner_prod x y) / (inner_prod y y)) * (inner_prod y y)" by auto
  also have " = ((inner_prod x x)*(inner_prod y y) - (inner_prod x y) * cnj (inner_prod x y))"
    by (smt add.inverse_neutral add_diff_cancel diff_0 diff_divide_eq_iff divide_cancel_right mult_eq_0_iff nonzero_mult_div_cancel_right rw_0)
  finally have "(inner_prod x y) * cnj (inner_prod x y)  (inner_prod x x)*(inner_prod y y)" by auto
  then show ?thesis 
    apply (subst inner_prod_swap[of y n x]) by (auto simp add: assms)
qed

subsection ‹Hermitian adjoint of a matrix›

abbreviation adjoint where "adjoint  mat_adjoint"

lemma adjoint_dim_row [simp]:
  "dim_row (adjoint A) = dim_col A" by (simp add: mat_adjoint_def)

lemma adjoint_dim_col [simp]:
  "dim_col (adjoint A) = dim_row A" by (simp add: mat_adjoint_def)

lemma adjoint_dim:
  "A  carrier_mat n n  adjoint A  carrier_mat n n"
  using adjoint_dim_col adjoint_dim_row by blast

lemma adjoint_def:
  "adjoint A = mat (dim_col A) (dim_row A) (λ(i,j). conjugate (A $$ (j,i)))"
  unfolding mat_adjoint_def mat_of_rows_def by auto

lemma adjoint_eval:
  assumes "i < dim_col A" "j < dim_row A"
  shows "(adjoint A) $$ (i,j) = conjugate (A $$ (j,i))"
  using assms by (simp add: adjoint_def)

lemma adjoint_row:
  assumes "i < dim_col A"
  shows "row (adjoint A) i = conjugate (col A i)"
  apply (rule eq_vecI)
  using assms by (auto simp add: adjoint_eval)

lemma adjoint_col:
  assumes "i < dim_row A"
  shows "col (adjoint A) i = conjugate (row A i)"
  apply (rule eq_vecI)
  using assms by (auto simp add: adjoint_eval)

text ‹The identity <v, A w> = <A* v, w>›
lemma adjoint_def_alter:
  fixes v w :: "'a::conjugatable_field vec"
    and A :: "'a::conjugatable_field mat"
  assumes dims: "v  carrier_vec n" "w  carrier_vec m" "A  carrier_mat n m"
  shows "inner_prod v (A *v w) = inner_prod (adjoint A *v v) w" (is "?lhs = ?rhs")
proof -
  from dims have "?lhs = (i=0..<dim_vec v. (j=0..<dim_vec w.
                conjugate (v$i) * A$$(i, j) * w$j))"
    apply (simp add: scalar_prod_def sum_distrib_right )
    apply (rule sum.cong, simp)
    apply (rule sum.cong, auto)
    done
  moreover from assms have "?rhs = (i=0..<dim_vec v. (j=0..<dim_vec w.
                conjugate (v$i) * A$$(i, j) * w$j))"
    apply (simp add: scalar_prod_def  adjoint_eval 
                     sum_conjugate conjugate_dist_mul sum_distrib_left)
    apply (subst sum.swap[where ?A = "{0..<n}"])
    apply (rule sum.cong, simp)
    apply (rule sum.cong, auto)
    done
  ultimately show ?thesis by simp
qed

lemma adjoint_one:
  shows "adjoint (1m n) = (1m n::complex mat)"
  apply (rule eq_matI) 
  by (auto simp add: adjoint_eval)

lemma adjoint_scale:
  fixes A :: "'a::conjugatable_field mat"
  shows "adjoint (a m A) = (conjugate a) m adjoint A"
  apply (rule eq_matI) using conjugatable_ring_class.conjugate_dist_mul
  by (auto simp add: adjoint_eval)

lemma adjoint_add:
  fixes A B :: "'a::conjugatable_field mat"
  assumes "A  carrier_mat n m" "B  carrier_mat n m"
  shows "adjoint (A + B) = adjoint A + adjoint B"
  apply (rule eq_matI)
  using assms conjugatable_ring_class.conjugate_dist_add 
  by( auto simp add: adjoint_eval)

lemma adjoint_minus:
  fixes A B :: "'a::conjugatable_field mat"
  assumes "A  carrier_mat n m" "B  carrier_mat n m"
  shows "adjoint (A - B) = adjoint A - adjoint B"
  apply (rule eq_matI)
  using assms apply(auto simp add: adjoint_eval)
  by (metis add_uminus_conv_diff conjugate_dist_add conjugate_neg)

lemma adjoint_mult:
  fixes A B :: "'a::conjugatable_field mat"
  assumes "A  carrier_mat n m" "B  carrier_mat m l"
  shows "adjoint (A * B) = adjoint B * adjoint A"
proof (rule eq_matI, auto simp add: adjoint_eval adjoint_row adjoint_col)
  fix i j
  assume "i < dim_col B" "j < dim_row A"
  show "conjugate (row A j  col B i) = conjugate (col B i)  conjugate (row A j)"
    using assms apply (simp add: conjugate_scalar_prod)
    apply (subst comm_scalar_prod[where n="dim_row B"])
    by (auto simp add: carrier_vecI)
qed

lemma adjoint_adjoint:
  fixes A :: "'a::conjugatable_field mat"
  shows "adjoint (adjoint A) = A"
  by (rule eq_matI, auto simp add: adjoint_eval)

lemma trace_adjoint_positive:
  fixes A :: "complex mat"
  shows "trace (A * adjoint A)  0"
  apply (auto simp add: trace_def adjoint_col)
  apply (rule sum_nonneg) by auto

subsection ‹Algebraic manipulations on matrices›

lemma right_add_zero_mat[simp]:
  "(A :: 'a :: monoid_add mat)  carrier_mat nr nc  A + 0m nr nc = A"
  by (intro eq_matI, auto)

lemma add_carrier_mat':
  "A  carrier_mat nr nc  B  carrier_mat nr nc  A + B  carrier_mat nr nc"
  by simp

lemma minus_carrier_mat':
  "A  carrier_mat nr nc  B  carrier_mat nr nc  A - B  carrier_mat nr nc"
  by auto

lemma swap_plus_mat:
  fixes A B C :: "'a::semiring_1 mat"
  assumes "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n"
  shows "A + B + C = A + C + B"
  by (metis assms assoc_add_mat comm_add_mat)

lemma uminus_mat:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  shows "-A = (-1) m A"
  by auto

ML_file "mat_alg.ML"
method_setup mat_assoc = mat_assoc_method
  "Normalization of expressions on matrices"

lemma mat_assoc_test:
  fixes A B C D :: "complex mat"
  assumes "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n" "D  carrier_mat n n"
  shows
    "(A * B) * (C * D) = A * B * C * D"
    "adjoint (A * adjoint B) * C = B * (adjoint A * C)"
    "A * 1m n * 1m n * B * 1m n = A * B"
    "(A - B) + (B - C) = A + (-B) + B + (-C)"
    "A + (B - C) = A + B - C"
    "A - (B + C + D) = A - B - C - D"
    "(A + B) * (B + C) = A * B + B * B + A * C + B * C"
    "A - B = A + (-1) m B"
    "A * (B - C) * D = A * B * D - A * C * D"
    "trace (A * B * C) = trace (B * C * A)"
    "trace (A * B * C * D) = trace (C * D * A * B)"
    "trace (A + B * C) = trace A + trace (C * B)"
    "A + B = B + A"
    "A + B + C = C + B + A"
    "A + B + (C + D) = A + C + (B + D)"
  using assms by (mat_assoc n)+

subsection ‹Hermitian matrices›

text ‹A Hermitian matrix is a matrix that is equal to its Hermitian adjoint.›
definition hermitian :: "'a::conjugatable_field mat  bool" where
  "hermitian A  (adjoint A = A)"

lemma hermitian_one:
  shows "hermitian ((1m n)::('a::conjugatable_field mat))"
  unfolding hermitian_def 
proof-
  have "conjugate (1::'a) = 1"   
    apply (subst mult_1_right[symmetric, of "conjugate 1"])
    apply (subst conjugate_id[symmetric, of "conjugate 1 * 1"])
    apply (subst conjugate_dist_mul)
    apply auto
    done
  then show "adjoint ((1m n)::('a::conjugatable_field mat)) = (1m n)" 
    by (auto simp add: adjoint_eval) 
qed

subsection ‹Inverse matrices›

lemma inverts_mat_symm:
  fixes A B :: "'a::field mat"
  assumes dim: "A  carrier_mat n n" "B  carrier_mat n n"
    and AB: "inverts_mat A B"
  shows "inverts_mat B A"
proof -
  have "A * B = 1m n" using dim AB unfolding inverts_mat_def by auto
  with dim have "B * A = 1m n"  by (rule mat_mult_left_right_inverse)
  then show "inverts_mat B A" using dim inverts_mat_def by auto
qed

lemma inverts_mat_unique:
  fixes A B C :: "'a::field mat"
  assumes dim: "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n" 
    and AB: "inverts_mat A B" and AC: "inverts_mat A C"
  shows "B = C"
proof -
  have AB1: "A * B = 1m n" using AB dim unfolding inverts_mat_def by auto
  have "A * C = 1m n" using AC dim unfolding inverts_mat_def by auto
  then have CA1: "C * A = 1m n" using mat_mult_left_right_inverse[of A n C] dim by auto
  then have "C = C * 1m n" using dim by auto
  also have " = C * (A * B)" using AB1 by auto
  also have " = (C * A) * B" using dim by auto
  also have " = 1m n * B" using CA1 by auto
  also have " = B" using dim by auto
  finally show "B = C" ..
qed

subsection ‹Unitary matrices›

text ‹A unitary matrix is a matrix whose Hermitian adjoint is also its inverse.›
definition unitary :: "'a::conjugatable_field mat  bool" where
  "unitary A  A  carrier_mat (dim_row A) (dim_row A)  inverts_mat A (adjoint A)"

lemma unitaryD2:
  assumes "A  carrier_mat n n"
  shows "unitary A  inverts_mat (adjoint A) A"
  using assms adjoint_dim inverts_mat_symm unitary_def by blast

lemma unitary_simps [simp]:
  "A  carrier_mat n n  unitary A  adjoint A * A = 1m n"
  "A  carrier_mat n n  unitary A  A * adjoint A = 1m n"
  apply (metis adjoint_dim_row carrier_matD(2) inverts_mat_def unitaryD2)
  by (simp add: inverts_mat_def unitary_def)

lemma unitary_adjoint [simp]:
  assumes "A  carrier_mat n n" "unitary A"
  shows "unitary (adjoint A)" 
  unfolding unitary_def
  using  adjoint_dim[OF assms(1)] assms by (auto simp add: unitaryD2[OF assms] adjoint_adjoint)

lemma unitary_one:
  shows "unitary ((1m n)::('a::conjugatable_field mat))"
  unfolding unitary_def 
proof -
  define I where I_def[simp]: "I  ((1m n)::('a::conjugatable_field mat))"
  have dim: "I  carrier_mat n n" by auto
  have "hermitian I" using hermitian_one  by auto
  hence "adjoint I = I" using hermitian_def by auto
  with dim show "I  carrier_mat (dim_row I) (dim_row I)  inverts_mat I (adjoint I)" 
    unfolding inverts_mat_def using dim by auto
qed

lemma unitary_zero:
  fixes A :: "'a::conjugatable_field mat"
  assumes "A  carrier_mat 0 0"
  shows "unitary A"
  unfolding unitary_def inverts_mat_def Let_def using assms by auto

lemma unitary_elim:
  assumes dims: "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n"
    and uP: "unitary P" and eq: "P * A * adjoint P = P * B * adjoint P"
  shows "A = B"
proof -
  have dimaP: "adjoint P  carrier_mat n n" using dims by auto
  have iv: "inverts_mat P (adjoint P)" using uP unitary_def by auto
  then have "P * (adjoint P) = 1m n" using inverts_mat_def dims by auto
  then have aPP: "adjoint P * P = 1m n" using mat_mult_left_right_inverse[OF dims(3) dimaP] by auto
  have "adjoint P * (P * A * adjoint P) * P = (adjoint P * P) * A * (adjoint P * P)" 
    using dims dimaP by (mat_assoc n)
  also have " = 1m n * A * 1m n" using aPP by auto
  also have " = A" using dims by auto
  finally have eqA: "A = adjoint P * (P * A * adjoint P) * P" ..
  have "adjoint P * (P * B * adjoint P) * P = (adjoint P * P) * B * (adjoint P * P)" 
    using dims dimaP by (mat_assoc n)
  also have " = 1m n * B * 1m n" using aPP by auto
  also have " = B" using dims by auto
  finally have eqB: "B = adjoint P * (P * B * adjoint P) * P" ..
  then show ?thesis using eqA eqB eq by auto
qed

lemma unitary_is_corthogonal:
  fixes U :: "'a::conjugatable_field mat"
  assumes dim: "U  carrier_mat n n" 
    and U: "unitary U"
  shows "corthogonal_mat U"
  unfolding corthogonal_mat_def Let_def
proof (rule conjI)
  have dima: "adjoint U  carrier_mat n n" using dim by auto
  have aUU: "mat_adjoint U * U = (1m n)"
    apply (insert U[unfolded unitary_def] dim dima, drule conjunct2)
    apply (drule inverts_mat_symm[of "U", OF dim dima], unfold inverts_mat_def, auto)
    done
  then show "diagonal_mat (mat_adjoint U * U)"
    by (simp add: diagonal_mat_def)
  show "i<dim_col U. (mat_adjoint U * U) $$ (i, i)  0" using dim by (simp add: aUU)
qed

lemma unitary_times_unitary:
  fixes P Q :: "'a:: conjugatable_field mat"
  assumes dim: "P  carrier_mat n n" "Q  carrier_mat n n"
    and uP: "unitary P" and uQ: "unitary Q"
  shows "unitary (P * Q)"
proof -
  have dim_pq: "P * Q  carrier_mat n n" using dim by auto
  have "(P * Q) * adjoint (P * Q) = P * (Q * adjoint Q) * adjoint P" using dim by (mat_assoc n)
  also have " = P * (1m n) * adjoint P" using uQ dim by auto
  also have " = P * adjoint P" using dim by (mat_assoc n)
  also have " = 1m n" using uP dim by simp
  finally have "(P * Q) * adjoint (P * Q) = 1m n" by auto
  hence "inverts_mat (P * Q) (adjoint (P * Q))" 
    using inverts_mat_def dim_pq by auto
  thus "unitary (P*Q)" using unitary_def dim_pq by auto
qed

lemma unitary_operator_keep_trace:
  fixes U A :: "complex mat"
  assumes dU: "U  carrier_mat n n" and dA: "A  carrier_mat n n" and u: "unitary U"
  shows "trace A = trace (adjoint U * A * U)"
proof -
  have u': "U * adjoint U = 1m n" using u unfolding unitary_def inverts_mat_def using dU by auto
  have "trace (adjoint U * A * U) = trace (U * adjoint U * A)" using dU dA by (mat_assoc n)
  also have " = trace A" using u' dA by auto
  finally show ?thesis by auto
qed

subsection ‹Normalization of vectors›

definition vec_norm :: "complex vec  complex" where
  "vec_norm v  csqrt (v ∙c v)"

lemma vec_norm_geq_0:
  fixes v :: "complex vec"
  shows "vec_norm v  0"
  unfolding vec_norm_def by (insert self_cscalar_prod_geq_0[of v], simp)

lemma vec_norm_zero:
  fixes v ::  "complex vec"
  assumes dim: "v  carrier_vec n"
  shows "vec_norm v = 0  v = 0v n"
  unfolding vec_norm_def
  by (subst conjugate_square_eq_0_vec[OF dim, symmetric], rule csqrt_eq_0)

lemma vec_norm_ge_0:
  fixes v ::  "complex vec"
  assumes dim_v: "v  carrier_vec n" and neq0: "v  0v n"
  shows "vec_norm v > 0"
proof -
  have geq: "vec_norm v  0" using vec_norm_geq_0 by auto
  have neq: "vec_norm v  0" 
    apply (insert dim_v neq0)
    apply (drule vec_norm_zero, auto)
    done
  show ?thesis using neq geq by (rule dual_order.not_eq_order_implies_strict)
qed

definition vec_normalize :: "complex vec  complex vec" where
  "vec_normalize v = (if (v = 0v (dim_vec v)) then v else 1 / (vec_norm v) v v)"

lemma normalized_vec_dim[simp]:
  assumes "(v::complex vec)  carrier_vec n"
  shows "vec_normalize v  carrier_vec n"
  unfolding vec_normalize_def using assms by auto

lemma vec_eq_norm_smult_normalized:
  shows "v = vec_norm v v vec_normalize v"
proof (cases "v = 0v (dim_vec v)")
  define n where "n = dim_vec v"
  then have dimv: "v  carrier_vec n" by auto
  then have dimnv: "vec_normalize v  carrier_vec n" by auto
  {
    case True
    then have v0: "v = 0v n" using n_def by auto
    then have n0: "vec_norm v = 0" using vec_norm_def by auto
    have "vec_norm v v vec_normalize v = 0v n" 
      unfolding smult_vec_def by (auto simp add: n0 carrier_vecD[OF dimnv])
    then show ?thesis using v0 by auto
    next
    case False
    then have v: "v  0v n" using n_def by auto
    then have ge0: "vec_norm v > 0" using vec_norm_ge_0 dimv by auto
    have "vec_normalize v = (1 / vec_norm v) v v" using False vec_normalize_def by auto
    then have "vec_norm v v vec_normalize v = (vec_norm v * (1 / vec_norm v)) v v"
      using smult_smult_assoc by auto
    also have " = v" using ge0 by auto
    finally have "v = vec_norm v v vec_normalize v"..
    then show "v = vec_norm v v vec_normalize v" using v by auto
  }
qed

lemma normalized_cscalar_prod:
  fixes v w :: "complex vec"
  assumes dim_v: "v  carrier_vec n" and dim_w: "w  carrier_vec n"
  shows "v ∙c w = (vec_norm v * vec_norm w) * (vec_normalize v ∙c vec_normalize w)"
  unfolding vec_normalize_def apply (split if_split, split if_split)
proof (intro conjI impI)
  note dim0 = dim_v dim_w
  have dim: "dim_vec v = n" "dim_vec w = n" using dim0 by auto
  {
    assume "w = 0v n" "v = 0v n"
    then have lhs: "v ∙c w = 0" by auto
    then moreover have rhs: "vec_norm v * vec_norm w * (v ∙c w) = 0" by auto
    ultimately have "v ∙c w = vec_norm v * vec_norm w * (v ∙c w)" by auto
  }
  with dim show "w = 0v (dim_vec w)  v = 0v (dim_vec v)  v ∙c w = vec_norm v * vec_norm w * (v ∙c w)" by auto
  {
    assume asm: "w = 0v n" "v  0v n"
    then have w0: "conjugate w = 0v n" by auto
    with dim0 have "(1 / vec_norm v v v) ∙c w = 0" by auto
    then moreover have rhs: "vec_norm v * vec_norm w * ((1 / vec_norm v v v) ∙c w) = 0" by auto
    moreover have "v ∙c w = 0" using w0 dim0 by auto
    ultimately have "v ∙c w = vec_norm v * vec_norm w * ((1 / vec_norm v v v) ∙c w)" by auto
  }
  with dim show "w = 0v (dim_vec w)  v  0v (dim_vec v)  v ∙c w = vec_norm v * vec_norm w * ((1 / vec_norm v v v) ∙c w)" by auto
  {
    assume asm: "w  0v n" "v = 0v n"
    with dim0 have "v ∙c (1 / vec_norm w v w) = 0" by auto
    then moreover have rhs: "vec_norm v * vec_norm w * (v ∙c (1 / vec_norm w v w)) = 0" by auto
    moreover have "v ∙c w = 0" using asm dim0 by auto
    ultimately have "v ∙c w = vec_norm v * vec_norm w * (v ∙c (1 / vec_norm w v w))" by auto
  }
  with dim show "w  0v (dim_vec w)  v = 0v (dim_vec v)  v ∙c w = vec_norm v * vec_norm w * (v ∙c (1 / vec_norm w v w))" by auto
  {
    assume asmw: "w  0v n" and asmv: "v  0v n"
    have "vec_norm w > 0" by (insert asmw dim0, rule vec_norm_ge_0, auto)
    then have cw: "conjugate (1 / vec_norm w) = 1 / vec_norm w" by (simp add: complex_eq_iff complex_is_Real_iff) 
    from dim0 have 
      "((1 / vec_norm v v v) ∙c (1 / vec_norm w v w)) = 1 / vec_norm v * (v ∙c (1 / vec_norm w v w))" by auto
    also have " = 1 / vec_norm v * (v  (conjugate (1 / vec_norm w) v conjugate w))"
      by (subst conjugate_smult_vec, auto)
    also have " = 1 / vec_norm v * conjugate (1 / vec_norm w) * (v  conjugate w)" using dim by auto
    also have " = 1 / vec_norm v * (1 / vec_norm w) * (v ∙c w)" using vec_norm_ge_0 cw by auto
    finally have eq1: "(1 / vec_norm v v v) ∙c (1 / vec_norm w v w) = 1 / vec_norm v * (1 / vec_norm w) * (v ∙c w)" .
    then have "vec_norm v * vec_norm w * ((1 / vec_norm v v v) ∙c (1 / vec_norm w v w)) = (v ∙c w)" 
      by (subst eq1, insert vec_norm_ge_0[of v n, OF dim_v asmv] vec_norm_ge_0[of w n, OF dim_w asmw], auto)
  }
  with dim show " w  0v (dim_vec w)  v  0v (dim_vec v)  v ∙c w = vec_norm v * vec_norm w * ((1 / vec_norm v v v) ∙c (1 / vec_norm w v w))" by auto
qed

lemma normalized_vec_norm :
  fixes v :: "complex vec"
  assumes dim_v: "v  carrier_vec n" 
    and neq0: "v  0v n"
  shows "vec_normalize v ∙c vec_normalize v = 1"
  unfolding vec_normalize_def
proof (simp, rule conjI)
  show "v = 0v (dim_vec v)  v ∙c v = 1" using neq0 dim_v by auto
  have dim_a: "(vec_normalize v)  carrier_vec n" "conjugate (vec_normalize v)  carrier_vec n" using dim_v vec_normalize_def by auto 
  note dim = dim_v dim_a
  have nvge0: "vec_norm v > 0" using vec_norm_ge_0 neq0 dim_v by auto
  then have vvvv: "v ∙c v = (vec_norm v) * (vec_norm v)" unfolding vec_norm_def by (metis power2_csqrt power2_eq_square)
  from nvge0 have "conjugate (vec_norm v) = vec_norm v" by (simp add: complex_eq_iff complex_is_Real_iff) 
  then have "v ∙c (1 / vec_norm v v v) = 1 / vec_norm v * (v ∙c v)" 
    by (subst conjugate_smult_vec, auto)
  also have " = 1 / vec_norm v * vec_norm v * vec_norm v" using vvvv by auto
  also have " = vec_norm v" by auto
  finally have "v ∙c (1 / vec_norm v v v) = vec_norm v".
  then show "v  0v (dim_vec v)  vec_norm v  0  v ∙c (1 / vec_norm v v v) = vec_norm v" 
    using neq0 nvge0 by auto
qed

lemma normalize_zero:
  assumes "v  carrier_vec n"
  shows "vec_normalize v = 0v n  v = 0v n"
proof
  show "v = 0v n  vec_normalize v = 0v n" unfolding vec_normalize_def by auto
next
  have "v  0v n  vec_normalize v  0v n" unfolding vec_normalize_def 
  proof (simp, rule impI)
    assume asm: "v  0v n"
    then have "vec_norm v > 0" using vec_norm_ge_0 assms by auto
    then have nvge0: "1 / vec_norm v > 0" by (simp add: complex_is_Real_iff)
    have "k < n. v $ k  0" using asm assms by auto
    then obtain k where kn: "k < n" and  vkneq0: "v $ k  0" by auto
    then have "(1 / vec_norm v v v) $ k = (1 / vec_norm v) * (v $ k)" 
      using assms carrier_vecD index_smult_vec(1) by blast
    with nvge0 vkneq0 have "(1 / vec_norm v v v) $ k  0" by auto
    then show "1 / vec_norm v v v  0v n" using assms kn by fastforce
  qed
  then show "vec_normalize v = 0v n  v = 0v n" by auto
qed

lemma normalize_normalize[simp]:
  "vec_normalize (vec_normalize v) = vec_normalize v"
proof (rule disjE[of "v = 0v (dim_vec v)" "v  0v (dim_vec v)"], auto)
  let ?n = "dim_vec v"
{
  assume "v = 0v ?n"
  then have "vec_normalize v = v" unfolding vec_normalize_def by auto
  then show "vec_normalize (vec_normalize v) = vec_normalize v" by auto
}
  assume neq0: "v  0v ?n"
  have dim: "v  carrier_vec ?n" by auto
  have "vec_norm (vec_normalize v) = 1" unfolding vec_norm_def
    using normalized_vec_norm[OF dim neq0] by auto
  then show "vec_normalize (vec_normalize v) = vec_normalize v" 
    by (subst (1) vec_normalize_def, simp)
qed

subsection ‹Spectral decomposition of normal complex matrices›

lemma normalize_keep_corthogonal:
  fixes vs :: "complex vec list"
  assumes cor: "corthogonal vs" and dims: "set vs  carrier_vec n"
  shows "corthogonal (map vec_normalize vs)"
  unfolding corthogonal_def
proof (rule allI, rule impI, rule allI, rule impI, goal_cases)
  case c: (1 i j)
  let ?m = "length vs"
  have len: "length (map vec_normalize vs) = ?m" by auto
  have dim: "k. k < ?m  (vs ! k)  carrier_vec n" using dims by auto
  have map: "k. k < ?m   map vec_normalize vs ! k = vec_normalize (vs ! k)" by auto

  have eq1: "j k. j < ?m  k < ?m  ((vs ! j) ∙c (vs ! k) = 0) = (j  k)" using assms unfolding corthogonal_def by auto
  then have "k. k < ?m  (vs ! k) ∙c (vs ! k)  0 " by auto
  then have "k. k < ?m  (vs ! k)  (0v n)" using dim 
    by (auto simp add: conjugate_square_eq_0_vec[of _ n, OF dim])
  then have vnneq0: "k. k < ?m  vec_norm (vs ! k)  0" using vec_norm_zero[OF dim] by auto
  then have i0: "vec_norm (vs ! i)  0" and j0: "vec_norm (vs ! j)  0" using c by auto
  have "(vs ! i) ∙c (vs ! j) = vec_norm (vs ! i) * vec_norm (vs ! j) * (vec_normalize (vs ! i) ∙c vec_normalize (vs ! j))"
    by (subst normalized_cscalar_prod[of "vs ! i" n "vs ! j"], auto, insert dim c, auto)
  with i0 j0 have "(vec_normalize (vs ! i) ∙c vec_normalize (vs ! j) = 0) = ((vs ! i) ∙c (vs ! j) = 0)" by auto
  with eq1 c have "(vec_normalize (vs ! i) ∙c vec_normalize (vs ! j) = 0) = (i  j)" by auto
  with map c show "(map vec_normalize vs ! i ∙c map vec_normalize vs ! j = 0) = (i  j)" by auto
qed

lemma normalized_corthogonal_mat_is_unitary:
  assumes W: "set ws  carrier_vec n"
    and orth: "corthogonal ws"
    and len: "length ws = n"
  shows "unitary (mat_of_cols n (map vec_normalize ws))" (is "unitary ?W")
proof -
  define vs where "vs = map vec_normalize ws"
  define W where "W = mat_of_cols n vs"
  have W': "set vs  carrier_vec n" using assms vs_def by auto
  then have W'': "k. k < length vs  vs ! k  carrier_vec n" by auto
  have orth': "corthogonal vs" using assms normalize_keep_corthogonal vs_def by auto
  have len'[simp]: "length vs = n" using assms vs_def by auto
  have dimW: "W  carrier_mat n n" using W_def len by auto
  have "adjoint W  carrier_mat n n" using dimW by auto
  then have dimaW: "mat_adjoint W  carrier_mat n n" by auto
  {
    fix i j assume i: "i < n" and j: "j < n"
    have dimws: "(ws ! i)  carrier_vec n" "(ws ! j)  carrier_vec n" using W len i j by auto
    have "(ws ! i) ∙c (ws ! i)  0" "(ws ! j) ∙c (ws ! j)  0" using orth corthogonal_def[of ws] len i j by auto
    then have neq0: "(ws ! i)  0v n" "(ws ! j)  0v n"
      by (auto simp add: conjugate_square_eq_0_vec[of "ws ! i" n])
    then have "vec_norm (ws ! i) > 0" "vec_norm (ws ! j) > 0" using vec_norm_ge_0 dimws by auto
    then have ge0: "vec_norm (ws ! i) * vec_norm (ws ! j) > 0" by auto
    have ws': "vs ! i = vec_normalize (ws ! i)" 
        "vs ! j = vec_normalize (ws ! j)" 
      using len i j vs_def by auto
    have ii1: "(vs ! i) ∙c (vs ! i) = 1" 
      apply (simp add: ws')
      apply (rule normalized_vec_norm[of "ws ! i"], rule dimws, rule neq0)
      done
    have ij0: "i  j  (ws ! i)  ∙c (ws ! j) = 0" using i j 
      by (insert orth, auto simp add: corthogonal_def[of ws] len)
    have "i  j  (ws ! i)  ∙c (ws ! j) =  (vec_norm (ws ! i) * vec_norm (ws ! j)) * ((vs ! i) ∙c (vs ! j))"
      apply (auto simp add: ws')
      apply (rule normalized_cscalar_prod)
       apply (rule dimws, rule dimws)
      done
    with ij0 have ij0': "i  j  (vs ! i) ∙c (vs ! j) = 0" using ge0 by auto
    have cWk: "k. k < n  col W k = vs ! k" unfolding W_def 
    apply (subst col_mat_of_cols)
      apply (auto simp add: W'')
      done
    have "(mat_adjoint W * W) $$ (j, i) = row (mat_adjoint W) j  col W i"
      by (insert dimW i j dimaW, auto)
    also have " = conjugate (col W j)  col W i" 
      by (insert dimW i j dimaW, auto simp add: mat_adjoint_def)
    also have " = col W i  conjugate (col W j)" using comm_scalar_prod[of "col W i" n] dimW by auto
    also have " = (vs ! i) ∙c (vs ! j)" using W_def col_mat_of_cols i j len cWk by auto
    finally have "(mat_adjoint W * W) $$ (j, i) = (vs ! i) ∙c (vs ! j)".
    then have "(mat_adjoint W * W) $$ (j, i) = (if (j = i) then 1 else 0)"
      by (auto simp add: ii1 ij0')
  }
  note maWW = this
  then have "mat_adjoint W * W = 1m n" unfolding one_mat_def using dimW dimaW
    by (auto simp add: maWW adjoint_def)
  then have iv0: "adjoint W * W = 1m n"  by auto
  have dimaW: "adjoint W  carrier_mat n n" using dimaW by auto
  then have iv1: "W * adjoint W  = 1m n" using mat_mult_left_right_inverse dimW iv0 by auto
  then show "unitary W" unfolding unitary_def inverts_mat_def using dimW dimaW iv0 iv1 by auto 
qed

lemma normalize_keep_eigenvector:
  assumes ev: "eigenvector A v e" 
    and dim: "A  carrier_mat n n" "v  carrier_vec n"
  shows "eigenvector A (vec_normalize v) e"
  unfolding eigenvector_def
proof 
  show "vec_normalize v  carrier_vec (dim_row A)" using dim by auto
  have eg: "A *v v = e v v" using ev dim eigenvector_def by auto
  have vneq0: "v  0v n" using ev dim unfolding eigenvector_def by auto
  then have s0: "vec_normalize v  0v n" 
    by (insert dim, subst normalize_zero[of v], auto)
  from vneq0 have vvge0: "vec_norm v > 0" using vec_norm_ge_0 dim by auto
  have s1: "A *v vec_normalize v = e v vec_normalize v" unfolding vec_normalize_def 
    using vneq0 dim 
    apply (auto, simp add: mult_mat_vec)
    apply (subst eg, auto)
    done
  with s0 dim show "vec_normalize v  0v (dim_row A)  A *v vec_normalize v = e v vec_normalize v" by auto
qed

lemma four_block_mat_adjoint:
  fixes A B C D :: "'a::conjugatable_field mat"
  assumes dim: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "adjoint (four_block_mat A B C D) 
    = four_block_mat (adjoint A) (adjoint C) (adjoint B) (adjoint D)"
  by (rule eq_matI, insert dim, auto simp add: adjoint_eval)

fun unitary_schur_decomposition :: "complex mat  complex list  complex mat × complex mat × complex mat" where 
  "unitary_schur_decomposition A [] = (A, 1m (dim_row A), 1m (dim_row A))"
| "unitary_schur_decomposition A (e # es) = (let
       n = dim_row A;
       n1 = n - 1;
       v' = find_eigenvector A e;
       v = vec_normalize v';
       ws0 = gram_schmidt n (basis_completion v);
       ws = map vec_normalize ws0;
       W = mat_of_cols n ws;
       W' = corthogonal_inv W;
       A' = W' * A * W;
       (A1,A2,A0,A3) = split_block A' 1 1;
       (B,P,Q) = unitary_schur_decomposition A3 es;
       z_row = (0m 1 n1);
       z_col = (0m n1 1);
       one_1 = 1m 1
     in (four_block_mat A1 (A2 * P) A0 B, 
     W * four_block_mat one_1 z_row z_col P, 
     four_block_mat one_1 z_row z_col Q * W'))"

theorem unitary_schur_decomposition:
  assumes A: "(A::complex mat)  carrier_mat n n"
      and c: "char_poly A = ( (e :: complex)  es. [:- e, 1:])"
      and B: "unitary_schur_decomposition A es = (B,P,Q)"
  shows "similar_mat_wit A B P Q  upper_triangular B  diag_mat B = es  unitary P  (Q = adjoint P)"
  using assms
proof (induct es arbitrary: n A B P Q)
  case Nil
  with degree_monic_char_poly[of A n] 
  show ?case by (auto intro: similar_mat_wit_refl simp: diag_mat_def unitary_zero)
next
  case (Cons e es n A C P Q)
  let ?n1 = "n - 1"
  from Cons have A: "A  carrier_mat n n" and dim: "dim_row A = n" by auto
  let ?cp = "char_poly A"
  from Cons(3)
  have cp: "?cp = [: -e, 1 :] * (e  es. [:- e, 1:])" by auto
  have mon: "monic (e es. [:- e, 1:])" by (rule monic_prod_list, auto)
  have deg: "degree ?cp = Suc (degree (e es. [:- e, 1:]))" unfolding cp
    by (subst degree_mult_eq, insert mon, auto)
  with degree_monic_char_poly[OF A] have n: "n  0" by auto
  define v' where "v' = find_eigenvector A e"
  define v where "v = vec_normalize v'"
  define b where "b = basis_completion v"
  define ws0 where "ws0 = gram_schmidt n b"
  define ws where "ws = map vec_normalize ws0"
  define W where "W = mat_of_cols n ws"
  define W' where "W' = corthogonal_inv W"
  define A' where "A' = W' * A * W"
  obtain A1 A2 A0 A3 where splitA': "split_block A' 1 1 = (A1,A2,A0,A3)"
    by (cases "split_block A' 1 1", auto)
  obtain B P' Q' where schur: "unitary_schur_decomposition A3 es = (B,P',Q')" 
    by (cases "unitary_schur_decomposition A3 es", auto)
  let ?P' = "four_block_mat (1m 1) (0m 1 ?n1) (0m ?n1 1) P'"
  let ?Q' = "four_block_mat (1m 1) (0m 1 ?n1) (0m ?n1 1) Q'"
  have C: "C = four_block_mat A1 (A2 * P') A0 B" and P: "P = W * ?P'" and Q: "Q = ?Q' * W'"
    using Cons(4) unfolding unitary_schur_decomposition.simps
    Let_def list.sel dim
    v'_def[symmetric] v_def[symmetric] b_def[symmetric] ws0_def[symmetric] ws_def[symmetric] W'_def[symmetric] W_def[symmetric]
    A'_def[symmetric] split splitA' schur by auto
  have e: "eigenvalue A e" 
    unfolding eigenvalue_root_char_poly[OF A] cp by simp
  from find_eigenvector[OF A e] have ev': "eigenvector A v' e" unfolding v'_def .
  then have "v'  carrier_vec n" unfolding eigenvector_def using A by auto
  with ev' have ev: "eigenvector A v e" unfolding v_def using A dim normalize_keep_eigenvector by auto
  from this[unfolded eigenvector_def]
  have v[simp]: "v  carrier_vec n" and v0: "v  0v n" using A by auto
  interpret cof_vec_space n "TYPE(complex)" .
  from basis_completion[OF v v0, folded b_def]
  have span_b: "span (set b) = carrier_vec n" and dist_b: "distinct b" 
    and indep: "¬ lin_dep (set b)" and b: "set b  carrier_vec n" and hdb: "hd b = v" 
    and len_b: "length b = n" by auto
  from hdb len_b n obtain vs where bv: "b = v # vs" by (cases b, auto)
  from gram_schmidt_result[OF b dist_b indep refl, folded ws0_def]
  have ws0: "set ws0  carrier_vec n" "corthogonal ws0" "length ws0 = n" 
    by (auto simp: len_b)
  then have ws: "set ws  carrier_vec n" "corthogonal ws" "length ws = n" unfolding ws_def
    using normalize_keep_corthogonal by auto
  have ws0ne: "ws0  []" using ‹length ws0 = n n by auto
  from gram_schmidt_hd[OF v, of vs, folded bv] have hdws0: "hd ws0 = (vec_normalize v')" unfolding ws0_def v_def .
  have "hd ws = vec_normalize (hd ws0)" unfolding ws_def using hd_map[OF ws0ne]  by auto
  then have hdws: "hd ws = v" unfolding v_def using normalize_normalize[of v'] hdws0 by auto
  have orth_W: "corthogonal_mat W" using orthogonal_mat_of_cols ws unfolding W_def.
  have W: "W  carrier_mat n n"
    using ws unfolding W_def using mat_of_cols_carrier(1)[of n ws] by auto
  have W': "W'  carrier_mat n n" unfolding W'_def corthogonal_inv_def using W 
    by (auto simp: mat_of_rows_def)  
  from corthogonal_inv_result[OF orth_W] 
  have W'W: "inverts_mat W' W" unfolding W'_def .
  hence WW': "inverts_mat W W'" using mat_mult_left_right_inverse[OF W' W] W' W
    unfolding inverts_mat_def by auto
  have A': "A'  carrier_mat n n" using W W' A unfolding A'_def by auto
  have A'A_wit: "similar_mat_wit A' A W' W"
    by (rule similar_mat_witI[of _ _ n], insert W W' A A' W'W WW', auto simp: A'_def
    inverts_mat_def)
  hence A'A: "similar_mat A' A" unfolding similar_mat_def by blast
  from similar_mat_wit_sym[OF A'A_wit] have simAA': "similar_mat_wit A A' W W'" by auto
  have eigen[simp]: "A *v v = e v v" and v0: "v  0v n"
    using v_def v'_def find_eigenvector[OF A e] A normalize_keep_eigenvector
    unfolding eigenvector_def by auto
  let ?f = "(λ i. if i = 0 then e else 0)"
  have col0: "col A' 0 = vec n ?f"
    unfolding A'_def W'_def W_def
    using corthogonal_col_ev_0[OF A v v0 eigen n hdws ws].
  from A' n have "dim_row A' = 1 + ?n1" "dim_col A' = 1 + ?n1" by auto
  from split_block[OF splitA' this] have A2: "A2  carrier_mat 1 ?n1"
    and A3: "A3  carrier_mat ?n1 ?n1" 
    and A'block: "A' = four_block_mat A1 A2 A0 A3" by auto
  have A1id: "A1 = mat 1 1 (λ _. e)"
    using splitA'[unfolded split_block_def Let_def] arg_cong[OF col0, of "λ v. v $ 0"] A' n
    by (auto simp: col_def)
  have A1: "A1  carrier_mat 1 1" unfolding A1id by auto
  {
    fix i
    assume "i < ?n1"
    with arg_cong[OF col0, of "λ v. v $ Suc i"] A'
    have "A' $$ (Suc i, 0) = 0" by auto
  } note A'0 = this
  have A0id: "A0 = 0m ?n1 1"
    using splitA'[unfolded split_block_def Let_def] A'0 A' by auto
  have A0: "A0  carrier_mat ?n1 1" unfolding A0id by auto
  from cp char_poly_similar[OF A'A]
  have cp: "char_poly A' = [: -e,1 :] * ( e  es. [:- e, 1:])" by simp
  also have "char_poly A' = char_poly A1 * char_poly A3" 
    unfolding A'block A0id
    by (rule char_poly_four_block_zeros_col[OF A1 A2 A3])
  also have "char_poly A1 = [: -e,1 :]"
    by (simp add: A1id char_poly_defs det_def signof_def sign_def)
  finally have cp: "char_poly A3 = ( e  es. [:- e, 1:])"
    by (metis mult_cancel_left pCons_eq_0_iff zero_neq_one)
  from Cons(1)[OF A3 cp schur]
  have simIH: "similar_mat_wit A3 B P' Q'" and ut: "upper_triangular B" and diag: "diag_mat B = es"
    and uP': "unitary P'" and Q'P': "Q' = adjoint P'"
    by auto
  from similar_mat_witD2[OF A3 simIH] 
  have B: "B  carrier_mat ?n1 ?n1" and P': "P'  carrier_mat ?n1 ?n1" and Q': "Q'  carrier_mat ?n1 ?n1" 
    and PQ': "P' * Q' = 1m ?n1" by auto
  have A0_eq: "A0 = P' * A0 * 1m 1" unfolding A0id using P' by auto
  have simA'C: "similar_mat_wit A' C ?P' ?Q'" unfolding A'block C
    by (rule similar_mat_wit_four_block[OF similar_mat_wit_refl[OF A1] simIH _ A0_eq A1 A3 A0],
    insert PQ' A2 P' Q', auto)
  have ut1: "upper_triangular A1" unfolding A1id by auto
  have ut: "upper_triangular C" unfolding C A0id
    by (intro upper_triangular_four_block[OF _ B ut1 ut], auto simp: A1id)
  from A1id have diagA1: "diag_mat A1 = [e]" unfolding diag_mat_def by auto
  from diag_four_block_mat[OF A1 B] have diag: "diag_mat C = e # es" unfolding diag diagA1 C by simp

  have aW: "adjoint W  carrier_mat n n" using W by auto
  have aW': "adjoint W'  carrier_mat n n" using W' by auto
  have "unitary W" using W_def ws_def ws0 normalized_corthogonal_mat_is_unitary by auto
  then have ivWaW: "inverts_mat W (adjoint W)" using unitary_def W aW by auto
  with WW' have W'aW: "W' = (adjoint W)" using inverts_mat_unique W W' aW by auto
  then have "adjoint W' = W" using adjoint_adjoint by auto
  with ivWaW have "inverts_mat W' (adjoint W')" using inverts_mat_symm W aW W'aW by auto
  then have "unitary W'" using unitary_def W' by auto

  have newP': "P'  carrier_mat (n - Suc 0) (n - Suc 0)" using P' by auto
  have rl: " x1 x2 x3 x4 y1 y2 y3 y4 f. x1 = y1  x2 = y2  x3 = y3  x4 = y4  f x1 x2 x3 x4 = f y1 y2 y3 y4" by simp
  have Q'aP': "?Q' = adjoint ?P'"
    apply (subst four_block_mat_adjoint, auto simp add: newP')
    apply (rule rl[where f2 = four_block_mat])
       apply (auto simp add: eq_matI adjoint_eval Q'P')
    done
  have "adjoint P = adjoint ?P' * adjoint W"  using W newP' n
    apply (simp add: P)
    apply (subst adjoint_mult[of W, symmetric])
     apply (auto simp add: W P' carrier_matD[of W n n])
    done
  also have " = ?Q' * W'" using Q'aP' W'aW by auto
  also have " = Q" using Q by auto
  finally have QaP: "Q = adjoint P" ..

  from similar_mat_wit_trans[OF simAA' simA'C, folded P Q] have smw: "similar_mat_wit A C P Q" by blast
  then have dimP: "P  carrier_mat n n" and dimQ: "Q  carrier_mat n n" unfolding similar_mat_wit_def using A by auto
  from smw have "P * Q = 1m n" unfolding similar_mat_wit_def using A  by auto
  then have "inverts_mat P Q" using inverts_mat_def dimP by auto
  then have uP: "unitary P" using QaP unitary_def dimP by auto
    
  from ut similar_mat_wit_trans[OF simAA' simA'C, folded P Q] diag uP QaP
  show ?case by blast
qed

lemma complex_mat_char_poly_factorizable:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  shows "as. char_poly A =  ( a  as. [:- a, 1:])  length as = n"
proof -
  let ?ca = "char_poly A"
  have ex0: "bs. Polynomial.smult (lead_coeff ?ca) (bbs. [:- b, 1:]) = ?ca 
     length bs = degree ?ca"
    by (simp add: fundamental_theorem_algebra_factorized)
  then obtain bs where " Polynomial.smult (lead_coeff ?ca) (bbs. [:- b, 1:]) = ?ca 
     length bs = degree ?ca" by auto
  moreover have "lead_coeff ?ca = (1::complex)" 
    using assms degree_monic_char_poly by blast
  ultimately have ex1: "?ca = (bbs. [:- b, 1:])  length bs = degree ?ca" by auto
  moreover have "degree ?ca = n"
    by (simp add: assms degree_monic_char_poly)
  ultimately show ?thesis by auto
qed

lemma complex_mat_has_unitary_schur_decomposition:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  shows "B P es. similar_mat_wit A B P (adjoint P)  unitary P 
     char_poly A = ( (e :: complex)  es. [:- e, 1:])  diag_mat B = es"
proof -
  have "es. char_poly A =  ( e  es. [:- e, 1:])  length es = n" 
    using assms by (simp add: complex_mat_char_poly_factorizable)
  then obtain es where es: "char_poly A =  ( e  es. [:- e, 1:])  length es = n" by auto
  obtain B P Q where B: "unitary_schur_decomposition A es = (B,P,Q)" by (cases "unitary_schur_decomposition A es", auto)

  have "similar_mat_wit A B P Q  upper_triangular B  unitary P  (Q = adjoint P)  
   char_poly A = ( (e :: complex)  es. [:- e, 1:])  diag_mat B = es" using assms es B
    by (auto simp add: unitary_schur_decomposition)
  then show ?thesis by auto
qed

lemma normal_upper_triangular_matrix_is_diagonal:
  fixes A :: "'a::conjugatable_ordered_field mat"
  assumes "A  carrier_mat n n"
    and tri: "upper_triangular A"
    and norm: "A * adjoint A = adjoint A * A"
  shows "diagonal_mat A"
proof (rule disjE[of "n = 0" "n > 0"], blast)
  have dim: "dim_row A = n" "dim_col A = n" using assms by auto
  from norm have eq0: "i j. (A * adjoint A)$$(i,j) = (adjoint A * A)$$(i,j)" by auto
  have nat_induct_strong: 
    "P. (P::natbool) 0  (i. i < n  (k. k < i  P k)  P i)  (i. i < n  P i)"
    by (metis dual_order.strict_trans infinite_descent0 linorder_neqE_nat)
  show "n = 0  ?thesis" using dim unfolding diagonal_mat_def by auto
  show "n > 0  ?thesis" unfolding diagonal_mat_def dim
    apply (rule allI, rule impI)
    apply (rule nat_induct_strong)
  proof (rule allI, rule impI, rule impI)
    assume asm: "n > 0"
    from tri upper_triangularD[of A 0 j] dim have z0: "j. 0< j  j < n  A$$(j, 0) = 0" 
      by auto
    then have ada00: "(adjoint A * A)$$(0,0) = conjugate (A$$(0,0)) * A$$(0,0)"
      using asm dim by (auto simp add: scalar_prod_def adjoint_eval sum.atLeast_Suc_lessThan)
    have aad00: "(A * adjoint A)$$(0,0) = (k=0..<n. A$$(0, k) * conjugate (A$$(0, k)))"
      using asm dim by (auto simp add: scalar_prod_def adjoint_eval)
    moreover have 
      " = A$$(0,0) * conjugate (A$$(0,0))
          + (k=1..<n. A$$(0, k) * conjugate (A$$(0, k)))"
      using dim asm by (subst sum.atLeast_Suc_lessThan[of 0 n "λk. A$$(0, k) * conjugate (A$$(0, k))"], auto)
    ultimately have f1tneq0: "(k=(Suc 0)..<n. A$$(0, k) * conjugate (A$$(0, k))) = 0"
      using eq0 ada00 by (simp)
    have geq0: "k. k < n  A$$(0, k) * conjugate (A$$(0, k))  0"  
      using conjugate_square_positive by auto
    have "k. 1  k  k < n  A$$(0, k) * conjugate (A$$(0, k)) = 0"
      by (rule sum_nonneg_0[of "{1..<n}"], auto, rule geq0, auto, rule f1tneq0)
    with dim asm show 
      case0: "j. 0 < n  j < n  0  j  A $$ (0, j) = 0"
      by auto
    {
      fix i
      assume asm: "n > 0" "i < n" "i > 0"
        and ih: "k. k < i  j<n. k  j  A $$ (k, j) = 0"
      then have "j. j<n  i  j  A $$ (i, j) = 0"
      proof -
        have inter_part: "b m e. (b::nat) < e  b < m  m < e  {b..<m}  {m..<e} = {b..<e}" by auto
        then have  
          "b m e f. (b::nat) < e  b < m  m < e 
             (k=b..<e. f k) = (k{b..<m}{m..<e}. f k)"
          using sum.union_disjoint by auto
        then have sum_part:
          "b m e f. (b::nat) < e  b < m  m < e 
                       (k=b..<e. f k) = (k=b..<m. f k) + (k=m..<e. f k)"
          by (auto simp add: sum.union_disjoint)
        from tri upper_triangularD[of A j i] asm dim have 
            zsi0: "j. j < i  A$$(i, j) = 0" by auto
        from tri upper_triangularD[of A j i] asm dim have 
            zsi1: "k. i < k  k < n  A$$(k, i) = 0" by auto
        have 
          "(A * adjoint A)$$(i, i) 
          = (k=0..<n. conjugate (A$$(i, k)) * A$$(i, k))" using asm dim
          apply (auto simp add: scalar_prod_def adjoint_eval)
          apply (rule sum.cong, auto)
          done
        also have
          " = (k=0..<i. conjugate (A$$(i, k)) * A$$(i, k))
              + (k=i..<n. conjugate (A$$(i, k)) * A$$(i, k))"
          using asm
          by (auto simp add: sum_part[of 0 n i])
        also have
          " = (k=i..<n. conjugate (A$$(i, k)) * A$$(i, k))"
          using zsi0
          by auto
        also have
          " = conjugate (A$$(i, i)) * A$$(i, i) 
            + (k=(Suc i)..<n. conjugate (A$$(i, k)) * A$$(i, k))"
          using asm
          by (auto simp add: sum.atLeast_Suc_lessThan)
        finally have
          adaii: "(A * adjoint A)$$(i, i) 
            = conjugate (A$$(i, i)) * A$$(i, i) 
            + (k=(Suc i)..<n. conjugate (A$$(i, k)) * A$$(i, k))" .
        have 
          "(adjoint A * A)$$(i, i) = (k=0..<n. conjugate (A$$(k, i)) * A$$(k, i))"
          using asm dim by (auto simp add: scalar_prod_def adjoint_eval)
        also have
          " = (k=0..<i. conjugate (A$$(k, i)) * A$$(k, i))
              + (k=i..<n. conjugate (A$$(k, i)) * A$$(k, i))"
          using asm by (auto simp add: sum_part[of 0 n i])
        also have 
            " = (k=i..<n. conjugate (A$$(k, i)) * A$$(k, i))"
          using asm ih by auto
        also have
            " = conjugate (A$$(i, i)) * A$$(i, i)"
          using asm zsi1 by (auto simp add: sum.atLeast_Suc_lessThan)
        finally have "(adjoint A * A)$$(i, i) = conjugate (A$$(i, i)) * A$$(i, i)" .
        with adaii eq0 have 
          fsitoneq0: "(k=(Suc i)..<n. conjugate (A$$(i, k)) * A$$(i, k)) = 0" by auto
        have "k. k<n  i < k  conjugate (A$$(i, k)) * A$$(i, k) = 0"
          by (rule sum_nonneg_0[of "{(Suc i)..<n}"], auto, subst mult.commute, 
              rule conjugate_square_positive, rule fsitoneq0)
        then have "k. k<n   i<k  A $$ (i, k) = 0" by auto
        with zsi0 show "j. j<n  i  j  A $$ (i, j) = 0" 
          by (metis linorder_neqE_nat)
      qed
    }
    with case0 show "i ia.
       0 < n 
       i < n 
       ia < n 
       (k. k < ia  j<n. k  j  A $$ (k, j) = 0) 
       j<n. ia  j  A $$ (ia, j) = 0" by auto
  qed
qed

lemma normal_complex_mat_has_spectral_decomposition:
  assumes A: "(A::complex mat)  carrier_mat n n"
    and normal: "A * adjoint A  = adjoint A * A"
    and c: "char_poly A = ( (e :: complex)  es. [:- e, 1:])"
    and B: "unitary_schur_decomposition A es = (B,P,Q)"
  shows "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es  unitary P"
proof -
  have smw: "similar_mat_wit A B P (adjoint P)" 
    and ut: "upper_triangular B"
    and uP: "unitary P" 
    and dB: "diag_mat B = es"
    and "(Q = adjoint P)"
    using assms by (auto simp add: unitary_schur_decomposition)
  from smw have dimP: "P  carrier_mat n n" and dimB: "B  carrier_mat n n" 
    and dimaP: "adjoint P  carrier_mat n n"
    unfolding similar_mat_wit_def using A by auto
  have dimaB: "adjoint B  carrier_mat n n" using dimB by auto
  note dims = dimP dimB dimaP dimaB

  have "inverts_mat P (adjoint P)" using unitary_def uP dims by auto
  then have iaPP: "inverts_mat (adjoint P) P" using inverts_mat_symm using dims by auto
  have aPP: "adjoint P * P = 1m n" using dims iaPP unfolding inverts_mat_def by auto
  from smw have A: "A = P * B * (adjoint P)" unfolding similar_mat_wit_def Let_def by auto
  then have aA: "adjoint A = P * adjoint B * adjoint P" 
    by (insert A dimP dimB dimaP, auto simp add: adjoint_mult[of _ n n _ n] adjoint_adjoint)
  have "A * adjoint A = (P * B * adjoint P) * (P * adjoint B * adjoint P)" using A aA by auto
  also have " = P * B * (adjoint P * P) * (adjoint B * adjoint P)" using dims by (mat_assoc n)
  also have " = P * B * 1m n * (adjoint B * adjoint P)" using dims aPP by (auto)
  also have " = P * B * adjoint B * adjoint P" using dims by (mat_assoc n)
  finally have "A * adjoint A = P * B * adjoint B * adjoint P".
  then have "adjoint P * (A * adjoint A) * P = (adjoint P * P) * B * adjoint B * (adjoint P * P)"
    using dims by (simp add: assoc_mult_mat[of _ n n _ n _ n])
  also have " = 1m n * B * adjoint B * 1m n" using aPP by auto
  also have " = B * adjoint B" using dims by auto
  finally have eq0: "adjoint P * (A * adjoint A) * P = B * adjoint B".

  have "adjoint A * A = (P * adjoint B * adjoint P) * (P * B * adjoint P)" using A aA by auto
  also have " = P * adjoint B * (adjoint P * P) * (B * adjoint P)" using dims by (mat_assoc n)
  also have " = P * adjoint B * 1m n * (B * adjoint P)" using dims aPP by (auto)
  also have " = P * adjoint B * B * adjoint P" using dims by (mat_assoc n)  
  finally have "adjoint A * A = P * adjoint B * B * adjoint P" by auto
  then have "adjoint P * (adjoint A * A) * P = (adjoint P * P) * adjoint B * B * (adjoint P * P)"
    using dims by (simp add: assoc_mult_mat[of _ n n _ n _ n])
  also have " = 1m n * adjoint B * B * 1m n" using aPP by auto
  also have " = adjoint B * B" using dims by auto
  finally have eq1: "adjoint P * (adjoint A * A) * P = adjoint B * B".

  from normal have "adjoint P * (adjoint A * A) * P = adjoint P * (A * adjoint A) * P" by auto
  with eq0 eq1 have "B * adjoint B = adjoint B * B" by auto
  with ut dims have "diagonal_mat B" using normal_upper_triangular_matrix_is_diagonal by auto
  with smw uP dB show "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es  unitary P" by auto
qed

lemma complex_mat_has_jordan_nf:
  fixes A :: "complex mat"
  assumes "A  carrier_mat n n"
  shows "n_as. jordan_nf A n_as"
proof -
  have "as. char_poly A =  ( a  as. [:- a, 1:])  length as = n" 
    using assms by (simp add: complex_mat_char_poly_factorizable)
  then show ?thesis using assms
    by (auto simp add: jordan_nf_iff_linear_factorization)
qed

lemma hermitian_is_normal:
  assumes "hermitian A"
  shows "A * adjoint A = adjoint A * A"
  using assms by (auto simp add: hermitian_def)

lemma hermitian_eigenvalue_real:
  assumes dim: "(A::complex mat)  carrier_mat n n"
    and hA: "hermitian A"
    and c: "char_poly A = ( (e :: complex)  es. [:- e, 1:])"
    and B: "unitary_schur_decomposition A es = (B,P,Q)"
  shows "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es 
     unitary P  (i < n. B$$(i, i)  Reals)"
proof -
  have normal: "A * adjoint A = adjoint A * A" using hA hermitian_is_normal by auto
  then have schur: "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es  unitary P"
    using normal_complex_mat_has_spectral_decomposition[OF dim normal c B] by (simp)
  then have "similar_mat_wit A B P (adjoint P)" 
    and uP: "unitary P" and dB: "diag_mat B = es"
    using assms by auto
  then have A: "A = P * B * (adjoint P)" 
    and dimB: "B  carrier_mat n n" and dimP: "P  carrier_mat n n"
    unfolding similar_mat_wit_def Let_def using dim by auto
  then have dimaB: "adjoint B  carrier_mat n n" by auto
  have "adjoint A = adjoint (adjoint P) * adjoint (P * B)" 
    apply (subst A)
    apply (subst adjoint_mult[of "P * B" n n "adjoint P" n])
      apply (insert dimB dimP, auto)
    done
  also have " = P * adjoint (P * B)" by (auto simp add: adjoint_adjoint)
  also have " = P * (adjoint B * adjoint P)" using dimB dimP by (auto simp add: adjoint_mult)
  also have " = P * adjoint B * adjoint P" using dimB dimP by (subst assoc_mult_mat[symmetric, of P n n "adjoint B" n "adjoint P" n], auto)
  finally have aA: "adjoint A = P * adjoint B * adjoint P" .
  have "A = adjoint A" using hA hermitian_def[of A] by auto
  then have "P * B * adjoint P = P * adjoint B * adjoint P" using A aA by auto
  then have BaB: "B = adjoint B" using unitary_elim[OF dimB dimaB dimP] uP by auto
  {
    fix i 
    assume "i < n"
    then have "B$$(i, i) = conjugate (B$$(i, i))" 
      apply (subst BaB)
      by (insert dimB, simp add: adjoint_eval)
    then have "B$$(i, i)  Reals" unfolding conjugate_complex_def 
      using Reals_cnj_iff by auto
  }
  then have "i<n. B$$(i, i)  Reals" by auto
  with schur show ?thesis by auto
qed

lemma hermitian_inner_prod_real:
  assumes dimA: "(A::complex mat)  carrier_mat n n"
    and dimv: "v  carrier_vec n"
    and hA: "hermitian A"
  shows "inner_prod v (A *v v)  Reals"
proof -
  obtain es where es: "char_poly A = ( (e :: complex)  es. [:- e, 1:])" 
    using complex_mat_char_poly_factorizable dimA by auto
  obtain B P Q where "unitary_schur_decomposition A es = (B,P,Q)" 
    by (cases "unitary_schur_decomposition A es", auto)
  then have "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es 
     unitary P  (i < n. B$$(i, i)  Reals)"
    using hermitian_eigenvalue_real dimA es hA by auto
  then have A: "A = P * B * (adjoint P)" and dB: "diagonal_mat B"
    and Bii: "i. i < n  B$$(i, i)  Reals"
    and dimB: "B  carrier_mat n n" and dimP: "P  carrier_mat n n" and dimaP: "adjoint P  carrier_mat n n"
    unfolding similar_mat_wit_def Let_def using dimA by auto
  define w where "w = (adjoint P) *v v"
  then have dimw: "w  carrier_vec n" using dimaP dimv by auto
  from A have "inner_prod v (A *v v) = inner_prod v ((P * B * (adjoint P)) *v v)" by auto
  also have " = inner_prod v ((P * B) *v ((adjoint P) *v v))" using dimP dimB dimv
    by (subst assoc_mult_mat_vec[of _ n n "adjoint P" n], auto)
  also have " = inner_prod v (P *v (B *v ((adjoint P) *v v)))" using dimP dimB dimv dimaP
    by (subst assoc_mult_mat_vec[of _ n n "B" n], auto)
  also have " = inner_prod w (B *v w)" unfolding w_def 
    apply (rule adjoint_def_alter[OF _ _ dimP])
     apply (insert mult_mat_vec_carrier[OF dimB mult_mat_vec_carrier[OF dimaP dimv]], auto simp add: dimv)
    done

  also have " = (i=0..<n. (j=0..<n.
                conjugate (w$i) * B$$(i, j) * w$j))" unfolding scalar_prod_def using dimw dimB
    apply (simp add: scalar_prod_def sum_distrib_right)
    apply (rule sum.cong, auto, rule sum.cong, auto)
    done
  also have " = (i=0..<n. B$$(i, i) *  conjugate (w$i) * w$i)" 
    apply (rule sum.cong, auto)
    apply (simp add: sum.remove)
    apply (insert dB[unfolded diagonal_mat_def] dimB, auto)
    done
  finally have sum: "inner_prod v (A *v v) = (i=0..<n. B$$(i, i) *  conjugate (w$i) * w$i)" .
  have "i. i < n  B$$(i, i) *  conjugate (w$i) * w$i  Reals" using Bii by (simp add: Reals_cnj_iff)
  then have "(i=0..<n. B$$(i, i) *  conjugate (w$i) * w$i)  Reals" by auto
  then show ?thesis using sum by auto
qed

lemma unit_vec_bracket:
  fixes A :: "complex mat"
  assumes dimA: "A  carrier_mat n n" and i: "i < n"
  shows "inner_prod (unit_vec n i) (A *v (unit_vec n i)) = A$$(i, i)"
proof -
  define w where "(w::complex vec) = unit_vec n i"
  have "A *v w = col A i" using i dimA w_def by auto
  then have 1: "inner_prod w (A *v w) = inner_prod w (col A i)" using w_def by auto
  have "conjugate w = w" unfolding w_def unit_vec_def conjugate_vec_def using i by auto
  then have 2: "inner_prod w (col A i) = A$$(i, i)" using i dimA w_def by auto
  from 1 2 show "inner_prod w (A *v w) = A$$(i, i)" by auto
qed

lemma spectral_decomposition_extract_diag:
  fixes P B :: "complex mat"
  assumes dimP: "P  carrier_mat n n" and dimB: "B  carrier_mat n n"
    and uP: "unitary P" and dB: "diagonal_mat B" and i: "i < n"
  shows "inner_prod (col P i) (P * B * (adjoint P) *v (col P i)) = B$$(i, i)"
proof -
  have dimaP: "adjoint P carrier_mat n n" using dimP by auto
  have uaP: "unitary (adjoint P)" using unitary_adjoint uP dimP by auto
  then have "inverts_mat (adjoint P) P" by (simp add: unitary_def adjoint_adjoint)
  then have iv: "(adjoint P) * P = 1m n" using dimaP inverts_mat_def by auto
  define v where "v = col P i"
  then have dimv: "v  carrier_vec n" using dimP by auto
  define w where "(w::complex vec) = unit_vec n i"
  then have dimw: "w  carrier_vec n" by auto
  have BaPv: "B *v (adjoint P *v v)  carrier_vec n" using dimB dimaP dimv by auto
  have "(adjoint P) *v v = (col (adjoint P * P) i)" 
    by (simp add: col_mult2[OF dimaP dimP i, symmetric] v_def)
  then have aPv: "(adjoint P) *v v = w"
    by (auto simp add: iv i w_def)
  have "inner_prod v (P * B * (adjoint P) *v v) = inner_prod v ((P * B) *v ((adjoint P) *v v))" using dimP dimB dimv
    by (subst assoc_mult_mat_vec[of _ n n "adjoint P" n], auto)
  also have " = inner_prod v (P *v (B *v ((adjoint P) *v v)))" using dimP dimB dimv dimaP
    by (subst assoc_mult_mat_vec[of _ n n "B" n], auto)
  also have " = inner_prod (adjoint P *v v) (B *v (adjoint P *v v))" 
    by (simp add: adjoint_def_alter[OF dimv BaPv dimP])
  also have " = inner_prod w (B *v w)" using aPv by auto
  also have " = B$$(i, i)" using w_def unit_vec_bracket dimB i by auto
  finally show "inner_prod v (P * B * (adjoint P) *v v) = B$$(i, i)".
qed

lemma hermitian_inner_prod_zero:
  fixes A :: "complex mat"
  assumes dimA: "A  carrier_mat n n" and hA: "hermitian A"
    and zero: "vcarrier_vec n. inner_prod v (A *v v) = 0"
  shows "A = 0m n n"
proof -
  obtain es where es: "char_poly A = ( (e :: complex)  es. [:- e, 1:])" 
    using complex_mat_char_poly_factorizable dimA by auto
  obtain B P Q where "unitary_schur_decomposition A es = (B,P,Q)" 
    by (cases "unitary_schur_decomposition A es", auto)
  then have "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es 
     unitary P  (i < n. B$$(i, i)  Reals)"
    using hermitian_eigenvalue_real dimA es hA by auto
  then have A: "A = P * B * (adjoint P)" and dB: "diagonal_mat B"
    and Bii: "i. i < n  B$$(i, i)  Reals"
    and dimB: "B  carrier_mat n n" and dimP: "P  carrier_mat n n" and dimaP: "adjoint P  carrier_mat n n"
    and uP: "unitary P"
    unfolding similar_mat_wit_def Let_def unitary_def using dimA by auto
  then have uaP: "unitary (adjoint P)" using unitary_adjoint by auto
  then have "inverts_mat (adjoint P) P" by (simp add: unitary_def adjoint_adjoint)
  then have iv: "adjoint P * P = 1m n" using dimaP inverts_mat_def by auto
  have "B = 0m n n"
  proof-
    {
      fix i assume i: "i < n"
      define v where "v = col P i"
      then have dimv: "v  carrier_vec n" using v_def dimP by auto
      have "inner_prod v (A *v v) = B$$(i, i)" unfolding A v_def
        using spectral_decomposition_extract_diag[OF dimP dimB uP dB i]  by auto
      moreover have "inner_prod v (A *v v) = 0" using dimv zero by auto
      ultimately have "B$$(i, i) = 0" by auto
    }
    note zB = this
    show "B = 0m n n" by (insert zB dB dimB, rule eq_matI, auto simp add: diagonal_mat_def)
  qed
  then show "A = 0m n n" using A dimB dimP dimaP by auto
qed

lemma complex_mat_decomposition_to_hermitian:
  fixes A :: "complex mat"
  assumes dim: "A  carrier_mat n n"
  shows "B C. hermitian B  hermitian C  A = B + 𝗂 m C  B  carrier_mat n n  C  carrier_mat n n"
proof -
  obtain B C where B: "B = (1 / 2) m (A + adjoint A)" 
    and C: "C = (-𝗂 / 2) m (A - adjoint A)" by auto
  then have dimB: "B  carrier_mat n n" and dimC: "C  carrier_mat n n" using dim by auto
  have "hermitian B" unfolding B hermitian_def using dim
    by (auto simp add: adjoint_eval)
  moreover have "hermitian C" unfolding C hermitian_def using dim
    apply (subst eq_matI)
       apply (auto simp add: adjoint_eval algebra_simps)
    done
  moreover have "A = B + 𝗂 m C" using dim B C 
    apply (subst eq_matI)
       apply (auto simp add: adjoint_eval algebra_simps)
    done
  ultimately show ?thesis using dimB dimC by auto
qed

subsection ‹Outer product›

definition outer_prod :: "'a::conjugatable_field vec  'a vec  'a mat" where
  "outer_prod v w = mat (dim_vec v) 1 (λ(i, j). v $ i) * mat 1 (dim_vec w) (λ(i, j). (conjugate w) $ j)"

lemma outer_prod_dim[simp]:
  fixes v w :: "'a::conjugatable_field vec"
  assumes v: "v  carrier_vec n" and w: "w  carrier_vec m"
  shows "outer_prod v w  carrier_mat n m"
  unfolding outer_prod_def using assms mat_of_cols_carrier mat_of_rows_carrier by auto

lemma mat_of_vec_mult_eq_scalar_prod:
  fixes v w :: "'a::conjugatable_field vec"
  assumes "v  carrier_vec n" and "w  carrier_vec n"
  shows "mat 1 (dim_vec v) (λ(i, j). (conjugate v) $ j) * mat (dim_vec w) 1 (λ(i, j). w $ i) 
    = mat 1 1 (λk. inner_prod v w)"
  apply (rule eq_matI) using assms apply (simp add: scalar_prod_def) apply (rule sum.cong) by auto

lemma one_dim_mat_mult_is_scale:
  fixes A B :: "('a::conjugatable_field mat)"
  assumes "B  carrier_mat 1 n"
  shows "(mat 1 1 (λk. a)) * B = a m B"
  apply (rule eq_matI) using assms by (auto simp add: scalar_prod_def)

lemma outer_prod_mult_outer_prod:
  fixes a b c d :: "'a::conjugatable_field vec"
  assumes a: "a  carrier_vec d1" and b: "b  carrier_vec d2"
    and c: "c  carrier_vec d2" and d: "d  carrier_vec d3"
  shows "outer_prod a b * outer_prod c d = inner_prod b c m outer_prod a d"
proof -
  let ?ma = "mat (dim_vec a) 1 (λ(i, j). a $ i)"
  let ?mb = "mat 1 (dim_vec b) (λ(i, j). (conjugate b) $ j)"
  let ?mc = "mat (dim_vec c) 1 (λ(i, j). c $ i)"
  let ?md = "mat 1 (dim_vec d) (λ(i, j). (conjugate d) $ j)"
  have "(?ma * ?mb) * (?mc * ?md) = ?ma * (?mb * (?mc * ?md))"
    apply (subst assoc_mult_mat[of "?ma" d1 1 "?mb" d2 "?mc * ?md" d3] )
    using assms by auto
  also have " = ?ma * ((?mb * ?mc) * ?md)"
    apply (subst assoc_mult_mat[symmetric, of "?mb" 1 d2 "?mc" 1 "?md" d3])
    using assms by auto
  also have " = ?ma * ((mat 1 1 (λk. inner_prod b c)) * ?md)"
    apply (subst mat_of_vec_mult_eq_scalar_prod[of b d2 c]) using assms by auto
  also have " = ?ma * (inner_prod b c m ?md)" 
    apply (subst one_dim_mat_mult_is_scale) using assms by auto
  also have " = (inner_prod b c) m (?ma * ?md)" using assms by auto
  finally show ?thesis unfolding outer_prod_def by auto
qed

lemma index_outer_prod:
  fixes v w :: "'a::conjugatable_field vec"
  assumes v: "v  carrier_vec n" and w: "w  carrier_vec m"
    and ij: "i < n" "j < m"
  shows "(outer_prod v w)$$(i, j) = v $ i * conjugate (w $ j)"
  unfolding outer_prod_def using assms by (simp add: scalar_prod_def)

lemma mat_of_vec_mult_vec:
  fixes a b c :: "'a::conjugatable_field vec"
  assumes a: "a  carrier_vec d" and b: "b  carrier_vec d"
  shows "mat 1 d (λ(i, j). (conjugate a) $ j) *v b = vec 1 (λk. inner_prod a b)"
  apply (rule eq_vecI) 
   apply (simp add: scalar_prod_def carrier_vecD[OF a] carrier_vecD[OF b])
  apply (rule sum.cong) by auto

lemma mat_of_vec_mult_one_dim_vec:
  fixes a b :: "'a::conjugatable_field vec"
  assumes a: "a  carrier_vec d" 
  shows "mat d 1 (λ(i, j). a $ i) *v vec 1 (λk. c) = c v a"
  apply (rule eq_vecI)
  by (auto simp add: scalar_prod_def carrier_vecD[OF a])

lemma outer_prod_mult_vec:
  fixes a b c :: "'a::conjugatable_field vec"
  assumes a: "a  carrier_vec d1" and b: "b  carrier_vec d2"
    and c: "c  carrier_vec d2"
  shows "outer_prod a b *v c = inner_prod b c v a"
proof -
  have "outer_prod a b *v c 
    = mat d1 1 (λ(i, j). a $ i) 
    * mat 1 d2 (λ(i, j). (conjugate b) $ j)
    *v c" unfolding outer_prod_def using assms by auto
  also have " = mat d1 1 (λ(i, j). a $ i) 
    *v (mat 1 d2 (λ(i, j). (conjugate b) $ j)
    *v c)" apply (subst assoc_mult_mat_vec) using assms by auto
  also have " = mat d1 1 (λ(i, j). a $ i) 
    *v vec 1 (λk. inner_prod b c)" using mat_of_vec_mult_vec[of b] assms by auto
  also have " = inner_prod b c v a" using mat_of_vec_mult_one_dim_vec assms by auto
  finally show ?thesis by auto
qed

lemma trace_outer_prod_right:
  fixes A :: "'a::conjugatable_field mat" and v w :: "'a vec"
  assumes A: "A  carrier_mat n n"
    and v: "v  carrier_vec n" and w: "w  carrier_vec n"
  shows "trace (A * outer_prod v w) = inner_prod w (A *v v)" (is "?lhs = ?rhs")
proof -
  define B where "B = outer_prod v w"
  then have B: "B  carrier_mat n n" using assms by auto
  have "trace(A * B) = (i = 0..<n. j = 0..<n. A $$ (i,j) * B $$ (j,i))"
    unfolding trace_def using A B by (simp add: scalar_prod_def)
  also have " = (i = 0..<n. j = 0..<n. A $$ (i,j) * v $ j * conjugate (w $ i))"
    unfolding B_def
    apply (rule sum.cong, simp, rule sum.cong, simp)
    by (insert v w, auto simp add: index_outer_prod)
  finally have "?lhs = (i = 0..<n. j = 0..<n. A $$ (i,j) * v $ j * conjugate (w $ i))" using B_def by auto
  moreover have "?rhs = (i = 0..<n. j = 0..<n. A $$ (i,j) * v $ j * conjugate (w $ i))" using A v w
    by (simp add: scalar_prod_def sum_distrib_right)
  ultimately show ?thesis by auto
qed

lemma trace_outer_prod:
  fixes v w :: "('a::conjugatable_field vec)"
  assumes v: "v  carrier_vec n" and w: "w  carrier_vec n"
  shows "trace (outer_prod v w) = inner_prod w v" (is "?lhs = ?rhs")
proof -
  have "(1m n) * (outer_prod v w) = outer_prod v w" apply (subst left_mult_one_mat) using outer_prod_dim assms by auto
  moreover have "1m n *v v = v" using assms by auto
  ultimately show ?thesis using trace_outer_prod_right[of "1m n" n v w] assms by auto
qed

lemma inner_prod_outer_prod:
  fixes a b c d :: "'a::conjugatable_field vec"
  assumes a: "a  carrier_vec n" and b: "b  carrier_vec n"
    and c: "c  carrier_vec m" and d: "d  carrier_vec m"
  shows "inner_prod a (outer_prod b c *v d) = inner_prod a b * inner_prod c d" (is "?lhs = ?rhs")
proof -
  define P where "P = outer_prod b c"
  then have dimP: "P  carrier_mat n m" using assms by auto
  have "inner_prod a (P *v d) = (i=0..<n. (j=0..<m. conjugate (a$i) * P$$(i, j) * d$j))" using assms dimP
    apply (simp add: scalar_prod_def sum_distrib_right)
    apply (rule sum.cong, auto)
    apply (rule sum.cong, auto)
    done
  also have " = (i=0..<n. (j=0..<m. conjugate (a$i) * b$i * conjugate(c$j) * d$j))"
    using P_def b c by(simp add: index_outer_prod algebra_simps)
  finally have eq: "?lhs = (i=0..<n. (j=0..<m. conjugate (a$i) * b$i * conjugate(c$j) * d$j))" using P_def by auto

  have "?rhs = (i=0..<n. conjugate (a$i) * b$i) * (j=0..<m. conjugate(c$j) * d$j)" using assms 
    by (auto simp add: scalar_prod_def algebra_simps)
  also have " = (i=0..<n. (j=0..<m. conjugate (a$i) * b$i * conjugate(c$j) * d$j))"
    using assms by (simp add: sum_product algebra_simps)
  finally show "?lhs = ?rhs" using eq by auto
qed

subsection ‹Semi-definite matrices›

definition positive :: "complex mat  bool" where
  "positive A 
     A  carrier_mat (dim_col A) (dim_col A) 
     (v. dim_vec v = dim_col A  inner_prod v (A *v v)  0)"

lemma positive_iff_normalized_vec:
  "positive A 
    A  carrier_mat (dim_col A) (dim_col A) 
    (v. (dim_vec v = dim_col A  vec_norm v = 1)  inner_prod v (A *v v)  0)"
proof (rule)
  assume "positive A"
  then show "A  carrier_mat (dim_col A) (dim_col A)  
    (v. dim_vec v = dim_col A  vec_norm v = 1  0  inner_prod v (A *v v))"
    unfolding positive_def by auto
next
  define n where "n = dim_col A"
  assume "A  carrier_mat (dim_col A) (dim_col A)  (v. dim_vec v = dim_col A  vec_norm v = 1  0  inner_prod v (A *v v))"
  then have A: "A  carrier_mat (dim_col A) (dim_col A)" and geq0: "v. dim_vec v = dim_col A  vec_norm v = 1  0  inner_prod v (A *v v)" by auto
  then have dimA: "A  carrier_mat n n" using n_def[symmetric] by auto
  {
    fix v assume dimv: "(v::complex vec)  carrier_vec n"
    have "0  inner_prod v (A *v v)"
    proof (cases "v = 0v n")
      case True
      then show "0  inner_prod v (A *v v)" using dimA by auto
    next
      case False
      then have 1: "vec_norm v > 0" using vec_norm_ge_0 dimv by auto
      then have cnv: "cnj (vec_norm v) = vec_norm v" using Reals_cnj_iff complex_is_Real_iff by auto
      define w where "w = vec_normalize v"
      then have dimw: "w  carrier_vec n" using dimv by auto
      have nvw: "v = vec_norm v v w" using w_def vec_eq_norm_smult_normalized by auto
      have "vec_norm w = 1" using normalized_vec_norm[OF dimv False] vec_norm_def w_def by auto
      then have 2: "0  inner_prod w (A *v w)" using geq0 dimw dimA by auto
      have "inner_prod v (A *v v) = vec_norm v * vec_norm v * inner_prod w (A *v w)" using dimA dimv dimw
        apply (subst (1 2) nvw)
        apply (subst mult_mat_vec, simp, simp)
        apply (subst scalar_prod_smult_left[of "(A *v w)" "conjugate (vec_norm v v w)" "vec_norm v"], simp)
        apply (simp add: conjugate_smult_vec cnv)
        done
      also have "  0" using 1 2 by auto
      finally show "0  inner_prod v (A *v v)" by auto
    qed
  }
  then have geq: "v. dim_vec v = dim_col A  0  inner_prod v (A *v v)" using dimA by auto
  show "positive A" unfolding positive_def 
    by (rule, simp add: A, rule geq)
qed

lemma positive_is_hermitian:
  fixes A :: "complex mat"
  assumes pA: "positive A"
  shows "hermitian A"
proof -
  define n where "n = dim_col A"
  then have dimA: "A  carrier_mat n n" using positive_def pA by auto
  obtain B C where B: "hermitian B" and C: "hermitian C" and A: "A = B + 𝗂 m C"
    and dimB: "B  carrier_mat n n" and dimC: "C  carrier_mat n n" and dimiC: "𝗂 m C  carrier_mat n n"
    using complex_mat_decomposition_to_hermitian[OF dimA] by auto
  {
    fix v :: "complex vec" assume dimv: "v  carrier_vec n"
    have dimvA: "dim_vec v = dim_col A" using dimv dimA by auto
    have "inner_prod v (A *v v) = inner_prod v (B *v v) + inner_prod v ((𝗂 m C) *v v)"
      unfolding A using dimB dimiC dimv by (simp add: add_mult_distrib_mat_vec inner_prod_distrib_right)
    moreover have "inner_prod v ((𝗂 m C) *v v) = 𝗂 * inner_prod v (C *v v)" using dimv dimC
      apply (simp add: scalar_prod_def sum_distrib_left cong: sum.cong)
      apply (rule sum.cong, auto)
      done
    ultimately have ABC: "inner_prod v (A *v v) = inner_prod v (B *v v) + 𝗂 * inner_prod v (C *v v)" by auto
    moreover have "inner_prod v (B *v v)  Reals" using B dimB dimv hermitian_inner_prod_real by auto
    moreover have "inner_prod v (C *v v)  Reals" using C dimC dimv hermitian_inner_prod_real by auto
    moreover have "inner_prod v (A *v v)  Reals" using pA unfolding positive_def 
      apply (rule) 
      apply (fold n_def)
      apply (simp add: complex_is_Real_iff[of "inner_prod v (A *v v)"])
      apply (auto simp add: dimvA)
      done
    ultimately have "inner_prod v (C *v v) = 0" using of_real_Re by fastforce
  } 
  then have "C = 0m n n" using hermitian_inner_prod_zero dimC C by auto
  then have "A = B" using A dimC dimB by auto
  then show "hermitian A" using B by auto
qed

lemma positive_eigenvalue_positive:
  assumes dimA: "(A::complex mat)  carrier_mat n n"
    and pA: "positive A"
    and c: "char_poly A = ( (e :: complex)  es. [:- e, 1:])"
    and B: "unitary_schur_decomposition A es = (B,P,Q)"
  shows "i. i < n  B$$(i, i)  0"
proof -
  have hA: "hermitian A" using positive_is_hermitian pA by auto
  have "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es 
     unitary P  (i < n. B$$(i, i)  Reals)"
    using hermitian_eigenvalue_real dimA hA B c by auto
  then have A: "A = P * B * (adjoint P)" and dB: "diagonal_mat B"
    and Bii: "i. i < n  B$$(i, i)  Reals"
    and dimB: "B  carrier_mat n n" and dimP: "P  carrier_mat n n" and dimaP: "adjoint P  carrier_mat n n"
    and uP: "unitary P" 
    unfolding similar_mat_wit_def Let_def unitary_def using dimA by auto
  {
    fix i assume i: "i < n"
    define v where "v = col P i"
    then have dimv: "v  carrier_vec n" using v_def dimP by auto
    have "inner_prod v (A *v v) = B$$(i, i)" unfolding A v_def
      using spectral_decomposition_extract_diag[OF dimP dimB uP dB i]  by auto
    moreover have "inner_prod v (A *v v)  0" using dimv pA dimA positive_def by auto
    ultimately show "B$$(i, i)  0" by auto
  }
qed

lemma diag_mat_mult_diag_mat:
  fixes B D :: "'a::semiring_0 mat"
  assumes dimB: "B  carrier_mat n n" and dimD: "D  carrier_mat n n"
    and dB: "diagonal_mat B" and dD: "diagonal_mat D"
  shows "B * D = mat n n (λ(i,j). (if i = j then (B$$(i, i)) * (D$$(i, i)) else 0))"
proof(rule eq_matI, auto)
  have Bij: "x y. x < n  y < n  x  y  B$$(x, y) = 0" using dB diagonal_mat_def dimB by auto
  have Dij: "x y. x < n  y < n  x  y  D$$(x, y) = 0" using dD diagonal_mat_def dimD by auto
{
  fix i j assume ij: "i < n" "j < n"
  have "(B * D) $$ (i, j) = (k=0..<n. (B $$ (i, k)) * (D $$ (k, j)))" using dimB dimD
    by (auto simp add: scalar_prod_def ij)
  also have " = B$$(i, i) * D$$(i, j)"
    apply (simp add: sum.remove[of _i] ij)
    apply (simp add: Bij Dij ij)
    done
  finally have "(B * D) $$ (i, j) = B$$(i, i) * D$$(i, j)".
}
  note BDij = this
  from BDij show "j. j < n  (B * D) $$ (j, j) = B $$ (j, j) * D $$ (j, j)" by auto
  from BDij show "i j. i < n  j < n  i  j  (B * D) $$ (i, j) = 0" using Bij Dij by auto
  from assms show "dim_row B = n" "dim_col D = n" by auto
qed

lemma positive_only_if_decomp:
  assumes dimA: "A  carrier_mat n n" and pA: "positive A"
  shows "M  carrier_mat n n. M * adjoint M = A"
proof -
  from pA have hA: "hermitian A" using positive_is_hermitian by auto
  obtain es where es: "char_poly A = ( (e :: complex)  es. [:- e, 1:])" 
    using complex_mat_char_poly_factorizable dimA by auto
  obtain B P Q where schur: "unitary_schur_decomposition A es = (B,P,Q)" 
    by (cases "unitary_schur_decomposition A es", auto)
  then have "similar_mat_wit A B P (adjoint P)  diagonal_mat B  diag_mat B = es 
     unitary P  (i < n. B$$(i, i)  Reals)"
    using hermitian_eigenvalue_real dimA es hA by auto
  then have A: "A = P * B * (adjoint P)" and dB: "diagonal_mat B"
    and Bii: "i. i < n  B$$(i, i)  Reals"
    and dimB: "B  carrier_mat n n" and dimP: "P  carrier_mat n n" and dimaP: "adjoint P  carrier_mat n n"
    unfolding similar_mat_wit_def Let_def using dimA by auto
  have Bii: "i. i < n  B$$(i, i)  0" using pA dimA es schur positive_eigenvalue_positive by auto
  define D where "D = mat n n (λ(i, j). (if (i = j) then csqrt (B$$(i, i)) else 0))"
  then have dimD: "D  carrier_mat n n" and dimaD: "adjoint D  carrier_mat n n" using dimB by auto
  have dD: "diagonal_mat D" using dB D_def unfolding diagonal_mat_def by auto
  then have daD: "diagonal_mat (adjoint D)" by (simp add: adjoint_eval diagonal_mat_def)
  have Dii: "i. i < n  D$$(i, i) = csqrt (B$$(i, i))" using dimD D_def by auto
  {
    fix i assume i: "i < n"
    define c where "c = csqrt (B$$(i, i))"
    have c: "c  0" using Bii i c_def by auto
    then have "conjugate c = c" 
      using Reals_cnj_iff complex_is_Real_iff by auto
    then have "c * cnj c = B$$(i, i)" using c_def c unfolding conjugate_complex_def by (metis power2_csqrt power2_eq_square)
  }  
  note cBii = this
  have "D * adjoint D = mat n n (λ(i,j). (if (i = j) then B$$(i, i) else 0))"
    apply (simp add: diag_mat_mult_diag_mat[OF dimD dimaD dD daD])
    apply (rule eq_matI, auto simp add: D_def adjoint_eval cBii)
    done
  also have " = B" using dimB dB[unfolded diagonal_mat_def] by auto
  finally have DaDB: "D * adjoint D = B".
  define M where "M = P * D"
  then have dimM: "M  carrier_mat n n" using dimP dimD by auto
  have "M * adjoint M = (P * D) * (adjoint D * adjoint P)" using M_def adjoint_mult[OF dimP dimD] by auto
  also have " = P * (D * adjoint D) * (adjoint P)" using dimP dimD by (mat_assoc n)
  also have " = P * B * (adjoint P)" using DaDB by auto
  finally have "M * adjoint M = A" using A by auto
  with dimM show "M  carrier_mat n n. M * adjoint M = A" by auto
qed

lemma positive_if_decomp:
  assumes dimA: "A  carrier_mat n n" and "M. M * adjoint M = A"
  shows "positive A"
proof -
  from assms obtain M where M: "M * adjoint M = A" by auto
  define m where "m = dim_col M"
  have dimM: "M  carrier_mat n m" using M dimA m_def by auto
{
  fix v assume dimv: "(v::complex vec)  carrier_vec n"
  have dimaM: "adjoint M  carrier_mat m n" using dimM by auto
  have dimaMv: "(adjoint M) *v v  carrier_vec m" using dimaM dimv by auto
  have "inner_prod v (A *v v) = inner_prod v (M * adjoint M *v v)" using M by auto
  also have " = inner_prod v (M *v (adjoint M *v v))" using assoc_mult_mat_vec dimM dimaM dimv by auto
  also have " = inner_prod (adjoint M *v v) (adjoint M *v v)" using adjoint_def_alter[OF dimv dimaMv dimM] by auto
  also have "  0" using self_cscalar_prod_geq_0 by auto
  finally have "inner_prod v (A *v v)  0".
}
  note geq0 = this
  from dimA geq0 show "positive A" using positive_def by auto
qed

lemma positive_iff_decomp:
  assumes dimA: "A  carrier_mat n n"
  shows "positive A  (Mcarrier_mat n n. M * adjoint M = A)"
proof
  assume pA: "positive A"
  then show "Mcarrier_mat n n. M * adjoint M = A" using positive_only_if_decomp assms by auto
next
  assume "Mcarrier_mat n n. M * adjoint M = A"
  then obtain M where M: "M * adjoint M = A" by auto
  then show "positive A" using M positive_if_decomp assms by auto
qed

lemma positive_dim_eq:
  assumes "positive A"
  shows "dim_row A = dim_col A"
  using carrier_matD(1)[of A "dim_col A" "dim_col A"]  assms[unfolded positive_def] by simp

lemma positive_zero:
  "positive (0m n n)"
  by (simp add: positive_def zero_mat_def mult_mat_vec_def scalar_prod_def)

lemma positive_one:
  "positive (1m n)"
proof (rule positive_if_decomp)
  show "1m n  carrier_mat n n" by auto
  have "adjoint (1m n) = 1m n" using hermitian_one hermitian_def by auto
  then have "1m n * adjoint (1m n) = 1m n" by auto
  then show "M. M * adjoint M = 1m n" by fastforce
qed

lemma positive_antisym:
  assumes pA: "positive A" and pnA: "positive (-A)"
  shows "A = 0m (dim_col A) (dim_col A)"
proof -
  define n where "n = dim_col A"
  from pA have dimA: "A  carrier_mat n n" and dimnA: "-A  carrier_mat n n"
    using positive_def n_def by auto
  from pA have hA: "hermitian A" using positive_is_hermitian by auto
  obtain es where es: "char_poly A = ( (e :: complex)  es. [:- e, 1:])" 
    using complex_mat_char_poly_factorizable dimA by auto
  obtain B P Q where schur: "unitary_schur_decomposition A es = (B,P,Q)" 
    by (cases "unitary_schur_decomposition A es", auto)
  then have "similar_mat_wit A B P (adjoint P)  diagonal_mat B  unitary P"
    using hermitian_eigenvalue_real dimA es hA by auto
  then have A: "A = P * B * (adjoint P)" and dB: "diagonal_mat B" and uP: "unitary P"
    and dimB: "B  carrier_mat n n" and dimnB: "-B  carrier_mat n n"
    and dimP: "P  carrier_mat n n" and dimaP: "adjoint P  carrier_mat n n"
    unfolding similar_mat_wit_def Let_def using dimA by auto
  from es schur have geq0: "i. i < n  B$$(i, i)  0" using positive_eigenvalue_positive dimA pA by auto
  from A have nA: "-A = P * (-B) * (adjoint P)" using mult_smult_assoc_mat dimB dimP dimaP by auto
  from dB have dnB: "diagonal_mat (-B)" by (simp add: diagonal_mat_def)
  {
    fix i assume i: "i < n"
    define v where "v = col P i"
    then have dimv: "v  carrier_vec n" using v_def dimP by auto
    have "inner_prod v ((-A) *v v) = (-B)$$(i, i)" unfolding nA v_def
      using spectral_decomposition_extract_diag[OF dimP dimnB uP dnB i]  by auto
    moreover have "inner_prod v ((-A) *v v)  0" using dimv pnA dimnA positive_def by auto
    ultimately have "B$$(i, i)  0" using dimB i by auto
    moreover have "B$$(i, i)  0" using i geq0 by auto
    ultimately have "B$$(i, i) = 0" by (metis no_atp(10))
  }
  then have "B = 0m n n" using dimB dB[unfolded diagonal_mat_def]
    by (subst eq_matI, auto)
  then show "A = 0m n n" using A dimB dimP dimaP by auto
qed

lemma positive_add:
  assumes pA: "positive A" and pB: "positive B"
    and dimA: "A  carrier_mat n n" and dimB: "B  carrier_mat n n"
  shows "positive (A + B)"
  unfolding positive_def
proof
  have dimApB: "A + B  carrier_mat n n" using dimA dimB by auto
  then show "A + B  carrier_mat (dim_col (A + B)) (dim_col (A + B))" using carrier_matD[of "A+B"] by auto
  {
    fix v assume dimv: "(v::complex vec)  carrier_vec n"
    have 1: "inner_prod v (A *v v)  0" using dimv pA[unfolded positive_def] dimA by auto
    have 2: "inner_prod v (B *v v)  0" using dimv pB[unfolded positive_def] dimB by auto
    have "inner_prod v ((A + B) *v v) = inner_prod v (A *v v) + inner_prod v (B *v v)"
      using dimA dimB dimv by (simp add: add_mult_distrib_mat_vec inner_prod_distrib_right) 
    also have "  0" using 1 2 by auto
    finally have "inner_prod v ((A + B) *v v)  0".
  }
  note geq0 = this
  then have "v. dim_vec v = n  0  inner_prod v ((A + B) *v v)" by auto
  then show "v. dim_vec v = dim_col (A + B)  0  inner_prod v ((A + B) *v v)" using dimApB by auto
qed

lemma positive_trace:
  assumes "A  carrier_mat n n" and "positive A"
  shows "trace A  0"
  using assms positive_iff_decomp trace_adjoint_positive by auto

lemma positive_close_under_left_right_mult_adjoint:
  fixes M A :: "complex mat"
  assumes dM: "M  carrier_mat n n" and dA: "A  carrier_mat n n" 
    and pA: "positive A"
  shows "positive (M * A * adjoint M)"
  unfolding positive_def
proof (rule, simp add: mult_carrier_mat[OF mult_carrier_mat[OF dM dA] adjoint_dim[OF dM]] carrier_matD[OF dM], rule, rule)
  have daM: "adjoint M  carrier_mat n n" using dM by auto
  fix v::"complex vec" assume "dim_vec v = dim_col (M * A * adjoint M)"
  then have dv: "v  carrier_vec n" using assms by auto
  then have "adjoint M *v v  carrier_vec n" using daM by auto
  have assoc: "M * A * adjoint M *v v = M *v (A *v (adjoint M *v v))"
    using dA dM daM dv by (auto simp add: assoc_mult_mat_vec[of _ n n _ n])
  have "inner_prod v (M * A * adjoint M *v v) = inner_prod (adjoint M *v v) (A *v (adjoint M *v v))"
    apply (subst assoc)
    apply (subst adjoint_def_alter[where ?A = "M"])
    by (auto simp add: dv dA daM dM carrier_matD[OF dM] mult_mat_vec_carrier[of _ n n])
  also have "  0" using dA dv daM pA positive_def by auto
  finally show "inner_prod v (M * A * adjoint M *v v)  0" by auto
qed

lemma positive_same_outer_prod:
  fixes v w :: "complex vec"
  assumes v: "v  carrier_vec n" 
  shows "positive (outer_prod v v)"
proof -
  have d1: "adjoint (mat (dim_vec v) 1 (λ(i, j). v $ i))  carrier_mat 1 n" using assms by auto
  have d2: "mat 1 (dim_vec v) (λ(i, y). conjugate v $ y)  carrier_mat 1 n" using assms by auto
  have dv: "dim_vec v = n" using assms by auto
  have "mat 1 (dim_vec v) (λ(i, y). conjugate v $ y) = adjoint (mat (dim_vec v) 1 (λ(i, j). v $ i))" (is "?r = adjoint ?l")
    apply (rule eq_matI) 
    subgoal for i j by (simp add: dv adjoint_eval)
    using d1 d2 by auto
  then have "outer_prod v v = ?l * adjoint ?l" unfolding outer_prod_def by auto
  then have "M. M * adjoint M = outer_prod v v" by auto
  then show "positive (outer_prod v v)" using positive_if_decomp[OF outer_prod_dim[OF v v]] by auto
qed

lemma smult_smult_mat: 
  fixes k :: complex and l :: complex
  assumes "A  carrier_mat nr n"
  shows "k m (l m A) = (k * l) m A" by auto

lemma positive_smult: 
  assumes "A  carrier_mat n n"
    and "positive A"
    and "c  0"
  shows "positive (c m A)"
proof -
  have sc: "csqrt c  0" using assms(3) by fastforce
  obtain M where dimM: "M  carrier_mat n n" and A: "M * adjoint M = A" using assms(1-2) positive_iff_decomp by auto
  have "c m A  = c m (M * adjoint M)" using A by auto
  have ccsq: "conjugate (csqrt c) = (csqrt c)" using sc Reals_cnj_iff[of "csqrt c"] complex_is_Real_iff by auto
  have MM: "(M * adjoint M)  carrier_mat n n" using A assms by fastforce
  have leftd: "c  m (M * adjoint M)  carrier_mat n n" using A assms by fastforce
  have rightd: "(csqrt c m M) * (adjoint (csqrt c m M)) carrier_mat n n" using A assms by fastforce
  have "(csqrt c m M) * (adjoint (csqrt c m M)) = (csqrt c m M) * ((conjugate (csqrt c)) m  adjoint M)"
    using adjoint_scale assms(1) by (metis adjoint_scale)
  also have " = (csqrt c m M) * (csqrt c m adjoint M)" using sc ccsq by fastforce
  also have " = csqrt c m (M * (csqrt c m adjoint M))"
    using mult_smult_assoc_mat index_smult_mat(2,3) by fastforce
  also have " =  csqrt c m ((csqrt c) m (M * adjoint M))"
    using mult_smult_distrib by fastforce
  also have " = c m (M * adjoint M)" 
    using smult_smult_mat[of "M * adjoint M" n n "(csqrt c)" "(csqrt c)"]  MM sc
    by (metis power2_csqrt power2_eq_square )   
  also have " = c m A" using A by auto
  finally have "(csqrt c m M) * (adjoint (csqrt c m M)) = c m A" by auto
  moreover have "c m A  carrier_mat n n" using assms(1) by auto
  moreover have "csqrt c m M  carrier_mat n n" using dimM by auto
  ultimately show ?thesis using positive_iff_decomp by auto
qed

text ‹Version of previous theorem for real numbers›
lemma positive_scale: 
  fixes c :: real
  assumes  "A  carrier_mat n n"
    and "positive A"
    and "c  0"
  shows "positive (c m A)"
  apply (rule positive_smult) using assms by auto

subsection ‹L\"{o}wner partial order›

definition lowner_le :: "complex mat  complex mat  bool"  (infix "L" 50) where
  "A L B  dim_row A = dim_row B  dim_col A = dim_col B  positive (B - A)"

lemma lowner_le_refl:
  assumes "A  carrier_mat n n"
  shows "A L A"
  unfolding lowner_le_def
  apply (simp add: minus_r_inv_mat[OF assms])
  by (rule positive_zero)

lemma lowner_le_antisym:
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
    and L1: "A L B" and L2: "B L A"
  shows "A = B"
proof -
  from L1 have P1: "positive (B - A)" by (simp add: lowner_le_def)
  from L2 have P2: "positive (A - B)" by (simp add: lowner_le_def)
  have "A - B = - (B - A)" using A B by auto
  then have P3: "positive (- (B - A))" using P2 by auto
  have BA: "B - A  carrier_mat n n" using A B by auto
  have "B - A = 0m n n" using BA by (subst positive_antisym[OF P1 P3], auto)
  then have "B + (-A) + A = 0m n n + A" using A B minus_add_uminus_mat[OF B A] by auto
  then have "B + (-A + A) = 0m n n + A" using A B by auto
  then show "A = B" using A B BA uminus_l_inv_mat[OF A] by auto
qed

lemma lowner_le_inner_prod_le:
  fixes A B :: "complex mat" and v :: "complex vec"
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
    and v: "v  carrier_vec n"
    and "A L B"
  shows "inner_prod v (A *v v)  inner_prod v (B *v v)"
proof -
  from assms have "positive (B-A)" by (auto simp add: lowner_le_def)
  with assms have geq: "inner_prod v ((B-A) *v v)  0" 
    unfolding positive_def by auto
  have "inner_prod v ((B-A) *v v) = inner_prod v (B *v v) - inner_prod v (A *v v)" 
    unfolding minus_add_uminus_mat[OF B A]
    by (subst add_mult_distrib_mat_vec[OF B _ v], insert A B v, auto simp add: inner_prod_distrib_right[OF v])
  then show ?thesis using geq by auto
qed

lemma lowner_le_trans:
  fixes A B C :: "complex mat"
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n" and C: "C  carrier_mat n n"
    and L1: "A L B" and L2: "B L C"
  shows "A L C"
  unfolding lowner_le_def
proof (auto simp add: carrier_matD[OF A] carrier_matD[OF C])
  have dim: "C - A  carrier_mat n n" using A C by auto
  {
    fix v assume v: "(v::complex vec)  carrier_vec n"
    from L1 have "inner_prod v (A *v v)  inner_prod v (B *v v)" using lowner_le_inner_prod_le A B v by auto
    also from L2 have "  inner_prod v (C *v v)" using lowner_le_inner_prod_le B C v by auto
    finally have "inner_prod v (A *v v)  inner_prod v (C *v v)".
    then have "inner_prod v (C *v v) - inner_prod v (A *v v)  0" by auto
    then have "inner_prod v ((C - A) *v v)  0" using A C v
      apply (subst minus_add_uminus_mat[OF C A])
      apply (subst add_mult_distrib_mat_vec[OF C _ v], simp)
      apply (simp add: inner_prod_distrib_right[OF v])
      done
  }
  note leq = this
  show "positive (C - A)" unfolding positive_def
    apply (rule, simp add: carrier_matD[OF A] dim)
    apply (subst carrier_matD[OF dim], insert leq, auto)
    done
qed

lemma lowner_le_imp_trace_le:
  assumes "A  carrier_mat n n" and "B  carrier_mat n n"
    and "A L B"
  shows "trace A  trace B"
proof -
  have "positive (B - A)" using assms lowner_le_def by auto
  moreover have "B - A  carrier_mat n n" using assms by auto
  ultimately have "trace (B - A)  0" using positive_trace by auto
  moreover have "trace (B - A) = trace B - trace A" using trace_minus_linear assms by auto
  ultimately have "trace B - trace A  0" by auto
  then show "trace A  trace B" by auto
qed

lemma lowner_le_add:
  assumes "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n" "D  carrier_mat n n"
    and "A L B" "C L D"
  shows "A + C L B + D"
proof -
  have "B + D - (A + C) = B - A + (D - C) " using assms by auto
  then have "positive (B + D - (A + C))" using assms unfolding lowner_le_def using positive_add
    by (metis minus_carrier_mat)
  then show "A + C L B + D" unfolding lowner_le_def using assms by fastforce
qed

lemma lowner_le_swap:
  assumes "A  carrier_mat n n" "B  carrier_mat n n" 
    and "A L B" 
  shows "-B L -A"
proof -
  have "positive (B - A)" using assms lowner_le_def by fastforce
  moreover have "B - A = (-A) - (-B)" using assms by fastforce
  ultimately have "positive ((-A) - (-B))" by auto
  then show ?thesis using lowner_le_def assms by fastforce
qed

lemma lowner_le_minus:
  assumes "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n" "D  carrier_mat n n"
    and "A L B" "C L D"
  shows "A - D L B - C"
proof -
  have "positive (D - C)" using assms lowner_le_def by auto
  then have "-D L -C" using lowner_le_swap assms by auto
  then have "A + (-D) L B + (-C)" using lowner_le_add[of "A" n  "B"] assms by auto
  moreover have "A + (-D) = A - D" and "B + (-C) = B - C" by auto
  ultimately show ?thesis by auto
qed

lemma outer_prod_le_one:
  assumes "v  carrier_vec n"
    and "inner_prod v v  1"
  shows "outer_prod v v L 1m n"
proof -
  let ?o = "outer_prod v v"
  have do: "?o  carrier_mat n n" using assms by auto
  {
    fix u :: "complex vec" assume "dim_vec u = n"
    then have du: "u  carrier_vec n" by auto
    have r: "inner_prod u u  Reals" apply (simp add: scalar_prod_def carrier_vecD[OF du])
      using complex_In_mult_cnj_zero complex_is_Real_iff by blast
    have geq0: "inner_prod u u  0" 
      using self_cscalar_prod_geq_0 by auto

    have "inner_prod u (?o *v u) = inner_prod u v * inner_prod v u"
      apply (subst inner_prod_outer_prod)
      using du assms by auto
    also have "  inner_prod u u * inner_prod v v" using Cauchy_Schwarz_complex_vec du assms by auto
    also have "  inner_prod u u" using assms(2) r geq0 
      by (simp add: mult_right_le_one_le)
    finally have le: "inner_prod u (?o *v u)  inner_prod u u".

    have "inner_prod u ((1m n - ?o) *v u) = inner_prod u ((1m n *v u) - ?o *v u)"
      apply (subst minus_mult_distrib_mat_vec) using do du by auto
    also have " = inner_prod u u - inner_prod u (?o *v u)"
      apply (subst inner_prod_minus_distrib_right)
      using du do by auto
    also have "  0" using le by auto
    finally have "inner_prod u ((1m n - ?o) *v u)  0" by auto
  }
  then have "positive (1m n - outer_prod v v)"
    unfolding positive_def using do by auto
  then show ?thesis unfolding lowner_le_def using do by auto
qed

lemma zero_lowner_le_positiveD:
  fixes A :: "complex mat"
  assumes dA: "A  carrier_mat n n" and le: "0m n n L A"
  shows "positive A"
  using assms unfolding lowner_le_def by (subgoal_tac "A - 0m n n = A", auto)

lemma zero_lowner_le_positiveI:
  fixes A :: "complex mat"
  assumes dA: "A  carrier_mat n n" and le: "positive A"
  shows "0m n n L A"
  using assms unfolding lowner_le_def by (subgoal_tac "A - 0m n n = A", auto)

lemma lowner_le_trans_positiveI:
  fixes A B :: "complex mat"
  assumes dA: "A  carrier_mat n n" and pA: "positive A" and le: "A L B"
  shows "positive B"
proof -
  have dB: "B  carrier_mat n n" using le dA lowner_le_def by auto
  have "0m n n L A" using zero_lowner_le_positiveI dA pA by auto
  then have "0m n n L B" using dA dB le by (simp add: lowner_le_trans[of _ n A B])
  then show ?thesis using dB zero_lowner_le_positiveD by auto
qed

lemma lowner_le_keep_under_measurement:
  fixes M A B :: "complex mat"
  assumes dM: "M  carrier_mat n n" and dA: "A  carrier_mat n n" and dB: "B  carrier_mat n n"
    and le: "A L B"
  shows "adjoint M * A * M L adjoint M * B * M"
  unfolding lowner_le_def
proof (rule conjI, fastforce)+
  have daM: "adjoint M  carrier_mat n n" using dM by auto
  have dBmA: "B - A  carrier_mat n n" using dB dA by fastforce
  have "positive (B - A)" using le lowner_le_def by auto
  then have p: "positive (adjoint M * (B - A) * M)" 
    using positive_close_under_left_right_mult_adjoint[OF daM dBmA] adjoint_adjoint[of M] by auto
  moreover have e: "adjoint M * (B - A) * M = adjoint M * B * M - adjoint M * A * M" using dM dB dA by (mat_assoc n)
  ultimately show "positive (adjoint M * B * M - adjoint M * A * M)" by auto
qed

lemma smult_distrib_left_minus_mat:
  fixes A B :: "'a::comm_ring_1 mat"
  assumes "A  carrier_mat n n" "B  carrier_mat n n"
  shows "c m (B - A) = c m B - c m A"
  using assms by (auto simp add: minus_add_uminus_mat add_smult_distrib_left_mat)

lemma lowner_le_smultc:
  fixes c :: complex
  assumes "c  0" "A L B" "A  carrier_mat n n" "B  carrier_mat n n"
  shows "c m A L c m B"
proof -
  have eqBA: "c m (B - A) = c m B - c m A"
    using assms by (auto simp add: smult_distrib_left_minus_mat)

  have "positive (B - A)" using assms(2) unfolding lowner_le_def by auto
  then have "positive (c m (B - A))" using positive_smult[of "B-A" n c] assms by fastforce
  moreover have "c m A  carrier_mat n n" using index_smult_mat(2,3) assms(3) by auto
  moreover have "c m B  carrier_mat n n" using index_smult_mat(2,3) assms(4) by auto 
  ultimately show ?thesis unfolding lowner_le_def using eqBA by fastforce
qed

lemma lowner_le_smult:
  fixes c :: real
  assumes "c  0" "A L B" "A  carrier_mat n n" "B  carrier_mat n n"
  shows "c m A L c m B"
  apply (rule lowner_le_smultc) using assms by auto

lemma minus_smult_vec_distrib:
  fixes w :: "'a::comm_ring_1 vec"
  shows "(a - b) v w = a v w - b v w"
  apply (rule eq_vecI)
  by (auto simp add: scalar_prod_def algebra_simps)

lemma smult_mat_mult_mat_vec_assoc:
  fixes A :: "'a::comm_ring_1 mat"
  assumes A: "A  carrier_mat n m" and w: "w  carrier_vec m"
  shows "a m A *v w = a v (A *v w)"
  apply (rule eq_vecI)
   apply (simp add: scalar_prod_def carrier_matD[OF A] carrier_vecD[OF w])
   apply (subst sum_distrib_left) apply (rule sum.cong, simp)
  by auto

lemma mult_mat_vec_smult_vec_assoc:
  fixes A :: "'a::comm_ring_1 mat"
  assumes A: "A  carrier_mat n m" and w: "w  carrier_vec m"
  shows "A *v (a v w) = a v (A *v w)"
  apply (rule eq_vecI)
   apply (simp add: scalar_prod_def carrier_matD[OF A] carrier_vecD[OF w])
   apply (subst sum_distrib_left) apply (rule sum.cong, simp)
  by auto 

lemma outer_prod_left_right_mat:
  fixes A B :: "complex mat"
  assumes du: "u  carrier_vec d2" and dv: "v  carrier_vec d3"
    and dA: "A  carrier_mat d1 d2" and dB: "B  carrier_mat d3 d4"
  shows "A * (outer_prod u v) * B = (outer_prod (A *v u) (adjoint B *v v))"
  unfolding outer_prod_def
proof -
  have eq1: "A * (mat (dim_vec u) 1 (λ(i, j). u $ i)) = mat (dim_vec (A *v u)) 1 (λ(i, j). (A *v u) $ i)"
    apply (rule eq_matI)
    by (auto simp add: dA du scalar_prod_def)
  have conj: "conjugate a * b = conjugate ((a::complex) * conjugate b) " for a b by auto
  have eq2: "mat 1 (dim_vec v) (λ(i, y). conjugate v $ y) * B = mat 1 (dim_vec (adjoint B *v v)) (λ(i, y). conjugate (adjoint B *v v) $ y)"
    apply (rule eq_matI)
      apply (auto simp add: carrier_matD[OF dB] carrier_vecD[OF dv] scalar_prod_def adjoint_def conjugate_vec_def sum_conjugate )
    apply (rule sum.cong)
    by (auto simp add: conj)
  have "A * (mat (dim_vec u) 1 (λ(i, j). u $ i) * mat 1 (dim_vec v) (λ(i, y). conjugate v $ y)) * B =
       (A * (mat (dim_vec u) 1 (λ(i, j). u $ i))) *(mat 1 (dim_vec v) (λ(i, y). conjugate v $ y)) * B"
    using dA du dv dB assoc_mult_mat[OF dA, of "mat (dim_vec u) 1 (λ(i, j). u $ i)" 1 "mat 1 (dim_vec v) (λ(i, y). conjugate v $ y)"] by fastforce
  also have " = (A * (mat (dim_vec u) 1 (λ(i, j). u $ i))) *((mat 1 (dim_vec v) (λ(i, y). conjugate v $ y)) * B)"
    using dA du dv dB assoc_mult_mat[OF _ _ dB, of "(A * (mat (dim_vec u) 1 (λ(i, j). u $ i)))" d1 1] by fastforce
  finally show "A * (mat (dim_vec u) 1 (λ(i, j). u $ i) * mat 1 (dim_vec v) (λ(i, y). conjugate v $ y)) * B =
    mat (dim_vec (A *v u)) 1 (λ(i, j). (A *v u) $ i) * mat 1 (dim_vec (adjoint B *v v)) (λ(i, y). conjugate (adjoint B *v v) $ y)" 
    using eq1 eq2 by auto
qed

subsection ‹Density operators›

definition density_operator :: "complex mat  bool" where
  "density_operator A  positive A  trace A = 1"

definition partial_density_operator :: "complex mat  bool" where
  "partial_density_operator A  positive A  trace A  1"

lemma pure_state_self_outer_prod_is_partial_density_operator:
  fixes v :: "complex vec"
  assumes dimv: "v  carrier_vec n" and nv: "vec_norm v = 1"
  shows "partial_density_operator (outer_prod v v)"
  unfolding partial_density_operator_def
proof
  have dimov: "outer_prod v v  carrier_mat n n" using dimv by auto
  show "positive (outer_prod v v)" unfolding positive_def
  proof (rule, simp add: carrier_matD(2)[OF dimov] dimov, rule allI, rule impI)
    fix w assume "dim_vec (w::complex vec) = dim_col (outer_prod v v)"
    then have dimw: "w  carrier_vec n"  using dimov carrier_vecI by auto
    then have "inner_prod w ((outer_prod v v) *v w) = inner_prod w v * inner_prod v w" 
      using inner_prod_outer_prod dimw dimv by auto
    also have " = inner_prod w v * conjugate (inner_prod w v)" using dimw dimv
      apply (subst conjugate_scalar_prod[of v "conjugate w"], simp)
      apply (subst conjugate_vec_sprod_comm[of "conjugate v" _ "conjugate w"], auto)
       apply (rule carrier_vec_conjugate[OF dimv])
      apply (rule carrier_vec_conjugate[OF dimw])
      done
    also have "  0" by auto
    finally show "inner_prod w ((outer_prod v v) *v w)  0".
  qed
  have eq: "trace (outer_prod v v) = (i=0..<n. v$i * conjugate(v$i))" unfolding trace_def 
    apply (subst carrier_matD(1)[OF dimov])
    apply (simp add: index_outer_prod[OF dimv dimv])
    done
  have "vec_norm v = csqrt (i=0..<n. v$i * conjugate(v$i))" unfolding vec_norm_def using dimv
    by (simp add: scalar_prod_def)
  then have "(i=0..<n. v$i * conjugate(v$i)) = 1" using nv by auto
  with eq show "trace (outer_prod v v)  1" by auto
qed

(* Lemma 2.1 *)
lemma lowner_le_trace:
  assumes A: "A  carrier_mat n n"
    and B: "B  carrier_mat n n"
  shows "A L B  (ρcarrier_mat n n. partial_density_operator ρ  trace (A * ρ)  trace (B * ρ))"
proof (rule iffI)
  have dimBmA: "B - A  carrier_mat n n" using A B by auto
  {
    assume "A L B"
    then have pBmA: "positive (B - A)" using lowner_le_def by auto
    moreover have "B - A  carrier_mat n n" using assms by auto
    ultimately have "Mcarrier_mat n n. M * adjoint M = B - A" using positive_iff_decomp[of "B - A"] by auto
    then obtain M where dimM: "M  carrier_mat n n" and M: "M * adjoint M = B - A" by auto
    {
      fix ρ assume dimr: "ρ  carrier_mat n n" and pdr: "partial_density_operator ρ"
      have eq: "trace(B * ρ) - trace(A * ρ) = trace((B - A) * ρ)" using A B dimr
        apply (subst minus_mult_distrib_mat, auto)
        apply (subst trace_minus_linear, auto)
        done
      have pr: "positive ρ" using pdr partial_density_operator_def by auto
      then have "Pcarrier_mat n n. ρ = P * adjoint P" using positive_iff_decomp dimr by auto
      then obtain P where dimP: "P  carrier_mat n n" and P: "ρ = P * adjoint P" by auto
      have "trace((B - A) * ρ) = trace(M * adjoint M * (P * adjoint P))" using P M by auto
      also have " = trace((adjoint P * M) * adjoint (adjoint P * M))" using dimM dimP by (mat_assoc n)
      also have "  0" using trace_adjoint_positive by auto
      finally have "trace((B - A) * ρ)  0".
      with eq have " trace (B * ρ) - trace (A * ρ)  0" by auto
    }
    then show "ρcarrier_mat n n. partial_density_operator ρ  trace (A * ρ)  trace (B * ρ)" by auto
  }

  {
    assume asm: "ρcarrier_mat n n. partial_density_operator ρ  trace (A * ρ)  trace (B * ρ)"
    have "positive (B - A)" 
    proof -
      {
        fix v assume "dim_vec (v::complex vec) = dim_col (B - A)  vec_norm v = 1"
        then have dimv: "v  carrier_vec n" and nv: "vec_norm v = 1"
          using carrier_matD[OF dimBmA] by (auto intro: carrier_vecI)
        have dimov: "outer_prod v v  carrier_mat n n" using dimv by auto
        then have "partial_density_operator (outer_prod v v)" 
          using dimv nv pure_state_self_outer_prod_is_partial_density_operator by auto
        then have leq: "trace(A * (outer_prod v v))  trace(B * (outer_prod v v))" using asm dimov by auto
        have "trace((B - A) * (outer_prod v v)) = trace(B * (outer_prod v v)) - trace(A * (outer_prod v v))" using A B dimov
          apply (subst minus_mult_distrib_mat, auto)
          apply (subst trace_minus_linear, auto)
          done
        then have "trace((B - A) * (outer_prod v v))  0" using leq by auto
        then have "inner_prod v ((B - A) *v v)  0" using trace_outer_prod_right[OF dimBmA dimv dimv] by auto
      }
      then show "positive (B - A)" using positive_iff_normalized_vec[of "B - A"] dimBmA A by simp
    qed
    then show "A L B" using lowner_le_def A B by auto
  }
qed

lemma lowner_le_traceI:
  assumes "A  carrier_mat n n"
    and "B  carrier_mat n n"
    and "ρ. ρ  carrier_mat n n  partial_density_operator ρ  trace (A * ρ)  trace (B * ρ)"
  shows "A L B"
  using lowner_le_trace assms by auto

lemma trace_pdo_eq_imp_eq:
  assumes A: "A  carrier_mat n n"
    and B: "B  carrier_mat n n"
    and teq: "ρ. ρ  carrier_mat n n  partial_density_operator ρ  trace (A * ρ) = trace (B * ρ)"
  shows "A = B"
proof -
  from teq have "A L B" using lowner_le_trace[OF A B] teq by auto
  moreover from teq have "B L A" using lowner_le_trace[OF B A] teq by auto
  ultimately show "A = B" using lowner_le_antisym A B by auto
qed

lemma lowner_le_traceD:
  assumes "A  carrier_mat n n" "B  carrier_mat n n" "ρ  carrier_mat n n"
    and "A L B"
    and "partial_density_operator ρ"
  shows "trace (A * ρ)  trace (B * ρ)"
  using lowner_le_trace assms by blast

lemma sum_only_one_neq_0:
  assumes "finite A" and "j  A" and "i. i  A  i  j  g i = 0"
  shows "sum g A = g j" 
proof -
  have "{j}  A" using assms by auto
  moreover have "iA - {j}. g i = 0" using assms by simp
  ultimately have "sum g A = sum g {j}" using assms 
    by (auto simp add: comm_monoid_add_class.sum.mono_neutral_right[of A "{j}" g])
  moreover have "sum g {j} = g j" by simp
  ultimately show ?thesis by auto
qed

end

File ‹mat_alg.ML›

(* Algebraic manipulations on matrices *)

(* Debugging: trace list of terms *)
fun string_of_terms ctxt ts =
    ts |> map (Syntax.pretty_term ctxt)
       |> Pretty.commas |> Pretty.block |> Pretty.string_of

(* Debugging: print term *)
fun trace_t ctxt s t =
    tracing (s ^ " " ^ (Syntax.string_of_term ctxt t))

(* Debugging: trace full theorem *)
fun trace_fullthm ctxt s th =
    tracing (s ^ " [" ^ (Thm.hyps_of th |> string_of_terms ctxt) ^
             "] ==> " ^ (Thm.prop_of th |> Syntax.string_of_term ctxt))

(* nat type *)
val natT = HOLogic.natT

(* Whether t is of the form a * b *)
fun is_times t =
  case t of
    Const (@{const_name times}, _) $ _ $ _ => true
  | _ => false

(* Whether t is of the form a + b *)
fun is_plus t =
  case t of
    Const (@{const_name plus}, _) $ _ $ _ => true
  | _ => false

(* Whether t is of the form a - b *)
fun is_minus t =
  case t of
    Const (@{const_name minus}, _) $ _ $ _ => true
  | _ => false

(* Whether t is of the form - a *)
fun is_uminus t =
  case t of
    Const (@{const_name uminus}, _) $ _ => true
  | _ => false

(* Given t of the form a (op) b, return the pair (a, b) *)
fun dest_binop t =
  case t of
    _ $ a $ b => (a, b)
  | _ => raise Fail "dest_binop"

(* Given t of the form f $ x, return the term x. *)
fun dest_arg t =
  case t of
    _ $ x => x
  | _ => raise Fail "dest_arg"

(* Return the first of two arguments of t. *)
fun dest_arg1 t =
  case t of
    _ $ arg1 $ _ => arg1
  | _ => raise Fail "dest_arg1"

(* Whether t is a matrix. *)
fun is_mat_type t =
  is_Type (fastype_of t) andalso
  (fastype_of t |> dest_Type |> fst) = "Matrix.mat"

(* Whether t is of the form c . A *)
fun is_smult_mat t =
  case t of
    Const (@{const_name smult_mat}, _) $ _ $ _ => true
  | _ => false

(* Whether t is of the form adjoint a *)
fun is_adjoint t =
  case t of
    Const (@{const_name mat_adjoint}, _) $ _ => true
  | _ => false

(* Whether t is of the form 1_m n *)
fun is_id_mat t =
  case t of
    Const (@{const_name one_mat}, _) $ _ => true
  | _ => false

(* Whether t is of the form 0_m n n *)
fun is_zero_mat t =
  case t of
    Const (@{const_name zero_mat}, _) $ _ $ _ => true
  | _ => false

(* Given a product in normal form, return the atomic components.
  E.g. strip_times (a * b * c * d) = [a, b, c, d]. *)
fun strip_times t =
  if is_times t then
    strip_times (dest_arg1 t) @ [dest_arg t]
  else
    [t]

(* Returns the term "carrier_mat n n", where t is a matrix providing the type. *)
fun carrier_mat n t =
  let
    val T = fastype_of t  (* 'a mat *)
    val Tset = HOLogic.mk_setT T  (* 'a mat set *)
  in
    Const (@{const_name carrier_mat}, natT --> natT --> Tset) $ n $ n
  end

(* Given n and t, returns the term "t : carrier n n" *)
fun mk_mem_carrier n t =
  HOLogic.mk_mem (t, carrier_mat n t)

(* Given n and t, returns the theorem [t : carrier n n]. t : carrier n n *)
fun assume_carrier ctxt n t =
  Thm.assume (Thm.cterm_of ctxt (HOLogic.mk_Trueprop (mk_mem_carrier n t)))

(* Given a term t, return t : carrier n n under the assumptions that the
  atomic components of t are in carrier n n.
  E.g. given t = a * b * c, returns
    [a : carrier n n, b : carrier n n, c : carrier n n]. a * b * c : carrier n n. *)
fun prod_in_carrier ctxt n t =
  if is_times t then
    let
      val (a, b) = dest_binop t
      val th1 = prod_in_carrier ctxt n a
      val th2 = prod_in_carrier ctxt n b
    in
      [th1, th2] MRS @{thm mult_carrier_mat}
    end
  else if is_plus t then
    let
      val (a, b) = dest_binop t
      val th1 = prod_in_carrier ctxt n a
      val th2 = prod_in_carrier ctxt n b
    in
      [th1, th2] MRS @{thm add_carrier_mat'}
    end
  else if is_uminus t then
    let
      val a = dest_arg t
      val th = prod_in_carrier ctxt n a
    in
      th RS @{thm uminus_carrier_mat}
    end
  else if is_minus t then
    let
      val (a, b) = dest_binop t
      val th1 = prod_in_carrier ctxt n a
      val th2 = prod_in_carrier ctxt n b
    in
      [th1, th2] MRS @{thm minus_carrier_mat'}
    end
  else if is_adjoint t then
    let
      val a = dest_arg t
      val th = prod_in_carrier ctxt n a
    in
      th RS @{thm adjoint_dim}
    end
  else if is_smult_mat t then
    let
      val a = dest_arg t
      val th = prod_in_carrier ctxt n a
    in
      th RS @{thm smult_carrier_mat}
    end
  else
    assume_carrier ctxt n t

(* Given theorem a = b, return theorem b = a *)
fun obj_sym th =
  th RS @{thm HOL.sym}

(* Given theorem a = b, return theorem a == b *)
fun to_meta_eq th =
  th RS @{thm HOL.eq_reflection}

(* Given theorem a == b, return theorem a = b *)
fun to_obj_eq th =
  th RS @{thm HOL.meta_eq_to_obj_eq}

fun rewr_cv ctxt n th ct =
  let
    val th = to_meta_eq th
    val pat = th |> Thm.concl_of |> dest_arg1 |> Thm.cterm_of ctxt
    val inst = Thm.match (pat, ct)
    val th = Thm.instantiate inst th
    val prems = map (fn prem => prod_in_carrier ctxt n (prem |> dest_arg |> dest_arg1))
                    (Thm.prems_of th)
  in
    prems MRS th
  end
  handle THM _ => let val _ = trace_fullthm ctxt "here" th in raise Fail "THM" end
     | Pattern.MATCH => let val _ = trace_fullthm ctxt "here" th in raise Fail "MATCH" end
  
(* Normalize (a_1 * ... * a_n) * (b_1 * ... * b_n) *)
fun assoc_times_norm ctxt n ct =
  let
    val t = Thm.term_of ct
    val (a, b) = dest_binop t
  in
    if is_smult_mat a then
      Conv.every_conv [
        rewr_cv ctxt n @{thm mult_smult_assoc_mat},
        Conv.arg_conv (assoc_times_norm ctxt n)] ct
    else if is_smult_mat b then
      Conv.every_conv [
        rewr_cv ctxt n @{thm mult_smult_distrib},
        Conv.arg_conv (assoc_times_norm ctxt n)] ct
    else if is_times b then
      Conv.every_conv [
        rewr_cv ctxt n (obj_sym @{thm assoc_mult_mat}),
        Conv.arg1_conv (assoc_times_norm ctxt n)] ct
    else if is_id_mat a then
      rewr_cv ctxt n @{thm left_mult_one_mat} ct
    else if is_id_mat b then
      rewr_cv ctxt n @{thm right_mult_one_mat} ct
    else
      Conv.all_conv ct
  end

(* Normalize (a_1 + ... + a_n) + b *)
fun assoc_plus_one_norm ctxt n ct =
  let
    val t = Thm.term_of ct
    val (a, b) = dest_binop t
  in
    if not (is_mat_type t) then
      Conv.all_conv ct
    else if is_plus a then
      if Term_Ord.term_ord (dest_arg a, b) = GREATER then
        Conv.every_conv [
          rewr_cv ctxt n @{thm swap_plus_mat},
          Conv.arg1_conv (assoc_plus_one_norm ctxt n)] ct
      else
        Conv.all_conv ct
    else
      if Term_Ord.term_ord (a, b) = GREATER then
        rewr_cv ctxt n @{thm comm_add_mat} ct
      else
        Conv.all_conv ct
  end

(* Normalize (a_1 + ... + a_n) + (b_1 + ... + b_n) *)
fun assoc_plus_norm ctxt n ct =
  let
    val t = Thm.term_of ct
    val (a, b) = dest_binop t
  in
    if not (is_mat_type t) then
      Conv.all_conv ct
    else if is_plus b then
      Conv.every_conv [
        rewr_cv ctxt n (obj_sym @{thm assoc_add_mat}),
        Conv.arg1_conv (assoc_plus_norm ctxt n),
        assoc_plus_one_norm ctxt n] ct
    else if is_zero_mat a then
      rewr_cv ctxt n @{thm left_add_zero_mat} ct
    else if is_zero_mat b then
      rewr_cv ctxt n @{thm right_add_zero_mat} ct
    else
      assoc_plus_one_norm ctxt n ct
  end

(* Normalization of c . (a_1 + ... + a_n) *)
fun smult_plus_norm ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_plus (dest_arg t) then
      Conv.every_conv [
        rewr_cv ctxt n @{thm add_smult_distrib_left_mat},
        Conv.arg1_conv (smult_plus_norm ctxt n)] ct
    else
      Conv.all_conv ct
  end

(* Normalize (a_1 + ... + a_n) * b *)
fun norm_mult_poly_monomial ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_plus (dest_arg1 t) then
      Conv.every_conv [
        rewr_cv ctxt n @{thm add_mult_distrib_mat},
        Conv.arg1_conv (norm_mult_poly_monomial ctxt n),
        Conv.arg_conv (assoc_times_norm ctxt n),
        assoc_plus_norm ctxt n] ct
    else
      assoc_times_norm ctxt n ct
  end

(* Normalize (a_1 + ... + a_n) * (b_1 + ... + b_n) *)
fun norm_mult_polynomials ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_plus (dest_arg t) then
      Conv.every_conv [
        rewr_cv ctxt n @{thm mult_add_distrib_mat},
        Conv.arg1_conv (norm_mult_polynomials ctxt n),
        Conv.arg_conv (norm_mult_poly_monomial ctxt n),
        assoc_plus_norm ctxt n] ct
    else
      norm_mult_poly_monomial ctxt n ct
  end   

fun is_trace t =
  case t of
    Const (@{const_name trace}, _) $ _ => true
  | _ => false

(* Normalize trace (a_1 * ... * a_n) *)
fun norm_trace_times ctxt n ct =
  let
    val tt = Thm.term_of ct
    val t = dest_arg tt
    val ts = strip_times t
    val (rest, last) = split_last ts
  in
    if exists (fn t' => Term_Ord.term_ord (last, t') = LESS) rest then
      Conv.every_conv [
        rewr_cv ctxt n @{thm trace_comm},
        Conv.arg_conv (assoc_times_norm ctxt n),
        norm_trace_times ctxt n] ct
    else
      Conv.all_conv ct
  end

(* Normalize trace (a_1 + ... + a_n) *)
fun norm_trace_plus ctxt n ct =
  let
    val tt = Thm.term_of ct
    val t = dest_arg tt
  in
    if is_plus t then
      Conv.every_conv [
        rewr_cv ctxt n @{thm trace_add_linear},
        Conv.arg1_conv (norm_trace_plus ctxt n),
        Conv.arg_conv (norm_trace_times ctxt n)] ct
    else
      norm_trace_times ctxt n ct
  end

(* Normalize with respect to associativity. *)
fun assoc_norm ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_times t then
      Conv.every_conv [
        Conv.binop_conv (assoc_norm ctxt n),
        norm_mult_polynomials ctxt n] ct
    else if is_plus t then
      Conv.every_conv [
        Conv.binop_conv (assoc_norm ctxt n),
        assoc_plus_norm ctxt n] ct
    else if is_smult_mat t then
      Conv.every_conv [
        Conv.arg_conv (assoc_norm ctxt n),
        smult_plus_norm ctxt n] ct
    else if is_minus t then
      Conv.every_conv [
        rewr_cv ctxt n @{thm minus_add_uminus_mat},
        assoc_norm ctxt n] ct
    else if is_uminus t then
      Conv.every_conv [
        rewr_cv ctxt n @{thm uminus_mat},
        assoc_norm ctxt n] ct
    else if is_adjoint t then
      if is_times (dest_arg t) then
        Conv.every_conv [
          rewr_cv ctxt n @{thm adjoint_mult},
          assoc_norm ctxt n] ct
      else if is_adjoint (dest_arg t) then
        Conv.every_conv [
          Conv.rewr_conv (to_meta_eq @{thm adjoint_adjoint}),
          assoc_norm ctxt n] ct
      else
        Conv.all_conv ct
    else if is_trace t then
      Conv.every_conv [
        Conv.arg_conv (assoc_norm ctxt n),
        norm_trace_plus ctxt n] ct
    else
      Conv.all_conv ct
  end

(* Given equality between two products of matrices, attempt to prove
  the equality by normalization.
  Example: given A * (B * C) = (A * B) * C, return the theorem stating
  the equality, with hypothesis A, B, C : carrier_mat n. *)
fun prove_by_assoc_norm ctxt n t =
  let
    val _ = trace_t ctxt "To show equation:" t
    val (a, b) = dest_binop t
    val norm1 = assoc_norm ctxt n (Thm.cterm_of ctxt a)
    val norm2 = assoc_norm ctxt n (Thm.cterm_of ctxt b)
  in
    if Thm.rhs_of norm1 aconvc Thm.rhs_of norm2 then
      let
        val res = Thm.transitive norm1 (Thm.symmetric norm2)
      in
        res |> to_obj_eq
      end
    else
      let
        val _ = trace_t ctxt "Left side is:" (Thm.term_of (Thm.rhs_of norm1))
        val _ = trace_t ctxt "Right side is:" (Thm.term_of (Thm.rhs_of norm2))
      in
        raise Fail "Normalization are not equal."
      end
  end

fun prove_by_assoc_norm_tac n ctxt state =
  let
    val n = Syntax.read_term ctxt n
    val subgoals = Thm.prems_of state
  in
    if null subgoals then Seq.empty else
      let
        (* Subgoal to be proved, in the form [A1, ..., An] ==> s = t *)
        val subgoal = state |> Drule.cprems_of |> hd
        val (cprems, cconcl) = (Drule.strip_imp_prems subgoal, Drule.strip_imp_concl subgoal)
        val concl = HOLogic.dest_Trueprop (Thm.term_of cconcl)

        (* Theorem A1 ==> ... ==> An ==> s = t, but possibly with addtional
           assumptions in Thm.hyps_of subgoal_th *)
        val subgoal_th = fold Thm.implies_intr (rev cprems) (prove_by_assoc_norm ctxt n concl)

        val chyps = Thm.chyps_of subgoal_th
        val res = Thm.implies_elim state subgoal_th
      in
        Seq.single (fold Thm.implies_intr chyps res)
      end
  end

val mat_assoc_method : (Proof.context -> Method.method) context_parser =
  Scan.lift Parse.term >> (fn n => fn ctxt => (SIMPLE_METHOD (prove_by_assoc_norm_tac n ctxt)))

Theory Matrix_Limit

section ‹Matrix limits›

theory Matrix_Limit
  imports Complex_Matrix
begin

subsection ‹Definition of limit of matrices›

definition limit_mat :: "(nat  complex mat)  complex mat  nat  bool" where
  "limit_mat X A m  ( n. X n  carrier_mat m m  A  carrier_mat m m 
                       ( i < m.  j < m. (λ n. (X n) $$ (i, j))  (A $$ (i, j))))"

lemma limit_mat_unique:
  assumes limA: "limit_mat X A m" and limB: "limit_mat X B m"
  shows "A = B"
proof -
  have dim: "A  carrier_mat m m" "B  carrier_mat m m" using limA limB limit_mat_def by auto
  {
    fix i j assume i: "i < m" and j: "j < m"
    have "(λ n. (X n) $$ (i, j))  (A $$ (i, j))" using limit_mat_def limA i j by auto
    moreover have "(λ n. (X n) $$ (i, j))  (B $$ (i, j))" using limit_mat_def limB i j by auto
    ultimately have "(A $$ (i, j)) = (B $$ (i, j))" using LIMSEQ_unique by auto
  }
  then show "A = B" using mat_eq_iff dim by auto
qed

lemma limit_mat_const:
  fixes A :: "complex mat"
  assumes "A  carrier_mat m m"
  shows "limit_mat (λk. A) A m"
  unfolding limit_mat_def using assms by auto

lemma limit_mat_scale:
  fixes X :: "nat  complex mat" and A :: "complex mat"
  assumes limX: "limit_mat X A m"
  shows "limit_mat (λn. c m X n) (c m A) m"
proof -
  have dimA: "A  carrier_mat m m" using limX limit_mat_def by auto
  have dimX: "n. X n  carrier_mat m m" using limX unfolding limit_mat_def by auto
  have "i j. i < m  j < m  (λn. (c m X n) $$ (i, j))  (c m A) $$ (i, j)"
  proof -
    fix i j assume i: "i < m" and j: "j < m"
    have "(λn. (X n) $$ (i, j))  A$$(i, j)" using limX limit_mat_def i j by auto
    moreover have "(λn. c)  c" by auto
    ultimately have "(λn. c * (X n) $$ (i, j))  c * A$$(i, j)"
      using tendsto_mult[of "λn. c" c] limX limit_mat_def by auto
    moreover have "(c m X n) $$ (i, j) = c * (X n) $$ (i, j)" for n
      using index_smult_mat(1)[of i "X n" j c] i j dimX[of n] by auto
    moreover have "(c m A) $$ (i, j) = c * A $$ (i, j)"
      using index_smult_mat(1)[of i "A" j c] i j dimA by auto
    ultimately show "(λn. (c m X n) $$ (i, j))  (c m A) $$ (i, j)" by auto
  qed
  then show ?thesis unfolding limit_mat_def using dimA dimX by auto
qed

lemma limit_mat_add:
  fixes X :: "nat  complex mat" and Y :: "nat  complex mat" and A :: "complex mat"
    and m :: nat and B :: "complex mat"
  assumes limX: "limit_mat X A m" and limY: "limit_mat Y B m"
  shows "limit_mat (λk. X k + Y k) (A + B) m"
proof -
  have dimA: "A  carrier_mat m m" using limX limit_mat_def by auto
  have dimB: "B  carrier_mat m m" using limY limit_mat_def by auto
  have dimX: "n. X n  carrier_mat m m" using limX unfolding limit_mat_def by auto
  have dimY: "n. Y n  carrier_mat m m" using limY unfolding limit_mat_def by auto
  then have dimXAB: "n. X n + Y n  carrier_mat m m  A + B  carrier_mat m m" using dimA dimB dimX dimY
    by (simp)

  have "(i j. i < m  j < m  (λn. (X n + Y n) $$ (i, j))  (A + B) $$ (i, j))"
  proof -
    fix i j assume i: "i < m" and j: "j < m"
    have "(λn. (X n) $$ (i, j))  A$$(i, j)" using limX limit_mat_def i j by auto
    moreover have "(λn. (Y n) $$ (i, j))  B$$(i, j)" using limY limit_mat_def i j by auto
    ultimately have "(λn. (X n)$$(i, j) + (Y n) $$ (i, j))  (A$$(i, j) + B$$(i, j))"
      using tendsto_add[of "λn. (X n) $$ (i, j)" "A $$ (i, j)"] by auto
    moreover have "(X n + Y