Session BenOr_Kozen_Reif

Theory More_Matrix

theory More_Matrix
  imports "Jordan_Normal_Form.Matrix"
    "Jordan_Normal_Form.DL_Rank"
    "Jordan_Normal_Form.VS_Connect"
    "Jordan_Normal_Form.Gauss_Jordan_Elimination"
begin

section "Kronecker Product"

definition kronecker_product :: "'a :: ring mat  'a mat  'a mat" where
  "kronecker_product A B =
  (let ra = dim_row A; ca = dim_col A;
       rb = dim_row B; cb = dim_col B
  in
    mat (ra*rb) (ca*cb)
    (λ(i,j).
      A $$ (i div rb, j div cb) *
      B $$ (i mod rb, j mod cb)
  ))"

lemma arith:
  assumes "d < a"
  assumes "c < b"
  shows "b*d+c < a*(b::nat)"
proof -
  have "b*d+c < b*(d+1)"
    by (simp add: assms(2))
  thus ?thesis
    by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right assms(1) less_le_trans mult.commute mult_le_cancel2)
qed

lemma dim_kronecker[simp]:
  "dim_row (kronecker_product A B) = dim_row A * dim_row B"
  "dim_col (kronecker_product A B) = dim_col A * dim_col B"
  unfolding kronecker_product_def Let_def by auto

lemma kronecker_inverse_index:
  assumes "r < dim_row A" "s < dim_col A"
  assumes "v < dim_row B" "w < dim_col B"
  shows "kronecker_product A B $$ (dim_row B*r+v, dim_col B*s+w) = A $$ (r,s) * B $$ (v,w)"
proof -
  from arith[OF assms(1) assms(3)]
  have "dim_row B*r+v < dim_row A * dim_row B" .
  moreover from arith[OF assms(2) assms(4)]
  have "dim_col B * s + w < dim_col A * dim_col B" .
  ultimately show ?thesis
    unfolding kronecker_product_def Let_def
    using assms by auto
qed

lemma kronecker_distr_left:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product A (B+C) = kronecker_product A B + kronecker_product A C"
  unfolding kronecker_product_def Let_def
  using assms apply (auto simp add: mat_eq_iff) 
  by (metis (no_types, lifting) distrib_left index_add_mat(1) mod_less_divisor mult_eq_0_iff neq0_conv not_less_zero)

lemma kronecker_distr_right:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product (B+C) A = kronecker_product B A + kronecker_product C A"
  unfolding kronecker_product_def Let_def
  using assms by (auto simp add: mat_eq_iff less_mult_imp_div_less distrib_right)

lemma index_mat_mod[simp]: "nr > 0 & nc > 0  mat nr nc f $$ (i mod nr,j mod nc) = f (i mod nr,j mod nc)"
  by auto

lemma kronecker_assoc:
  shows "kronecker_product A (kronecker_product B C) = kronecker_product (kronecker_product A B) C"
  unfolding kronecker_product_def Let_def
  apply (case_tac "dim_row B * dim_row C > 0 & dim_col B * dim_col C > 0")
   apply (auto simp add: mat_eq_iff less_mult_imp_div_less)
  by (smt div_mult2_eq div_mult_mod_eq kronecker_inverse_index less_mult_imp_div_less linordered_semiring_strict_class.mult_pos_pos mod_less_divisor mod_mult2_eq mult.assoc mult.commute)

lemma sum_sum_mod_div:
  "(ia = 0::nat..<x. ja = 0..<y. f ia ja) =
   (ia = 0..<x*y. f (ia div y) (ia mod y))"
proof -
  have 1: "inj_on (λia. (ia div y, ia mod y)) {0..<x * y}"
    by (smt (verit, best) Pair_inject div_mod_decomp inj_onI)
  have 21: "{0..<x} × {0..<y}  (λia. (ia div y, ia mod y)) ` {0..<x * y}"
  proof clarsimp
    fix a b
    assume *:"a < x" "b < y"
    have "a * y +  b  {0..<x*y}"
      by (metis arith * atLeastLessThan_iff le0 mult.commute)
    thus "(a, b)  (λia. (ia div y, ia mod y)) ` {0..<x * y}"
      by (metis (no_types, lifting) "*"(2) Euclidean_Division.div_eq_0_iff add_cancel_right_right div_mult_self3 gr_implies_not0 image_iff mod_less mod_mult_self3)
  qed
  have 22:"(λia. (ia div y, ia mod y)) ` {0..<x * y}  {0..<x} × {0..<y}"
    using less_mult_imp_div_less apply auto
    by (metis mod_less_divisor mult.commute neq0_conv not_less_zero)
  have 2: "{0..<x} × {0..<y} = (λia. (ia div y, ia mod y)) ` {0..<x * y}"
    using 21 22 by auto
  have *: "(ia = 0::nat..<x. ja = 0..<y. f ia ja) =
        ((x, y){0..<x} × {0..<y}. f x y)"
    by (auto simp add: sum.cartesian_product)
  show ?thesis unfolding *
    apply (intro sum.reindex_cong[of "λia. (ia div y, ia mod y)"])
    using 1 2 by auto
qed

(* Kronecker product distributes over matrix multiplication *)
lemma kronecker_of_mult:
  assumes "dim_col (A :: 'a :: comm_ring mat) = dim_row C"
  assumes "dim_col B = dim_row D"
  shows "kronecker_product A B * kronecker_product C D = kronecker_product (A * C) (B * D)"
  unfolding kronecker_product_def Let_def mat_eq_iff
proof clarsimp
  fix i j
  assume ij: "i < dim_row A * dim_row B" "j < dim_col C * dim_col D"
  have 1: "(A * C) $$ (i div dim_row B, j div dim_col D) =
    row A (i div dim_row B)  col C (j div dim_col D)"
    using ij less_mult_imp_div_less by (auto intro!: index_mult_mat)
  have 2: "(B * D) $$ (i mod dim_row B, j mod dim_col D) =
    row B (i mod dim_row B)  col D (j mod dim_col D)"
    using ij apply (auto intro!: index_mult_mat)
    using gr_implies_not0 apply fastforce
    using gr_implies_not0 by fastforce
  have 3: "x. x < dim_row C * dim_row D 
         A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))"
  proof -
    fix x
    assume *:"x < dim_row C * dim_row D"
    have 1: "row A (i div dim_row B) $ (x div dim_row D) = A $$ (i div dim_row B, x div dim_row D)"
      by (simp add: * assms(1) less_mult_imp_div_less row_def)
    have 2: "row B (i mod dim_row B) $ (x mod dim_row D) = B $$ (i mod dim_row B, x mod dim_row D)"
      by (metis "*" assms(2) ij(1) index_row(1) mod_less_divisor nat_0_less_mult_iff neq0_conv not_less_zero)
    have 3: "col C (j div dim_col D) $ (x div dim_row D) = C $$ (x div dim_row D, j div dim_col D)"
      by (simp add: "*" ij(2) less_mult_imp_div_less)
    have 4: "col D (j mod dim_col D) $ (x mod dim_row D) = D $$ (x mod dim_row D, j mod dim_col D)"
      by (metis "*" Euclidean_Division.div_eq_0_iff gr_implies_not0 ij(2) index_col mod_div_trivial mult_not_zero)
    show "A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))" unfolding 1 2 3 4
      by (simp add: mult.assoc mult.left_commute)
  qed
  have *: "(A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D) =
    (ia = 0..<dim_row C * dim_row D.
               A $$ (i div dim_row B, ia div dim_row D) *
               B $$ (i mod dim_row B, ia mod dim_row D) *
               (C $$ (ia div dim_row D, j div dim_col D) *
                D $$ (ia mod dim_row D, j mod dim_col D)))"
    unfolding 1 2 scalar_prod_def sum_product sum_sum_mod_div
    apply (auto simp add: sum_product sum_sum_mod_div intro!: sum.cong)
    using 3 by presburger
  show "vec (dim_col A * dim_col B)
          (λj. A $$ (i div dim_row B, j div dim_col B) *
               B $$ (i mod dim_row B, j mod dim_col B)) 
       vec (dim_row C * dim_row D)
          (λi. C $$ (i div dim_row D, j div dim_col D) *
               D $$ (i mod dim_row D, j mod dim_col D)) =
        (A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D)"
    unfolding * scalar_prod_def
    by (auto simp add: assms sum_product sum_sum_mod_div intro!: sum.cong)
qed

lemma inverts_mat_length:
  assumes "square_mat A" "inverts_mat A B" "inverts_mat B A"
  shows "dim_row B = dim_row A" "dim_col B = dim_col A"
   apply (metis assms(1) assms(3) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)
  by (metis assms(1) assms(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)

lemma less_mult_imp_mod_less:
  "m mod i < i" if "m < n * i" for m n i :: nat
  using gr_implies_not_zero that by fastforce

lemma kronecker_one:
  shows "kronecker_product ((1m x)::'a :: ring_1 mat) (1m y) = 1m (x*y)"
  unfolding kronecker_product_def Let_def
  apply  (auto simp add:mat_eq_iff less_mult_imp_div_less less_mult_imp_mod_less)
  by (metis div_mult_mod_eq)

lemma kronecker_invertible:
  assumes "invertible_mat (A :: 'a :: comm_ring_1 mat)" "invertible_mat B"
  shows "invertible_mat (kronecker_product A B)"
proof -
  obtain Ai where Ai: "inverts_mat A Ai" "inverts_mat Ai A" using assms invertible_mat_def by blast
  obtain Bi where Bi: "inverts_mat B Bi" "inverts_mat Bi B" using assms invertible_mat_def by blast
  have "square_mat (kronecker_product A B)"
    by (metis (no_types, lifting) assms(1) assms(2) dim_col_mat(1) dim_row_mat(1) invertible_mat_def kronecker_product_def square_mat.simps)
  moreover have "inverts_mat (kronecker_product A B) (kronecker_product Ai Bi)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  moreover have "inverts_mat (kronecker_product Ai Bi) (kronecker_product A B)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  ultimately show ?thesis unfolding invertible_mat_def by blast
qed

section "More DL Rank"

(* conjugate matrices *)
instantiation mat :: (conjugate) conjugate
begin

definition conjugate_mat :: "'a :: conjugate mat  'a mat"
  where "conjugate m = mat (dim_row m) (dim_col m) (λ(i,j). conjugate (m $$ (i,j)))"

lemma dim_row_conjugate[simp]: "dim_row (conjugate m) = dim_row m"
  unfolding conjugate_mat_def by auto

lemma dim_col_conjugate[simp]: "dim_col (conjugate m) = dim_col m"
  unfolding conjugate_mat_def by auto

lemma carrier_vec_conjugate[simp]: "m  carrier_mat nr nc  conjugate m  carrier_mat nr nc"
  by (auto)

lemma mat_index_conjugate[simp]:
  shows "i < dim_row m  j < dim_col m  conjugate m  $$ (i,j) = conjugate (m $$ (i,j))"
  unfolding conjugate_mat_def by auto

lemma row_conjugate[simp]: "i < dim_row m  row (conjugate m) i = conjugate (row m i)"
  by (auto)

lemma col_conjugate[simp]: "i < dim_col m  col (conjugate m) i = conjugate (col m i)"
  by (auto)

lemma rows_conjugate: "rows (conjugate m) = map conjugate (rows m)"
  by (simp add: list_eq_iff_nth_eq)

lemma cols_conjugate: "cols (conjugate m) = map conjugate (cols m)"
  by (simp add: list_eq_iff_nth_eq)

instance
proof
  fix a b :: "'a mat"
  show "conjugate (conjugate a) = a"
    unfolding mat_eq_iff by auto
  let ?a = "conjugate a"
  let ?b = "conjugate b"
  show "conjugate a = conjugate b  a = b"
    by (metis dim_col_conjugate dim_row_conjugate mat_index_conjugate conjugate_cancel_iff mat_eq_iff)
qed

end

abbreviation conjugate_transpose :: "'a::conjugate mat   'a mat"
  where "conjugate_transpose A  conjugate (AT)"

notation conjugate_transpose ("(_H)" [1000])

lemma transpose_conjugate:
  shows "(conjugate A)T = AH"
  unfolding conjugate_mat_def
  by auto

lemma vec_module_col_helper:
  fixes A:: "('a :: field) mat"
  shows "(0v (dim_row A)  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A)))"
proof -
  have "v. (0::'a) v v + v = v"
    by auto
  then show "0v (dim_row A)  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
    by (metis cols_dim module_vec_def right_zero_vec smult_carrier_vec vec_space.prod_in_span zero_carrier_vec)
qed

lemma vec_module_col_helper2:
  fixes A:: "('a :: field) mat"
  shows "a x. x  LinearCombinations.module.span class_ring
                carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                   zero = 0v (dim_row A), add = (+), smult = (⋅v)
                (set (cols A)) 
           (a b v. (a + b) v v = a v v + b v v) 
           a v x
            LinearCombinations.module.span class_ring
               carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                  zero = 0v (dim_row A), add = (+), smult = (⋅v)
               (set (cols A))"
proof -
  fix a :: 'a and x :: "'a vec"
  assume "x  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
  then show "a v x  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
    by (metis (full_types) cols_dim idom_vec.smult_in_span module_vec_def)
qed

lemma vec_module_col: "module (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) 
    (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A)))"
proof -
  interpret abelian_group "module_vec TYPE('a) (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A))"
    apply (unfold_locales)
          apply (auto simp add:module_vec_def)
          apply (metis cols_dim module_vec_def partial_object.select_convs(1) ring.simps(2) vec_vs vectorspace.span_add1)
         apply (metis assoc_add_vec cols_dim module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    using vec_module_col_helper[of A] apply (auto)    
       apply (metis cols_dim left_zero_vec module_vec_def partial_object.select_convs(1) vec_vs vectorspace.span_closed)
      apply (metis cols_dim module_vec_def partial_object.select_convs(1) right_zero_vec vec_vs vectorspace.span_closed)
     apply (metis cols_dim comm_add_vec module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    unfolding Units_def apply auto
    by (smt cols_dim module_vec_def partial_object.select_convs(1) uminus_l_inv_vec uminus_r_inv_vec vec_space.vec_neg vec_vs vectorspace.span_closed vectorspace.span_neg)
  show ?thesis
    apply (unfold_locales)
    unfolding class_ring_simps apply auto
    unfolding module_vec_simps using add_smult_distrib_vec apply auto
     apply (auto simp add:module_vec_def)
    using vec_module_col_helper2
     apply blast
    using cols_dim module_vec_def partial_object.select_convs(1) smult_add_distrib_vec vec_vs vectorspace.span_closed
    by (smt (z3))
qed

(* The columns of a matrix form a vectorspace *)
lemma vec_vs_col: "vectorspace (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring
          (module_vec TYPE('a)
            (dim_row A))
          (set (cols A)))"
  unfolding vectorspace_def
  using vec_module_col class_field 
  by (auto simp: class_field_def)

lemma cols_mat_mul_map:
  shows "cols (A * B) = map ((*v) A) (cols B)"
  unfolding list_eq_iff_nth_eq
  by auto

lemma cols_mat_mul:
  shows "set (cols (A * B)) = (*v) A ` set (cols B)"
  by (simp add: cols_mat_mul_map)

lemma set_obtain_sublist:
  assumes "S  set ls"
  obtains ss where "distinct ss" "S = set ss"
  using assms finite_distinct_list infinite_super by blast

lemma mul_mat_of_cols:
  assumes "A  carrier_mat nr n"
  assumes "j. j < length cs  cs ! j  carrier_vec n"
  shows "A * (mat_of_cols n cs) = mat_of_cols nr (map ((*v) A) cs)"
  unfolding mat_eq_iff
  using assms apply auto
  apply (subst mat_of_cols_index)
  by auto

lemma helper:
  fixes x y z ::"'a :: {conjugatable_ring, comm_ring}"
  shows "x * (y * z) = y * x * z"
  by (simp add: mult.assoc mult.left_commute)

lemma cscalar_prod_conjugate_transpose:
  fixes x y ::"'a :: {conjugatable_ring, comm_ring} vec"
  assumes "A  carrier_mat nr nc"
  assumes "x  carrier_vec nr"
  assumes "y  carrier_vec nc"
  shows "x ∙c (A *v y) = (AH *v x) ∙c y"
  unfolding mult_mat_vec_def scalar_prod_def
  using assms apply (auto simp add: sum_distrib_left sum_distrib_right sum_conjugate conjugate_dist_mul)
  apply (subst sum.swap)
  by (meson helper mult.assoc mult.left_commute sum.cong)

lemma mat_mul_conjugate_transpose_vec_eq_0:                        
  fixes v ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} vec"
  assumes "A  carrier_mat nr nc"
  assumes "v  carrier_vec nr"
  assumes "A *v (AH *v v) = 0v nr"
  shows "AH *v v = 0v nc"
proof -
  have "(AH *v v) ∙c (AH *v v) = (A *v (AH *v v)) ∙c v"
    by (metis (mono_tags, lifting) Matrix.carrier_vec_conjugate assms(1) assms(2) assms(3) carrier_matD(2) conjugate_zero_vec cscalar_prod_conjugate_transpose dim_row_conjugate index_transpose_mat(2) mult_mat_vec_def scalar_prod_left_zero scalar_prod_right_zero vec_carrier)
  also have "... = 0"
    by (simp add: assms(2) assms(3))
      (* this step requires real entries *)
  ultimately have "(AH *v v) ∙c (AH *v v) = 0" by auto
  thus ?thesis
    apply (subst conjugate_square_eq_0_vec[symmetric])
    using assms(1) carrier_dim_vec apply fastforce
    by auto
qed

lemma row_mat_of_cols:
  assumes "i < nr"
  shows "row (mat_of_cols nr ls) i = vec (length ls) (λj. (ls ! j) $i)"
  by (smt assms dim_vec eq_vecI index_row(1) index_row(2) index_vec mat_of_cols_carrier(2) mat_of_cols_carrier(3) mat_of_cols_index)

lemma mat_of_cols_cons_mat_vec:
  fixes v ::"'a::comm_ring vec"
  assumes "v  carrier_vec (length ls)"
  assumes "dim_vec a = nr"
  shows
    "mat_of_cols nr (a # ls) *v (vCons m v) =
   m v a + mat_of_cols nr ls *v v"
  unfolding mult_mat_vec_def vec_eq_iff
  using assms by
    (auto simp add: row_mat_of_cols vec_Suc o_def mult.commute)

lemma smult_vec_zero:
  fixes v ::"'a::ring vec"
  shows "0 v v = 0v (dim_vec v)"
  unfolding smult_vec_def vec_eq_iff
  by (auto)

lemma helper2:
  fixes A ::"'a::comm_ring mat"
  fixes v ::"'a vec"
  assumes "v  carrier_vec (length ss)"
  assumes "x. x  set ls  dim_vec x = nr"
  shows
    "mat_of_cols nr ss *v v =
   mat_of_cols nr (ls @ ss) *v (0v (length ls) @v v)"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case apply (auto simp add:zero_vec_Suc)
    apply (subst mat_of_cols_cons_mat_vec)
    by (auto simp add:assms smult_vec_zero)
qed

lemma mat_of_cols_mult_mat_vec_permute_list:
  fixes v ::"'a::comm_ring list"
  assumes "f permutes {..<length ss}"
  assumes "length ss = length v"
  shows
    "mat_of_cols nr (permute_list f ss) *v vec_of_list (permute_list f v) =
     mat_of_cols nr ss *v vec_of_list v"
  unfolding mat_of_cols_def mult_mat_vec_def vec_eq_iff scalar_prod_def
proof clarsimp
  fix i
  assume "i < nr"
  from sum.permute[OF assms(1)]
  have "(ia<length ss. ss ! f ia $ i * v ! f ia) =
  sum ((λia. ss ! f ia $ i * v ! f ia)  f) {..<length ss}" .
  also have "... = (ia = 0..<length v. ss ! f ia $ i * v ! f ia)"
    using assms(2) calculation lessThan_atLeast0 by auto
  ultimately have *: "(ia = 0..<length v.
             ss ! f ia $ i * v ! f ia) =
         (ia = 0..<length v.
             ss ! ia $ i * v ! ia)"
    by (metis (mono_tags, lifting) g. sum g {..<length ss} = sum (g  f) {..<length ss} assms(2) comp_apply lessThan_atLeast0 sum.cong)
  show "(ia = 0..<length v.
         vec (length ss) (λj. permute_list f ss ! j $ i) $ ia *
         vec_of_list (permute_list f v) $ ia) =
         (ia = 0..<length v. vec (length ss) (λj. ss ! j $ i) $ ia * vec_of_list v $ ia)"
    using assms * by (auto simp add: permute_list_nth vec_of_list_index)
qed

(* permute everything in a subset of the indices to the back *)
lemma subindex_permutation:
  assumes "distinct ss" "set ss  {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "permute_list f ls = map ((!) ls) (filter (λi. i  set ss) [0..<length ls]) @ map ((!) ls) ss"
proof -
  have "set [0..<length ls] = set (filter (λi. i  set ss) [0..<length ls] @ ss)"
    using assms unfolding multiset_eq_iff by auto
  then have "mset [0..<length ls] = mset (filter (λi. i  set ss) [0..<length ls] @ ss)"
    apply (subst set_eq_iff_mset_eq_distinct[symmetric])
    using assms by auto  
  then have "mset ls = mset (map ((!) ls)
           (filter (λi. i  set ss)
             [0..<length ls]) @ map ((!) ls) ss)"
    by (smt length_map map_append map_nth mset_eq_permutation mset_permute_list permute_list_map)
  thus ?thesis
    by (metis mset_eq_permutation that)
qed

lemma subindex_permutation2:
  assumes "distinct ss" "set ss  {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "ls = permute_list f (map ((!) ls) (filter (λi. i  set ss) [0..<length ls]) @ map ((!) ls) ss)"
  using subindex_permutation
  by (metis assms(1) assms(2) length_permute_list mset_eq_permutation mset_permute_list)

lemma distinct_list_subset_nths:
  assumes "distinct ss" "set ss  set ls"
  obtains ids where "distinct ids" "set ids  {..<length ls}" "ss = map ((!) ls) ids"
proof -
  let ?ids = "map (λi. @j. j < length ls  ls!j = i ) ss"
  have 1: "distinct ?ids" unfolding distinct_map
    using assms apply (auto simp add: inj_on_def)
    by (smt in_mono in_set_conv_nth tfl_some)
  have 2: "set ?ids  {..<length ls}"
    using assms apply (auto)
    by (metis (mono_tags, lifting) in_mono in_set_conv_nth tfl_some)
  have 3: "ss = map ((!) ls) ?ids"
    using assms apply (auto simp add: list_eq_iff_nth_eq)
    by (smt imageI in_set_conv_nth subset_iff tfl_some)
  show "(ids. distinct ids 
            set ids  {..<length ls} 
            ss = map ((!) ls) ids  thesis) 
    thesis" using 1 2 3 by blast
qed

lemma helper3: 
  fixes A ::"'a::comm_ring mat"
  assumes A: "A  carrier_mat nr nc"
  assumes ss:"distinct ss" "set ss  set (cols A)"
  assumes "v  carrier_vec (length ss)"
  obtains c where "mat_of_cols nr ss *v v = A *v c" "dim_vec c = nc"
proof -
  from distinct_list_subset_nths[OF ss]
  obtain ids where ids: "distinct ids" "set ids  {..<length (cols A)}"
    and ss: "ss = map ((!) (cols A)) ids" by blast
  let ?ls = " map ((!) (cols A)) (filter (λi. i  set ids) [0..<length (cols A)])"
  from subindex_permutation2[OF ids] obtain f where
    f: "f permutes {..<length (cols A)}"
    "cols A = permute_list f (?ls @ ss)" using ss by blast
  have *: "x. x  set ?ls  dim_vec x = nr"
    using A by auto
  let ?cs1 = "(list_of_vec (0v (length ?ls) @v v))"
  from helper2[OF assms(4) ]
  have "mat_of_cols nr ss *v v = mat_of_cols nr (?ls @ ss) *v vec_of_list (?cs1)"
    using *
    by (metis vec_list)
  also have "... = mat_of_cols nr (permute_list f (?ls @ ss)) *v vec_of_list (permute_list f ?cs1)"
    apply (auto intro!: mat_of_cols_mult_mat_vec_permute_list[symmetric])
     apply (metis cols_length f(1) f(2) length_append length_map length_permute_list)
    using assms(4) by auto
  also have "... =  A *v vec_of_list (permute_list f ?cs1)" using f(2) assms by auto
  ultimately show
    "(c. mat_of_cols nr ss *v v = A *v c  dim_vec c = nc  thesis)  thesis"
    by (metis A assms(4) carrier_matD(2) carrier_vecD cols_length dim_vec_of_list f(2) index_append_vec(2) index_zero_vec(2) length_append length_list_of_vec length_permute_list)
qed

lemma mat_mul_conjugate_transpose_sub_vec_eq_0:                        
  fixes A ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} mat"
  assumes "A  carrier_mat nr nc"
  assumes "distinct ss" "set ss  set (cols (AH))"
  assumes "v  carrier_vec (length ss)"
  assumes "A *v (mat_of_cols nc ss *v v) = 0v nr"
  shows "(mat_of_cols nc ss *v v) = 0v nc"
proof -
  have "AH  carrier_mat nc nr" using assms(1) by auto
  from  helper3[OF this assms(2-4)]
  obtain c where c: "mat_of_cols nc ss *v v = AH *v c" "dim_vec c = nr" by blast
  have 1: "c  carrier_vec nr"
    using c carrier_vec_dim_vec by blast
  have 2: "A *v (AH *v c) = 0v nr" using c assms(5) by auto
  from mat_mul_conjugate_transpose_vec_eq_0[OF assms(1) 1 2]
  have "AH *v c = 0v nc" .
  thus ?thesis unfolding c(1)[symmetric] .
qed

lemma Units_invertible:
  fixes A:: "'a::semiring_1 mat"
  assumes "A  Units (ring_mat TYPE('a) n b)"
  shows "invertible_mat A"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  using inverts_mat_def by blast

lemma invertible_Units:
  fixes A:: "'a::semiring_1 mat"
  assumes "invertible_mat A"
  shows "A  Units (ring_mat TYPE('a) (dim_row A) b)"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  by (metis assms carrier_mat_triv invertible_mat_def inverts_mat_def inverts_mat_length(1) inverts_mat_length(2))

lemma invertible_det:
  fixes A:: "'a::field mat"
  assumes "A  carrier_mat n n"
  shows "invertible_mat A  det A  0"
  apply auto
  using invertible_Units unit_imp_det_non_zero apply fastforce
  using assms by (auto intro!: Units_invertible det_non_zero_imp_unit)

context vec_space begin

lemma find_indices_distinct:
  assumes "distinct ss"
  assumes "i < length ss"
  shows "find_indices (ss ! i) ss = [i]"
proof -
  have "set (find_indices (ss ! i) ss) = {i}"
    using assms apply auto by (simp add: assms(1) assms(2) nth_eq_iff_index_eq)
  thus ?thesis
    by (metis distinct.simps(2) distinct_find_indices empty_iff empty_set insert_iff list.exhaust list.simps(15)) 
qed

lemma lin_indpt_lin_comb_list:
  assumes "distinct ss"
  assumes "lin_indpt (set ss)"
  assumes "set ss  carrier_vec n"
  assumes "lincomb_list f ss = 0v n"
  assumes "i < length ss"
  shows "f i = 0"
proof -
  from lincomb_list_as_lincomb[OF assms(3)]
  have "lincomb_list f ss = lincomb (mk_coeff ss f) (set ss)" .
  also have "... = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)"
    unfolding mk_coeff_def
    apply (subst R.sumlist_map_as_finsum)
    by (auto simp add: distinct_find_indices)
  ultimately have "lincomb_list f ss = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)" by auto
  then have *:"lincomb (λv. sum f (set (find_indices v ss))) (set ss) = 0v n"
    using assms(4) by auto
  have "finite (set ss)" by simp
  from not_lindepD[OF assms(2) this _ _ *]
  have "(λv. sum f (set (find_indices v ss)))  set ss  {0}"
    by auto
  from funcset_mem[OF this]
  have "sum f (set (find_indices (nth ss i) ss))  {0}"
    using assms(5) by auto
  thus ?thesis unfolding find_indices_distinct[OF assms(1) assms(5)]
    by auto
qed

(* Note: in this locale dim_row A = n, e.g.:
lemma foo:
  assumes "dim_row A = n"
  shows "rank A = vec_space.rank (dim_row A) A"
  by (simp add: assms) *)

lemma span_mat_mul_subset:
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "span (set (cols (A * B)))  span (set (cols A))"
proof -
  have *: "v. ca. lincomb_list v (cols (A * B)) =
              lincomb_list ca  (cols A)"
  proof -
    fix v
    have "lincomb_list v (cols (A * B)) = (A * B) *v vec nc v"
      apply (subst lincomb_list_as_mat_mult)
       apply (metis assms(1) carrier_dim_vec carrier_matD(1) cols_dim index_mult_mat(2) subset_code(1))
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length index_mult_mat(2) index_mult_mat(3) mat_of_cols_cols)
    also have "... = A *v (B *v vec nc v)"
      using assms(1) assms(2) by auto
    also have "... = lincomb_list (λi. (B *v vec nc v) $ i) (cols A)"
      apply (subst lincomb_list_as_mat_mult)
      using assms(1) carrier_dim_vec cols_dim apply blast
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length dim_mult_mat_vec dim_vec eq_vecI index_vec mat_of_cols_cols)
    ultimately have "lincomb_list v (cols (A * B)) =
              lincomb_list (λi. (B *v vec nc v) $ i) (cols A)" by auto
    thus "ca. lincomb_list v (cols (A * B)) = lincomb_list ca (cols A)" by auto
  qed
  show ?thesis
    apply (subst span_list_as_span[symmetric])
     apply (metis assms(1) carrier_matD(1) cols_dim index_mult_mat(2))
    apply (subst span_list_as_span[symmetric])
    using assms(1) cols_dim apply blast
    by (auto simp add:span_list_def *)
qed

lemma rank_mat_mul_right:
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "rank (A * B)  rank A"
proof -
  have "subspace class_ring (local.span (set (cols (A*B))))
        (vs (local.span (set (cols A))))"
    unfolding subspace_def
    by (metis assms(1) assms(2) carrier_matD(1) cols_dim index_mult_mat(2) nested_submodules span_is_submodule vec_space.span_mat_mul_subset vec_vs_col)
  from vectorspace.subspace_dim[OF _ this]
  have "vectorspace.dim class_ring
   (vs (local.span (set (cols A)))
    carrier := local.span (set (cols (A * B)))) 
  vectorspace.dim class_ring
      (vs (local.span (set (cols A))))"
    apply auto
    by (metis (no_types) assms(1) carrier_matD(1) fin_dim_span_cols index_mult_mat(2) mat_of_cols_carrier(1) mat_of_cols_cols vec_vs_col)
  thus ?thesis unfolding rank_def
    by auto
qed

lemma sumlist_drop:
  assumes "v. v  set ls  dim_vec v = n"
  shows "sumlist ls = sumlist (filter (λv. v  0v n) ls)"
  using assms
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case using dim_sumlist by auto
qed

lemma lincomb_list_alt:
  shows "lincomb_list c s =
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) [0..<length s]) [0..<length s])"
  unfolding lincomb_list_def
  by (smt length_map map2_map_map map_nth nth_equalityI nth_map)

lemma lincomb_list_alt2:
  assumes "v. v  set s  dim_vec v = n"
  assumes "i. i  set ls  i < length s"
  shows "
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) ls) ls) =
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) (filter (λi. c i  0) ls)) (filter (λi. c i  0) ls))"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a s)
  then show ?case
    apply auto
    apply (subst smult_l_null)
     apply (simp add: assms(1) carrier_vecI)
    apply (subst left_zero_vec)
     apply (subst sumlist_carrier)
      apply auto
    by (metis (no_types, lifting) assms(1) carrier_dim_vec mem_Collect_eq nth_mem set_filter set_zip_rightD)
qed 

lemma two_set:
  assumes "distinct ls"
  assumes "set ls = set [a,b]"
  assumes "a  b"
  shows "ls = [a,b]  ls = [b,a]"
  apply (cases ls)
  using assms(2) apply auto[1]
proof -
  fix x xs
  assume ls:"ls = x # xs"
  obtain y ys where xs:"xs = y # ys"
    by (metis (no_types) ls = x # xs assms(2) assms(3) list.set_cases list.set_intros(1) list.set_intros(2) set_ConsD)
  have 1:"x = a  x = b"
    using ls = x # xs assms(2) by auto
  have 2:"y = a  y = b"
    using ls = x # xs xs = y # ys assms(2) by auto
  have 3:"ys = []"
    by (metis (no_types) ls = x # xs xs = y # ys assms(1) assms(2) distinct.simps(2) distinct_length_2_or_more in_set_member member_rec(2) neq_Nil_conv set_ConsD)
  show "ls = [a, b]  ls = [b, a]" using ls xs 1 2 3 assms
    by auto
qed

lemma filter_disj_inds:
  assumes "i < length ls" "j < length ls" "i  j"
  shows "filter (λia. ia  j  ia = i) [0..<length ls] = [i, j] 
  filter (λia. ia  j  ia = i) [0..<length ls] = [j,i]"
proof -
  have 1: "distinct (filter (λia. ia = i  ia = j) [0..<length ls])"
    using distinct_filter distinct_upt by blast
  have 2:"set (filter (λia. ia = i  ia = j) [0..<length ls]) = {i, j}"
    using assms by auto
  show ?thesis using two_set[OF 1]
    using assms(3) empty_set filter_cong list.simps(15)
    by (smt "2" assms(3) empty_set filter_cong list.simps(15))
qed

lemma lincomb_list_indpt_distinct:
  assumes "v. v  set ls  dim_vec v = n"
  assumes
    "c. lincomb_list c ls = 0v n  (i. i < (length ls)  c i = 0)"
  shows "distinct ls"
  unfolding distinct_conv_nth
proof clarsimp
  fix i j
  assume ij: "i < length ls" "j < length ls" "i  j" 
  assume lsij: "ls ! i = ls ! j"
  have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls =
     (ls ! i) - (ls ! j)"
    unfolding lincomb_list_alt
    apply (subst lincomb_list_alt2[OF assms(1)])
      apply auto
    using  filter_disj_inds[OF ij]
    apply auto
    using ij(3) apply force
    using assms(1) ij(2) apply auto[1]
    using ij(3) apply blast
    using assms(1) ij(2) by auto
  also have "...  = 0v n" unfolding lsij
    apply (rule minus_cancel_vec)
    using j < length ls assms(1)
    using carrier_vec_dim_vec nth_mem by blast
  ultimately have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls = 0v n" by auto
  from assms(2)[OF this]
  show False
    using i < length ls by auto
qed

end

locale conjugatable_vec_space = vec_space f_ty n for
  f_ty::"'a::conjugatable_ordered_field itself"
  and n
begin                                                           

lemma transpose_rank_mul_conjugate_transpose:
  fixes A :: "'a mat"
  assumes "A  carrier_mat n nc"
  shows "vec_space.rank nc AH  rank (A * AH)"
proof -
  have 1: "AH  carrier_mat nc n" using assms by auto
  have 2: "A * AH  carrier_mat n n" using assms by auto
      (* S is a maximal linearly independent set of rows A (or cols AT) *)
  let ?P = "(λT. T  set (cols AH)  module.lin_indpt class_ring (module_vec TYPE('a) nc) T)"
  have *:"A. ?P A  finite A  card A  n"
  proof clarsimp
    fix S
    assume S: "S  set (cols AH)"
    have "card S  card (set (cols AH))" using S
      using card_mono by blast
    also have "...  length (cols AH)" using card_length by blast
    also have "...  n" using assms by auto
    ultimately show "finite S  card S  n"
      by (meson List.finite_set S dual_order.trans finite_subset)
  qed
  have **:"?P {}"
    apply (subst module.lin_dep_def)
    by (auto simp add: vec_module)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting)) 
      (* Some properties of S *)
  from vec_space.rank_card_indpt[OF 1 S]
  have rankeq: "vec_space.rank nc AH = card S" .

  have s_hyp: "S  set (cols AH)"
    using S unfolding maximal_def by simp
  have modhyp: "module.lin_indpt class_ring (module_vec TYPE('a) nc) S" 
    using S unfolding maximal_def by simp

(* switch to a list representation *)
  obtain ss where ss: "set ss = S" "distinct ss"
    by (metis (mono_tags) S maximal_def set_obtain_sublist)
  have ss2: "set (map ((*v) A) ss) = (*v) A ` S"
    by (simp add: ss(1))
  have rw_hyp: "cols (mat_of_cols n (map ((*v) A) ss)) = cols  (A * mat_of_cols nc ss)" 
    unfolding cols_def apply (auto)
    using mat_vec_as_mat_mat_mult[of A n nc]
    by (metis (no_types, lifting) "1" assms carrier_matD(1) cols_dim mul_mat_of_cols nth_mem s_hyp ss(1) subset_code(1))
  then have rw: "mat_of_cols n (map ((*v) A) ss) = A * mat_of_cols nc ss"
    by (metis assms carrier_matD(1) index_mult_mat(2) mat_of_cols_carrier(2) mat_of_cols_cols) 
  have indpt: "c. lincomb_list c (map ((*v) A) ss) = 0v n 
      i. (i < (length ss)  c i = 0)"
  proof clarsimp
    fix c i
    assume *: "lincomb_list c (map ((*v) A) ss) = 0v n"
    assume i: "i < length ss"
    have "wset (map ((*v) A) ss). dim_vec w = n"
      using assms by auto
    from lincomb_list_as_mat_mult[OF this]
    have "A * mat_of_cols nc ss *v  vec (length ss) c = 0v n"
      using * rw by auto
    then have hq: "A *v (mat_of_cols nc ss *v vec (length ss) c) =  0v n"
      by (metis assms assoc_mult_mat_vec mat_of_cols_carrier(1) vec_carrier)

    then have eq1: "(mat_of_cols nc ss *v vec (length ss) c) =  0v nc"
      apply (intro mat_mul_conjugate_transpose_sub_vec_eq_0)
      using assms ss s_hyp by auto

(* Rewrite the inner vector back to a lincomb_list *)
    have dim_hyp2: "wset ss. dim_vec w = nc"
      using ss(1) s_hyp
      by (metis "1" carrier_matD(1) carrier_vecD cols_dim subsetD) 
    from vec_module.lincomb_list_as_mat_mult[OF this, symmetric]
    have "mat_of_cols nc ss *v vec (length ss) c = module.lincomb_list (module_vec TYPE('a) nc) c ss" .
    then have "module.lincomb_list (module_vec TYPE('a) nc) c ss = 0v nc" using eq1 by auto
    from vec_space.lin_indpt_lin_comb_list[OF ss(2) _ _ this i]
    show "c i = 0" using modhyp ss s_hyp
      using "1" cols_dim by blast
  qed
  have distinct: "distinct (map ((*v) A) ss)"
    by (metis (no_types, lifting) assms carrier_matD(1) dim_mult_mat_vec imageE indpt length_map lincomb_list_indpt_distinct ss2)
  then have 3: "card S = card ((*v) A ` S)"
    by (metis ss distinct_card image_set length_map)
  then have 4: "(*v) A ` S  set (cols (A * AH))"
    using cols_mat_mul S  set (cols AH) by blast
  have 5: "lin_indpt ((*v) A ` S)"
  proof clarsimp
    assume ld:"lin_dep ((*v) A ` S)"
    have *: "finite ((*v) A ` S)"
      by (metis List.finite_set ss2)
    have **: "(*v) A ` S  carrier_vec n"
      using "2" "4" cols_dim by blast
    from finite_lin_dep[OF * ld **]
    obtain a v where
      a: "lincomb a ((*v) A ` S) = 0v n" and
      v: "v  (*v) A ` S" "a v  0" by blast
    obtain i where i:"v = map ((*v) A) ss ! i" "i < length ss"
      using v unfolding ss2[symmetric]
      using find_first_le nth_find_first by force
    from ss2[symmetric]
    have "set (map ((*v) A) ss) carrier_vec n" using ** ss2 by auto
    from lincomb_as_lincomb_list_distinct[OF this distinct] have
      "lincomb_list
     (λi. a (map ((*v) A) ss ! i))  (map ((*v) A) ss) = 0v n"
      using a ss2 by auto
    from indpt[OF this]
    show False using v i by simp
  qed
  from rank_ge_card_indpt[OF 2 4 5]
  have "card ((*v) A ` S)  rank (A * AH)" .
  thus ?thesis using rankeq 3 by linarith
qed

lemma conjugate_transpose_rank_le:
  assumes "A  carrier_mat n nc"
  shows "vec_space.rank nc (AH)  rank A"
  by (metis assms carrier_matD(2) carrier_mat_triv dim_row_conjugate dual_order.trans index_transpose_mat(2) rank_mat_mul_right transpose_rank_mul_conjugate_transpose)

lemma conjugate_finsum:
  assumes f: "f : U  carrier_vec n"
  shows "conjugate (finsum V f U) = finsum V (conjugate  f) U"
  using f
proof (induct U rule: infinite_finite_induct)
  case (infinite A)
  then show ?case by auto
next
  case empty
  then show ?case by auto
next
  case (insert u U)
  hence f: "f : U  carrier_vec n" "f u : carrier_vec n"  by auto
  then have cf: "conjugate  f : U  carrier_vec n"
    "(conjugate  f) u : carrier_vec n"
     apply (simp add: Pi_iff)
    by (simp add: f(2))
  then show ?case
    unfolding finsum_insert[OF insert(1) insert(2) f]
    unfolding finsum_insert[OF insert(1) insert(2) cf ]
    apply (subst conjugate_add_vec[of _ n])
    using f(2) apply blast
    using M.finsum_closed f(1) apply blast
    by (simp add: comp_def f(1) insert.hyps(3))
qed

lemma rank_conjugate_le:
  assumes A:"A  carrier_mat n d"
  shows "rank (conjugate (A))  rank A"
proof -
  (* S is a maximal linearly independent set of (conjugate A) *)
  let ?P = "(λT. T  set (cols (conjugate A))  lin_indpt T)"
  have *:"A. ?P A  finite A  card A  d"
    by (metis List.finite_set assms card_length card_mono carrier_matD(2) cols_length dim_col_conjugate dual_order.trans rev_finite_subset)
  have **:"?P {}"
    by (simp add: finite_lin_indpt2)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting))
  have s_hyp: "S  set (cols (conjugate A))" "lin_indpt S"
    using S unfolding maximal_def
     apply blast
    by (metis (no_types, lifting) S maximal_def)
  from rank_card_indpt[OF _ S, of d]
  have rankeq: "rank (conjugate A) = card S" using assms by auto 
  have 1:"conjugate ` S  set (cols A)"
    using S apply auto
    by (metis (no_types, lifting) cols_conjugate conjugate_id image_eqI in_mono list.set_map s_hyp(1))
  have 2: "lin_indpt (conjugate ` S)"
    apply (rule ccontr)
    apply (auto simp add: lin_dep_def)
  proof -
    fix T c v
    assume T: "T  conjugate ` S" "finite T" and
      lc:"lincomb c T = 0v n" and "v  T"  "c v  0"
    let ?T = "conjugate ` T"
    let ?c = "conjugate  c  conjugate"
    have 1: "finite ?T"  using T by auto
    have 2: "?T  S"  using T by auto
    have 3: "?c  ?T  UNIV" by auto
    have "lincomb ?c ?T = (VxT. conjugate (c x) v conjugate x)"
      unfolding lincomb_def
      apply (subst finsum_reindex)
        apply auto
       apply (metis "2" carrier_vec_conjugate assms carrier_matD(1) cols_dim image_eqI s_hyp(1) subsetD)
      by (meson conjugate_cancel_iff inj_onI)
    also have "... = (VxT. conjugate (c x v x)) "
      by (simp add: conjugate_smult_vec)
    also have "... = conjugate (VxT. (c x v x))"
      apply(subst conjugate_finsum[of "λx.(c x v x)" T])
       apply (auto simp add:o_def)
      by (smt Matrix.carrier_vec_conjugate Pi_I' T(1) assms carrier_matD(1) cols_dim dim_row_conjugate imageE s_hyp(1) smult_carrier_vec subset_eq) 
    also have "... = conjugate (lincomb c T)"
      using lincomb_def by presburger
    ultimately have "lincomb ?c ?T = conjugate (lincomb c T)" by auto
    then have 4:"lincomb ?c ?T = 0v n" using lc by auto
    from not_lindepD[OF s_hyp(2) 1 2 3 4]
    have "conjugate  c  conjugate  conjugate ` T  {0}" .
    then have "c v = 0"
      by (simp add: Pi_iff v  T)
    thus False using c v  0 by auto
  qed
  from rank_ge_card_indpt[OF A 1 2]
  have 3:"card (conjugate ` S)  rank A" .
  have 4: "card (conjugate ` S) = card S"
    apply (auto intro!: card_image)
    by (meson conjugate_cancel_iff inj_onI)
  show ?thesis using rankeq 3 4 by auto
qed

lemma rank_conjugate:
  assumes "A  carrier_mat n d"
  shows "rank (conjugate A) = rank A"
  using  rank_conjugate_le
  by (metis carrier_vec_conjugate assms conjugate_id dual_order.antisym)

end (* exit the context *)

lemma conjugate_transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (AH)"
  using  conjugatable_vec_space.conjugate_transpose_rank_le
  by (metis (no_types, lifting) Matrix.transpose_transpose carrier_matI conjugate_id dim_col_conjugate dual_order.antisym index_transpose_mat(2) transpose_conjugate)

lemma transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (AT)"
  by (metis carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_transpose_mat(2))

lemma rank_mat_mul_left:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "vec_space.rank n (A * B)  vec_space.rank d B"
  by (metis (no_types, lifting) Matrix.transpose_transpose assms(1) assms(2) carrier_matD(1) carrier_matD(2) carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_mult_mat(3) index_transpose_mat(3) transpose_mult vec_space.rank_mat_mul_right)

section "Results on Invertibility"

(* Extract specific columns of a matrix  *)
definition take_cols :: "'a mat  nat list  'a mat"
  where "take_cols A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (filter ((>) (dim_col A)) inds))"

definition take_cols_var :: "'a mat  nat list  'a mat"
  where "take_cols_var A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (inds))"

definition take_rows :: "'a mat  nat list  'a mat"
  where "take_rows A inds = mat_of_rows (dim_col A) (map ((!) (rows A)) (filter ((>) (dim_row A)) inds))"

lemma cong1:
  "x = y   mat_of_cols n x = mat_of_cols n y"
  by auto

lemma nth_filter:
  assumes "j < length (filter P ls)"
  shows "P  ((filter P ls) ! j)"
  by (simp add: assms list_ball_nth)

lemma take_cols_mat_mul:
  assumes "A  carrier_mat nr n"
  assumes "B  carrier_mat n nc"
  shows "A * take_cols B inds = take_cols (A * B) inds"
proof -
  have "j. j < length (map ((!) (cols B)) (filter ((>) nc) inds)) 
      (map ((!) (cols B)) (filter ((>) nc) inds)) ! j  carrier_vec n"
    using assms apply auto
    apply (subst cols_nth)
    using nth_filter by auto
  from mul_mat_of_cols[OF assms(1) this]
  have "A *  take_cols B inds = mat_of_cols nr (map (λx. A *v cols B ! x) (filter ((>) (dim_col B)) inds))"
    unfolding take_cols_def using assms by (auto simp add: o_def)
  also have "... = take_cols (A * B) inds"
    unfolding take_cols_def using assms apply (auto intro!: cong1)
    by (simp add: mult_mat_vec_def)
  ultimately show ?thesis by auto
qed

lemma take_cols_carrier_mat:
  assumes "A  carrier_mat nr nc"
  obtains n where "take_cols A inds  carrier_mat nr n"
  unfolding take_cols_def
  using assms
  by fastforce

lemma take_cols_carrier_mat_strict:
  assumes "A  carrier_mat nr nc"
  assumes "i. i  set inds  i < nc"
  shows "take_cols A inds  carrier_mat nr (length inds)"
  unfolding take_cols_def
  using assms by auto

lemma gauss_jordan_take_cols:  
  assumes "gauss_jordan A (take_cols A inds) = (C,D)"
  shows "D = take_cols C inds"
proof -
  obtain nr nc where A: "A   carrier_mat nr nc" by auto
  from take_cols_carrier_mat[OF this]
  obtain n where B: "take_cols A inds  carrier_mat nr n" by auto
  from gauss_jordan_transform[OF A B assms, of undefined]
  obtain P where PP:"PUnits (ring_mat TYPE('a) nr undefined)" and
    CD: "C = P * A" "D = P * take_cols A inds" by blast
  have P: "P  carrier_mat nr nr"
    by (metis (no_types, lifting) Units_def PP mem_Collect_eq partial_object.select_convs(1) ring_mat_def)
  from take_cols_mat_mul[OF P A]
  have "P * take_cols A inds = take_cols (P * A) inds" .
  thus ?thesis using CD by blast  
qed

lemma dim_col_take_cols:
  assumes "j. j  set inds  j < dim_col A"
  shows "dim_col (take_cols A inds) = length inds"
  unfolding take_cols_def
  using assms by auto

lemma dim_col_take_rows[simp]:
  shows "dim_col (take_rows A inds) = dim_col A"
  unfolding take_rows_def by auto

lemma cols_take_cols_subset:
  shows "set (cols (take_cols A inds))  set (cols A)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply auto
  using in_set_conv_nth by fastforce

lemma dim_row_take_cols[simp]:
  shows "dim_row (take_cols A ls) = dim_row A"
  by (simp add: take_cols_def)

lemma dim_row_append_rows[simp]:
  shows "dim_row (A @r B) = dim_row A + dim_row B"
  by (simp add: append_rows_def)

lemma rows_inj:
  assumes "dim_col A = dim_col B"
  assumes "rows A = rows B"
  shows "A = B"
  unfolding mat_eq_iff
  apply auto
    apply (metis assms(2) length_rows)
  using assms(1) apply blast
  by (metis assms(1) assms(2) mat_of_rows_rows)

lemma append_rows_index:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  assumes "j < dim_col A"
  shows "(A @r B) $$ (i,j) = (if i < dim_row A then A $$ (i,j) else B $$ (i-dim_row A,j))"
  unfolding append_rows_def
  apply (subst index_mat_four_block)
  using assms by auto

lemma row_append_rows:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  shows "row (A @r B) i = (if i < dim_row A then row A i else row B (i-dim_row A))"
  unfolding vec_eq_iff
  using assms by (auto simp add: append_rows_def)

lemma append_rows_mat_mul:
  assumes "dim_col A = dim_col B"
  shows "(A @r B) * C = A * C @r B * C"
  unfolding mat_eq_iff
  apply auto
   apply (simp add: append_rows_def)
  apply (subst index_mult_mat)
    apply auto
   apply (simp add: append_rows_def)
  apply (subst  append_rows_index)
     apply auto
    apply (simp add: append_rows_def)
   apply (metis add.right_neutral append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) row_append_rows trans_less_add1)
  by (metis add_cancel_right_right add_diff_inverse_nat append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) nat_add_left_cancel_less row_append_rows)

lemma cardlt:
  shows "card  {i. i < (n::nat)}  n"
  by simp

lemma row_echelon_form_zero_rows:
  assumes row_ech: "row_echelon_form A"
  assumes dim_asm: "dim_col A  dim_row A"
  shows "take_rows A [0..<length (pivot_positions A)] @r  0m (dim_row A - length (pivot_positions A))  (dim_col A) = A"
proof -
  have ex_pivot_fun: " f. pivot_fun A f (dim_col A)" using row_ech unfolding row_echelon_form_def by auto
  have len_help: "length (pivot_positions A) = card {i. i < dim_row A  row A i  0v (dim_col A)}"
    using ex_pivot_fun pivot_positions[where A = "A",where nr = "dim_row A", where nc = "dim_col A"]
    by auto
  then have len_help2: "length (pivot_positions A)  dim_row A"
    by (metis (no_types, lifting) card_mono cardlt finite_Collect_less_nat le_trans mem_Collect_eq subsetI)
  have fileq: "filter (λy. y < dim_row A) [0..< length (pivot_positions A)] = [0..<length (pivot_positions A)]"
    apply (rule filter_True)
    using len_help2 by auto
  have "n. card {i. i < n   row A i  0v (dim_col A)}  n"
  proof clarsimp 
    fix n
    have h: "x. x  {i. i < n  row A i  0v (dim_col A)}  x{..<n}"
      by simp
    then have h1: "{i. i < n   row A i  0v (dim_col A)}  {..<n}"
      by blast
    then have h2: "(card {i. i < n   row A i  0v (dim_col A)}::nat)  (card {..<n}::nat)"
      using card_mono by blast 
    then show "(card {i. i < n  row A i  0v (dim_col A)}::nat)  (n::nat)" using h2 card_lessThan[of n]
      by auto
  qed
  then have pivot_len: "length (pivot_positions A)  dim_row A "  using len_help
    by simp
  have alt_char: "mat_of_rows (dim_col A)
         (map ((!) (rows A)) (filter (λy. y < dim_col A) [0..<length (pivot_positions A)])) = 
      mat_of_rows (dim_col A) (map ((!) (rows A))  [0..<length (pivot_positions A)])"
    using pivot_len dim_asm
    by auto
  have h1: "i j. i < dim_row A 
           j < dim_col A 
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) 
           take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
  proof - 
    fix i 
    fix j
    assume "i < dim_row A"
    assume j_lt: "j < dim_col A"
    assume i_lt: "i < dim_row (take_rows A [0..<length (pivot_positions A)])" 
    have lt: "length (pivot_positions A)  dim_row A" using pivot_len by auto
    have h1: "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = (row (take_rows A [0..<length (pivot_positions A)]) i)$j"
      by (simp add: i_lt j_lt)
    then have h2: "(row (take_rows A [0..<length (pivot_positions A)]) i)$j = (row A i)$j"
      using lt alt_char i_lt unfolding take_rows_def by auto
    show "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
      using h1 h2
      by (simp add: i < dim_row A j_lt) 
  qed
  let ?nc = "dim_col A"
  let ?nr = "dim_row A"
  have h2: "i j. i < dim_row A 
           j < dim_col A 
           ¬ i < dim_row (take_rows A [0..<length (pivot_positions A)]) 
           0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)"
  proof - 
    fix i
    fix j
    assume lt_i: "i < dim_row A"
    assume lt_j: "j < dim_col A"
    assume not_lt: "¬ i < dim_row (take_rows A [0..<length (pivot_positions A)])"
    let ?ip = "i+1"
    have h0: "f. pivot_fun A f (dim_col A)  f i = ?nc"
    proof -  
      have half1: "f. pivot_fun A f (dim_col A)" using assms unfolding row_echelon_form_def
        by blast
      have half2: "f. pivot_fun A f (dim_col A)  f i = ?nc " 
      proof clarsimp
        fix f
        assume is_piv: "pivot_fun A f (dim_col A)"
        have len_pp: "length (pivot_positions A) = card {i. i < ?nr  row A i  0v ?nc}" using is_piv pivot_positions[of A ?nr ?nc f]
          by auto
        have  "i. (i < ?nr  row A i  0v ?nc)   (i < ?nr  f i  ?nc)"
          using is_piv pivot_fun_zero_row_iff[of A f ?nc ?nr]
          by blast
        then have len_pp_var: "length (pivot_positions A) = card {i. i < ?nr  f i  ?nc}" 
          using len_pp  by auto 
        have allj_hyp: "j < ?nr. f j = ?nc  ((Suc j) < ?nr  f (Suc j) = ?nc)" 
          using is_piv unfolding pivot_fun_def 
          using lt_i
          by (metis le_antisym le_less) 
        have if_then_bad: "f i  ?nc  (j. j  i  f j  ?nc)"
        proof clarsimp 
          fix j
          assume not_i: "f i  ?nc"
          assume j_leq: "j  i"
          assume bad_asm: "f j = ?nc"
          have "k. k  j   k < ?nr  f k = ?nc"
          proof -
            fix k :: nat
            assume a1: "j  k"
            assume a2: "k < dim_row A"
            have f3: "n. ¬ n < dim_row A  f n  f j  ¬ Suc n < dim_row A  f (Suc n) = f j"
              using allj_hyp bad_asm by presburger
            obtain nn :: "nat  nat  (nat  bool)  nat" where
              f4: "n na p nb nc. (¬ n  na  Suc n  Suc na)  (¬ p nb  ¬ nc  nb  ¬ p (nn nc nb p)  p nc)  (¬ p nb  ¬ nc  nb  p nc  p (Suc (nn nc nb p)))"
              using inc_induct order_refl by moura
            then have f5: "p. ¬ p k  p j  p (Suc (nn j k p))"
              using a1 by presburger
            have f6: "p. ¬ p k  ¬ p (nn j k p)  p j"
              using f4 a1 by meson
            { assume "nn j k (λn. n < dim_row A  f n  dim_col A) < dim_row A  f (nn j k (λn. n < dim_row A  f n  dim_col A))  dim_col A"
              moreover
              { assume "(nn j k (λn. n < dim_row A  f n  dim_col A) < dim_row A  f (nn j k (λn. n < dim_row A  f n  dim_col A))  dim_col A)  (¬ j < dim_row A  f j = dim_col A)"
                then have "¬ k < dim_row A  f k = dim_col A"
                  using f6
                  by (metis (mono_tags, lifting)) }
              ultimately have "(¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A)  ¬ k < dim_row A  f k = dim_col A"
                using bad_asm
                by blast }
            moreover
            { assume "(¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A)"
              then have "¬ k < dim_row A  f k = dim_col A"
                using f5
              proof -
                have "¬ (Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A)))  dim_col A)  ¬ (j < dim_row A  f j  dim_col A)"
                  using (¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A) by linarith
                then have "¬ (k < dim_row A  f k  dim_col A)"
                  by (metis (mono_tags, lifting) a2 bad_asm f5 le_less)
                then show ?thesis
                  by meson
              qed }
            ultimately show "f k = dim_col A"
              using f3 a2 by (metis (lifting) Suc_lessD bad_asm)
          qed
          then show "False" using lt_i not_i
            using j_leq by blast 
        qed
        have "f i  ?nc  ({0..<?ip}  {y. y < ?nr  f y  dim_col A})"
        proof -
          have h1: "f i  dim_col A  (ji. j < ?nr  f j  dim_col A)"
            using if_then_bad lt_i by auto
          then show ?thesis by auto
        qed
        then have gteq: "f i  ?nc  (card {i. i < ?nr  f i  dim_col A}  (i+1))"
          using card_lessThan[of ?ip] card_mono[where B = "{i. i < ?nr  f i  dim_col A} ", where A = "{0..<?ip}"]
          by auto
        then have clear: "dim_row (take_rows A [0..<length (pivot_positions A)]) = length (pivot_positions A)"
          unfolding take_rows_def using dim_asm fileq by (auto)
        have "i + 1 > length (pivot_positions A)" using not_lt clear by auto
        then show "f i = ?nc" using gteq len_pp_var by auto
      qed
      show ?thesis using half1 half2
        by blast 
    qed
    then have h1a: "row A i =  0v (dim_col A)" 
      using pivot_fun_zero_row_iff[of A _ ?nc ?nr]
      using lt_i by blast
    then have h1: "A $$ (i, j) = 0"
      using index_row(1) lt_i lt_j by fastforce 
    have h2a: "i - dim_row (take_rows A [0..<length (pivot_positions A)]) < dim_row A - length (pivot_positions A)"
      using pivot_len lt_i not_lt
      by (simp add: take_rows_def)
    then have h2: "0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) = 0 " 
      unfolding zero_mat_def using pivot_len lt_i lt_j
      using index_mat(1) by blast 
    then show "0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)" using h1 h2
      by simp 
  qed
  have h3: "(dim_row (take_rows A [0..<length (pivot_positions A)])::nat) + ((dim_row A::nat) - (length (pivot_positions A)::nat)) =
    dim_row A"
  proof - 
    have h0: "dim_row (take_rows A [0..<length (pivot_positions A)]) = (length (pivot_positions A)::nat)" 
      by (simp add: take_rows_def fileq)
    then show ?thesis using add_diff_inverse_nat  pivot_len
      by linarith
  qed
  have h4: " i j. i < dim_row A 
           j < dim_col A 
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) +
               (dim_row A - length (pivot_positions A))"
    using pivot_len
    by (simp add: h3) 
  then show ?thesis apply (subst mat_eq_iff)
    using h1 h2 h3 h4 by (auto simp add: append_rows_def)
qed

lemma length_pivot_positions_dim_row:
  assumes "row_echelon_form A"
  shows "length (pivot_positions A)  dim_row A"
proof -
  have 1: "A  carrier_mat (dim_row A) (dim_col A)" by auto
  obtain f where 2: "pivot_fun A f (dim_col A)"
    using assms row_echelon_form_def by blast
  from pivot_positions(4)[OF 1 2] have
    "length (pivot_positions A) = card {i. i < dim_row A  row A i  0v (dim_col A)}" .
  also have "...  card {i. i < dim_row A}"
    apply (rule card_mono)
    by auto
  ultimately show ?thesis by auto
qed

lemma rref_pivot_positions:
  assumes "row_echelon_form R"
  assumes R: "R  carrier_mat nr nc"
  shows "i j. (i,j)  set (pivot_positions R)  i < nr  j < nc"
proof -
  obtain f where f: "pivot_fun R f nc"
    using assms(1) assms(2) row_echelon_form_def by blast
  have *: "i. i < nr  f i  nc" using f
    using R pivot_funD(1) by blast
  from pivot_positions[OF R f]
  have "set (pivot_positions R) = {(i, f i) |i. i < nr  f i  nc}" by auto
  then have **: "set (pivot_positions R) = {(i, f i) |i. i < nr  f i < nc}"
    using *
    by fastforce
  fix i j
  assume "(i, j)  set (pivot_positions R)"
  thus "i < nr  j < nc" using **
    by simp
qed

lemma pivot_fun_monoton: 
  assumes pf: "pivot_fun A f (dim_col A)"
  assumes dr: "dim_row A = nr"
  shows " i. i < nr  ( k. ((k < nr  i < k)  f i  f k))"
proof -
  fix i
  assume "i < nr"
  show "( k. ((k < nr  i < k)  f i  f k))"
  proof -
    fix k
    show "((k < nr  i < k)  f i  f k)"
    proof (induct k)
      case 0
      then show ?case
        by blast 
    next
      case (Suc k)
      then show ?case 
        by (smt dr le_less_trans less_Suc_eq less_imp_le_nat pf pivot_funD(1) pivot_funD(3))
    qed
  qed
qed

lemma pivot_positions_contains:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A  dim_row A"
  assumes "pivot_fun A f (dim_col A)"
  shows "i < (length (pivot_positions A)). (i, f i)  set (pivot_positions A)"
proof - 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  let ?pp = "pivot_positions A"          
  have i_nr: "i < (length ?pp). i < ?nr" using rref_pivot_positions assms
    using length_pivot_positions_dim_row less_le_trans by blast 
  have i_nc: "i < (length ?pp). f i < ?nc"
  proof clarsimp 
    fix i
    assume i_lt: "i < length ?pp"
    have fis_nc: "f i = ?nc  ( k > i. k < ?nr  f k = ?nc)"
    proof -
      assume is_nc: "f i = ?nc"
      show "( k > i. k < ?nr  f k = ?nc)" 
      proof clarsimp
        fix k
        assume k_gt: "k > i"
        assume k_lt: "k < ?nr"
        have fk_lt: "f k  ?nc" using pivot_funD(1)[of A ?nr f ?nc k] k_lt apply (auto)
          using ‹pivot_fun A f (dim_col A) by blast 
        show "f k = ?nc"
          using fk_lt is_nc k_gt k_lt assms pivot_fun_monoton[of A f ?nr i k]
          using ‹pivot_fun A f (dim_col A) by auto 
      qed
    qed
    have ncimp: "f i = ?nc  ( k i. k  { i. i < ?nr  row A i  0v ?nc})"
    proof -
      assume nchyp: "f i = ?nc"
      show "( k i. k  { i. i < ?nr  row A i  0v ?nc})"
      proof clarsimp 
        fix k
        assume i_lt: "i  k" 
        assume k_lt: "k < dim_row A"
        show "row A k = 0v (dim_col A) "
          using i_lt k_lt fis_nc
          using pivot_fun_zero_row_iff[of A f ?nc ?nr]
          using ‹pivot_fun A f (dim_col A) le_neq_implies_less nchyp by blast 
      qed
    qed
    then have "f i = ?nc  card { i. i < ?nr  row A i  0v ?nc}  i"
    proof - 
      assume nchyp: "f i = ?nc"
      have h: "{ i. i < ?nr  row A i  0v ?nc}  {0..<i}"
        using atLeast0LessThan le_less_linear nchyp ncimp by blast
      then show " card { i. i < ?nr  row A i  0v ?nc}  i"
        using card_lessThan
        using subset_eq_atLeast0_lessThan_card by blast 
    qed
    then show "f i < ?nc" using i_lt pivot_positions(4)[of A ?nr ?nc f]
      apply (auto)
      by (metis ‹pivot_fun A f (dim_col A) i_nr le_neq_implies_less not_less pivot_funD(1)) 
  qed
  then show ?thesis
    using pivot_positions(1)
    by (smt ‹pivot_fun A f (dim_col A) carrier_matI i_nr less_not_refl mem_Collect_eq)
qed

lemma pivot_positions_form_helper_1:
  shows "(a, b)  set (pivot_positions_main_gen z A nr nc i j)  i  a"
proof  (induct i j rule: pivot_positions_main_gen.induct[of nr nc A z])
  case (1 i j)
  then show ?case using  pivot_positions_main_gen.simps[of z A nr nc i j]
    apply (auto)
    by (smt Suc_leD le_refl old.prod.inject set_ConsD)
qed

lemma pivot_positions_form_helper_2:
  shows "strict_sorted (map fst (pivot_positions_main_gen z A nr nc i j))"
proof  (induct i j rule: pivot_positions_main_gen.induct[of nr nc A z])
  case (1 i j)
  then show ?case using  pivot_positions_main_gen.simps[of z A nr nc i j]
    apply (auto) using pivot_positions_form_helper_1
    by (simp add: pivot_positions_form_helper_1 Suc_le_lessD)
qed

lemma sorted_pivot_positions:
  shows "strict_sorted (map fst (pivot_positions A))"
  using pivot_positions_form_helper_2
  by (simp add: pivot_positions_form_helper_2 pivot_positions_gen_def) 

lemma pivot_positions_form:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A  dim_row A"
  shows " i < (length (pivot_positions A)). fst ((pivot_positions A) ! i) = i"
proof clarsimp 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  let ?pp = "pivot_positions A :: (nat × nat) list"
  fix i
  assume i_lt: "i < length (pivot_positions A)"
  have "f. pivot_fun A f ?nc" using row_ech unfolding row_echelon_form_def
    by blast
  then obtain f where pf:"pivot_fun A f ?nc"
    by blast                  
  have all_f_in: "i < (length ?pp). (i, f i)  set ?pp"
    using pivot_positions_contains pf
      assms 
    by blast   
  have sorted_hyp: " (p::nat) (q::nat). p < (length ?pp)  q < (length ?pp)  p < q  (fst (?pp ! p) < fst (?pp ! q))"
  proof -
    fix p::nat
    fix q::nat
    assume p_lt: "p < q"
    assume p_welldef: "p < (length ?pp)"
    assume q_welldef: "q < (length ?pp)"
    show "fst (?pp ! p) < fst (?pp ! q)"
      using sorted_pivot_positions p_lt p_welldef q_welldef apply (auto)
      by (smt find_first_unique length_map nat_less_le nth_map p_welldef sorted_nth_mono sorted_pivot_positions strict_sorted_iff)     
  qed
  have h: "i < (length ?pp)  fst (pivot_positions A ! i) = i"
  proof (induct i)
    case 0
    have "j. fst (pivot_positions A ! j) = 0"
      by (metis all_f_in fst_conv i_lt in_set_conv_nth length_greater_0_conv list.size(3) not_less0)
    then obtain j where jth:" fst (pivot_positions A ! j) = 0"
      by blast      
    have "j  0  (fst (pivot_positions A ! 0) > 0  j  0)"
      using sorted_hyp apply (auto)
      by (metis all_f_in fst_conv i_lt in_set_conv_nth length_greater_0_conv list.size(3) neq0_conv not_less0)  
    then show ?case
      using jth neq0_conv by blast
  next
    case (Suc i)
    have ind_h: "i < length (pivot_positions A)  fst (pivot_positions A ! i) = i"
      using Suc.hyps by blast 
    have thesis_h: "(Suc i) < length (pivot_positions A)  fst (pivot_positions A ! (Suc i)) = (Suc i)"
    proof - 
      assume suc_i_lt: "(Suc i) < length (pivot_positions A)"
      have fst_i_is: "fst (pivot_positions A ! i) = i" using suc_i_lt ind_h
        using Suc_lessD by blast 
      have "(j < (length ?pp). fst (pivot_positions A ! j) = (Suc i))"
        by (metis suc_i_lt all_f_in fst_conv  in_set_conv_nth)
      then obtain j where jth: "j < (length ?pp)  fst (pivot_positions A ! j) = (Suc i)"
        by blast
      have "j > i"
        using sorted_hyp apply (auto)
        by (metis Suc_lessD ‹fst (pivot_positions A ! i) = i jth less_not_refl linorder_neqE_nat n_not_Suc_n suc_i_lt)
      have "j > (Suc i)  False"
      proof -
        assume j_gt: "j > (Suc i)"
        then have h1: "fst (pivot_positions A ! (Suc i)) > i"
          using fst_i_is sorted_pivot_positions
          using sorted_hyp suc_i_lt by force
        have "fst (pivot_positions A ! j) > fst (pivot_positions A ! (Suc i))"
          using jth j_gt sorted_hyp apply (auto)
          by fastforce 
        then have h2: "fst (pivot_positions A ! (Suc i)) < (Suc i)" 
          using jth
          by simp   
        show "False" using h1 h2
          using not_less_eq by blast 
      qed
      show "fst (pivot_positions A ! (Suc i)) = (Suc i)"
        using Suc_lessI ‹Suc i < j  False› i < j jth by blast
    qed
    then show ?case
      by blast 
  qed
  then show "fst (pivot_positions A ! i) = i"
    using i_lt by auto
qed

lemma take_cols_pivot_eq:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A  dim_row A"
  shows "take_cols A (map snd (pivot_positions A)) =
    1m (length (pivot_positions A)) @r
    0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))"
proof - 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  have h1: " dim_col
     (1m (length (pivot_positions A)) @r
      0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) = (length (pivot_positions A))"
    by (simp add: append_rows_def)
  have len_pivot: "length (pivot_positions A) = card {i. i < ?nr  row A i  0v ?nc}"
    using row_ech pivot_positions(4) row_echelon_form_def by blast
  have pp_leq_nc: "f. pivot_fun A f ?nc  (i < ?nr. f i  ?nc)" unfolding pivot_fun_def
    by meson 
  have pivot_set: "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. i < ?nr  f i  ?nc}"
    using row_ech row_echelon_form_def pivot_positions(1)
    by (smt (verit) Collect_cong carrier_matI)
  then have pivot_set_alt: "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. i < ?nr  row A i  0v ?nc}"
    using pivot_positions pivot_fun_zero_row_iff Collect_cong carrier_mat_triv
    by (smt (verit, best))
  have "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. f i  ?nc  i < ?nr  f i  ?nc}"
    using pivot_set pp_leq_nc by auto
  then have pivot_set_var: "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. i < ?nr  f i < ?nc}"
    by auto
  have "length (map snd (pivot_positions A)) = card (set (map snd (pivot_positions A)))" 
    using row_ech row_echelon_form_def pivot_positions(3) distinct_card[where xs = "map snd (pivot_positions A)"]
    by (metis carrier_mat_triv)
  then have "length (map snd (pivot_positions A)) = card (set (pivot_positions A))"
    by (metis card_distinct distinct_card distinct_map length_map) 
  then have "length (map snd (pivot_positions A)) = card {i. i < ?nr  row A i  0v ?nc}"
    using pivot_set_alt
    by (simp add: len_pivot) 
  then have length_asm: "length (map snd (pivot_positions A)) = length (pivot_positions A)"
    using len_pivot by linarith
  then have "a. List.member (map snd (pivot_positions A)) a  a < dim_col A"
  proof clarsimp 
    fix a
    assume a_in: "List.member (map snd (pivot_positions A)) a"
    have "v  set (pivot_positions A). a = snd v" 
      using a_in in_set_member[where xs = "(pivot_positions A)"] apply (auto)
      by (metis in_set_impl_in_set_zip2 in_set_member length_map snd_conv zip_map_fst_snd) 
    then show "a < dim_col A"
      using pivot_set_var in_set_member by auto
  qed
  then have h2b: "(filter (λy. y < dim_col A) (map snd (pivot_positions A))) =  (map snd (pivot_positions A))"
    by (meson filter_True in_set_member)
  then have h2a: "length (map ((!) (cols A)) (filter (λy. y < dim_col A) (map snd (pivot_positions A)))) = length (pivot_positions A)"
    using length_asm
    by (simp add: h2b) 
  then have h2: "length (pivot_positions A)  dim_row A 
    dim_col (take_cols A (map snd (pivot_positions A))) = (length (pivot_positions A))" 
    unfolding take_cols_def using mat_of_cols_carrier by auto
  have h_len: "length (pivot_positions A)  dim_row A 
    dim_col (take_cols A (map snd (pivot_positions A))) =
    dim_col
     (1m (length (pivot_positions A)) @r
      0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A)))" 
    using h1 h2
    by (simp add: h1 assms length_pivot_positions_dim_row)
  have h2: "i j. length (pivot_positions A)  dim_row A 
           i < dim_row A 
           j < dim_col
                (1m (length (pivot_positions A)) @r
                 0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) 
           take_cols A (map snd (pivot_positions A)) $$ (i, j) =
           (1m (length (pivot_positions A)) @r
            0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j)" 
  proof -
    fix i 
    fix j 
    let ?pp = "(pivot_positions A)"
    assume len_lt: "length (pivot_positions A)  dim_row A" 
    assume i_lt: " i < dim_row A" 
    assume j_lt: "j < dim_col
                (1m (length (pivot_positions A)) @r
                 0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A)))"
    let ?w = "((map snd (pivot_positions A)) ! j)"
    have breaking_it_down: "mat_of_cols (dim_row A)
     (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j)  
     =  ((cols A) ! ?w) $ i"
      apply (auto)
      by (metis comp_apply h1 i_lt j_lt length_map mat_of_cols_index nth_map) 
    have h1a: "i < (length ?pp)  (mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) 
        = (1m (length (pivot_positions A))) $$ (i, j))"
    proof - 
      (* need to, using row_ech, rely heavily on pivot_fun_def, that num_cols ≥ num_rows, and row_echelon form*)
      assume "i < (length ?pp)"
      have "f. pivot_fun A f ?nc" using row_ech unfolding row_echelon_form_def
        by blast
      then obtain f where "pivot_fun A f ?nc"
        by blast
      have j_nc: "j < (length ?pp)" using j_lt
        by (simp add: h1) 
      then have j_lt_nr: "j < ?nr" using dim_h
        using len_lt by linarith 
      then have is_this_true: "(pivot_positions A) ! j = (j, f j)" 
        using pivot_positions_form pivot_positions(1)[of A ?nr ?nc f]
      proof -
        have "pivot_positions A ! j  set (pivot_positions A)"
          using j_nc nth_mem by blast
        then have "n. pivot_positions A ! j = (n, f n)  n < dim_row A  f n  dim_col A"
          using A  carrier_mat (dim_row A) (dim_col A); pivot_fun A f (dim_col A)  set (pivot_positions A) = {(i, f i) |i. i < dim_row A  f i  dim_col A} ‹pivot_fun A f (dim_col A) by blast
        then show ?thesis
          by (metis (no_types) A. row_echelon_form A; dim_row A  dim_col A  i<length (pivot_positions A). fst (pivot_positions A ! i) = i dim_h fst_conv j_nc row_ech)
      qed
      then have w_is: "?w = f j"
        by (metis h1 j_lt nth_map snd_conv)
      have h0: "i = j  ((cols A) ! ?w) $ i = 1" using w_is pivot_funD(4)[of A ?nr f ?nc i]
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A i < length (pivot_positions A) ‹pivot_fun A f (dim_col A) cols_length i_lt in_set_member length_asm mat_of_cols_cols mat_of_cols_index nth_mem)
      have h1:  "i  j  ((cols A) ! ?w) $ i = 0" using w_is pivot_funD(5)
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A ‹pivot_fun A f (dim_col A) cols_length h1 i_lt in_set_member j_lt len_lt length_asm less_le_trans mat_of_cols_cols mat_of_cols_index nth_mem)
      show "(mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) 
        = (1m (length (pivot_positions A))) $$ (i, j))" using h0 h1 breaking_it_down
        by (metis i < length (pivot_positions A) h2 h_len index_one_mat(1) j_lt len_lt) 
    qed
    have h1b: "i  (length ?pp)  (mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j)  = 0)"
    proof - 
      assume i_gt: "i  (length ?pp)"
      have h0a: "((cols A) ! ((map snd (pivot_positions A)) ! j)) $ i = (row A i) $ ?w"
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A cols_length h1 i_lt in_set_member index_row(1) j_lt length_asm mat_of_cols_cols mat_of_cols_index nth_mem)
      have h0b: 
        "take_rows A [0..<length (pivot_positions A)] @r 0m (dim_row A - length (pivot_positions A)) (dim_col A) = A"
        using assms row_echelon_form_zero_rows[of A]
        by blast 
      then have h0c: "(row A i) = 0v (dim_col A)"  using i_gt
        using add_diff_cancel_right' add_less_cancel_left diff_is_0_eq' dim_col_take_rows dim_row_append_rows i_lt index_zero_mat(2) index_zero_mat(3) le_add_diff_inverse len_lt less_not_refl3 row_append_rows row_zero zero_less_diff
        by (smt add_diff_cancel_right' add_less_cancel_left diff_is_0_eq' dim_col_take_rows dim_row_append_rows i_lt index_zero_mat(2) index_zero_mat(3) le_add_diff_inverse len_lt less_not_refl3 row_append_rows row_zero zero_less_diff) 
      then show ?thesis using h0a breaking_it_down apply (auto)
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A h1 in_set_member index_zero_vec(1) j_lt length_asm nth_mem) 
    qed
    have h1: " mat_of_cols (dim_row A)
     (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) =
           (1m (length (pivot_positions A)) @r
            0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j) " using h1a h1b
      apply (auto)
      by (smt add_diff_inverse_nat add_less_cancel_left append_rows_index h1 i_lt index_one_mat(2) index_one_mat(3) index_zero_mat(1) index_zero_mat(2) index_zero_mat(3) j_lt len_lt not_less)  
    then show "take_cols A (map snd (pivot_positions A)) $$ (i, j) =
           (1m (length (pivot_positions A)) @r
            0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j)" 
      unfolding take_cols_def
      by (simp add: h2b)
  qed
  show ?thesis
    unfolding mat_eq_iff
    using length_pivot_positions_dim_row[OF assms(1)] h_len h2 by auto
qed

lemma rref_right_mul:
  assumes "row_echelon_form A"
  assumes "dim_col A  dim_row A"
  shows
    "take_cols A (map snd (pivot_positions A)) * take_rows A [0..<length (pivot_positions A)] = A"
proof -
  from take_cols_pivot_eq[OF assms] have
    1: "take_cols A (map snd (pivot_positions A)) =
    1m (length (pivot_positions A)) @r
    0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))" .
  have 2: "take_cols A (map snd (pivot_positions A)) * take_rows A [0..<length (pivot_positions A)] =
    take_rows A [0..<length (pivot_positions A)]  @r 0m (dim_row A - length (pivot_positions A)) (dim_col A)"
    unfolding 1
    apply (auto simp add: append_rows_mat_mul)
    by (smt add_diff_cancel_right' assms diff_diff_cancel dim_col_take_rows dim_row_append_rows index_zero_mat(2) left_mult_one_mat' left_mult_zero_mat' length_pivot_positions_dim_row row_echelon_form_zero_rows)   
  from row_echelon_form_zero_rows[OF assms] have
    "... = A" .
  thus ?thesis
    by (simp add: "2")
qed

context conjugatable_vec_space begin

lemma lin_indpt_id:
  shows "lin_indpt (set (cols (1m n)::'a vec list))"
proof -
  have *: "set (cols (1m n)) = set (rows (1m n))"
    by (metis cols_transpose transpose_one)
  have "det (1m n)  0" using det_one by auto
  from det_not_0_imp_lin_indpt_rows[OF _ this]
  have "lin_indpt (set (rows (1m n)))"
    using one_carrier_mat by blast
  thus ?thesis
    by (simp add: *) 
qed

lemma lin_indpt_take_cols_id:
  shows "lin_indpt (set (cols (take_cols (1m n) inds)))"
proof - 
  have subset_h: "set (cols (take_cols (1m n) inds))  set (cols (1m n)::'a vec list)"
    using cols_take_cols_subset by blast
  then show ?thesis using lin_indpt_id subset_li_is_li by auto
qed

lemma cols_id_unit_vecs:
  shows "cols (1m d) = unit_vecs d"
  unfolding unit_vecs_def list_eq_iff_nth_eq
  by auto

lemma distinct_cols_id:
  shows "distinct (cols (1m d)::'a vec list)"
  by (simp add: conjugatable_vec_space.cols_id_unit_vecs vec_space.unit_vecs_distinct)

lemma distinct_map_nth:
  assumes "distinct ls"
  assumes "distinct inds"
  assumes "j. j  set inds  j < length ls"
  shows "distinct (map ((!) ls) inds)"
  by (simp add: assms(1) assms(2) assms(3) distinct_map inj_on_nth)

lemma distinct_take_cols_id:
  assumes "distinct inds"
  shows "distinct (cols (take_cols (1m n) inds) :: 'a vec list)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply (auto intro!:  distinct_map_nth simp add: distinct_cols_id)
  using assms distinct_filter by blast

lemma rank_take_cols:
  assumes "distinct inds"
  shows "rank (take_cols (1m n) inds) = length (filter ((>) n) inds)"
  apply (subst lin_indpt_full_rank[of _ "length (filter ((>) n) inds)"])
     apply (auto simp add: lin_indpt_take_cols_id)
   apply (metis (full_types) index_one_mat(2) index_one_mat(3) length_map mat_of_cols_carrier(1) take_cols_def)
  by (simp add: assms distinct_take_cols_id)

lemma rank_mul_left_invertible_mat:
  fixes A::"'a mat"
  assumes "invertible_mat A"
  assumes "A  carrier_mat n n"
  assumes "B  carrier_mat n nc"
  shows "rank (A * B) = rank B"
proof -
  obtain C where C: "inverts_mat A C" "inverts_mat C A"
    using assms invertible_mat_def by blast 
  from C have ceq: "C * A = 1m n"
    by (metis assms(2) carrier_matD(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def)
  then have *:"B = C*A*B"
    using assms(3) by auto
  from rank_mat_mul_left[OF assms(2-3)]
  have **: "rank (A*B)  rank B" .
  have 1: "C  carrier_mat n n" using C ceq
    by (metis assms(2) carrier_matD(1) carrier_matI index_mult_mat(3) index_one_mat(3) inverts_mat_def) 
  have 2: "A * B  carrier_mat n nc" using assms by auto  
  have "rank B = rank (C* A * B)" using * by auto
  also have "...  rank (A*B)" using rank_mat_mul_left[OF 1 2]
    using "1" assms(2) assms(3) by auto
  ultimately show ?thesis using ** by auto
qed

lemma invertible_take_cols_rank:
  fixes A::"'a mat"
  assumes "invertible_mat A"
  assumes "A  carrier_mat n n"
  assumes "distinct inds"
  shows "rank (take_cols A inds) = length (filter ((>) n) inds)"
proof -
  have " A = A * 1m n" using assms(2) by auto
  then have "take_cols A inds = A * take_cols (1m n) inds"
    by (metis assms(2) one_carrier_mat take_cols_mat_mul)
  then have "rank (take_cols A inds) = rank (take_cols (1m n) inds)"
    by (metis assms(1) assms(2) conjugatable_vec_space.rank_mul_left_invertible_mat one_carrier_mat take_cols_carrier_mat)
  thus ?thesis
    by (simp add: assms(3) conjugatable_vec_space.rank_take_cols)
qed

lemma rank_take_cols_leq:
  assumes R:"R  carrier_mat n nc"
  shows "rank (take_cols R ls)  rank R"
proof -
  from take_cols_mat_mul[OF R]
  have "take_cols R ls =  R * take_cols (1m nc) ls"
    by (metis assms one_carrier_mat right_mult_one_mat)
  thus ?thesis
    by (metis assms one_carrier_mat take_cols_carrier_mat vec_space.rank_mat_mul_right)
qed

lemma rank_take_cols_geq:
  assumes R:"R  carrier_mat n nc"
  assumes t:"take_cols R ls  carrier_mat n r"
  assumes B:"B  carrier_mat r nc"
  assumes "R = (take_cols R ls) * B"
  shows "rank (take_cols R ls)  rank R"
  by (metis B assms(4) t vec_space.rank_mat_mul_right)

lemma rref_drop_pivots:
  assumes row_ech: "row_echelon_form R"
  assumes dims: "R  carrier_mat n nc"
  assumes order: "nc  n"
  shows "rank (take_cols R (map snd (pivot_positions R))) = rank R"
proof -
  let ?B = "take_rows R [0..<length (pivot_positions R)]"
  have equa: "R = take_cols R (map snd (pivot_positions R)) * ?B" using assms rref_right_mul
    by (metis carrier_matD(1) carrier_matD(2))
  have ex_r: "r. take_cols R (map snd (pivot_positions R))  carrier_mat n r  ?B  carrier_mat r nc"
  proof - 
    have h1:
      "take_cols R (map snd (pivot_positions R))  carrier_mat n (length (pivot_positions R))"
      using assms
      by (metis in_set_impl_in_set_zip2 length_map rref_pivot_positions take_cols_carrier_mat_strict zip_map_fst_snd)
    have " f. pivot_fun R f nc" using row_ech unfolding row_echelon_form_def using dims
      by blast
    then have "length (pivot_positions R) = card {i. i < n  row R i  0v nc}"
      using pivot_positions[of R n nc]
      using dims by auto 
    then have "nc  length (pivot_positions R)" using order
      using carrier_matD(1) dims dual_order.trans length_pivot_positions_dim_row row_ech by blast
    then have "dim_col R  length (pivot_positions R)" using dims by auto
    then have h2: "?B  carrier_mat (length (pivot_positions R)) nc" unfolding take_rows_def
      using dims 
      by (smt atLeastLessThan_iff carrier_matD(2) filter_True le_eq_less_or_eq length_map length_pivot_positions_dim_row less_trans map_nth mat_of_cols_carrier(1) row_ech set_upt transpose_carrier_mat transpose_mat_of_rows) 
    show ?thesis using h1 h2
      by blast
  qed
    (* prove the other two dimensionality assumptions *)
  have "rank R   rank (take_cols R (map snd (pivot_positions R)))"
    using dims ex_r rank_take_cols_geq[where R = "R", where B = "?B", where ls = "(map snd (pivot_positions R))", where nc = "nc"]
    using equa by blast
  thus ?thesis
    using assms(2) conjugatable_vec_space.rank_take_cols_leq le_antisym by blast
qed

lemma gjs_and_take_cols_var:
  fixes A::"'a mat"
  assumes A:"A  carrier_mat n nc"
  assumes order: "nc  n"
  shows "(take_cols A (map snd (pivot_positions (gauss_jordan_single A)))) = 
  (take_cols_var A (map snd (pivot_positions (gauss_jordan_single A))))"
proof -
  let ?gjs = "(gauss_jordan_single A)"
  have "x. List.member (map snd (pivot_positions (gauss_jordan_single A))) x  x  dim_col A"  
    using rref_pivot_positions gauss_jordan_single(3) carrier_matD(2) gauss_jordan_single(2) in_set_impl_in_set_zip2 in_set_member length_map less_irrefl less_trans not_le_imp_less zip_map_fst_snd
    by (smt A carrier_matD(2) gauss_jordan_single(2) in_set_impl_in_set_zip2 in_set_member length_map less_irrefl less_trans not_le_imp_less zip_map_fst_snd)
  then have "(filter (λy. y < dim_col A) (map snd (pivot_positions (gauss_jordan_single A)))) = 
    (map snd (pivot_positions (gauss_jordan_single A)))"
    by (metis (no_types, lifting) A carrier_matD(2) filter_True gauss_jordan_single(2) gauss_jordan_single(3) in_set_impl_in_set_zip2 length_map rref_pivot_positions zip_map_fst_snd)
  then show ?thesis unfolding take_cols_def take_cols_var_def
    by simp
qed

lemma gauss_jordan_single_rank:
  fixes A::"'a mat"
  assumes A:"A  carrier_mat n nc"
  assumes order: "nc  n"
  shows "rank (take_cols A (map snd (pivot_positions (gauss_jordan_single A)))) = rank A"
proof -
  let ?R = "gauss_jordan_single A"
  obtain P where P:"PUnits (ring_mat TYPE('a) n undefined)" and
    i: "?R = P * A" using gauss_jordan_transform[OF A]
    using A assms det_mult det_non_zero_imp_unit det_one gauss_jordan_single(4) mult_not_zero one_neq_zero
    by (smt A assms det_mult det_non_zero_imp_unit det_one gauss_jordan_single(4) mult_not_zero one_neq_zero)
  have pcarrier: "P  carrier_mat n n" using P unfolding Units_def
    by (auto simp add: ring_mat_def)
  have "invertible_mat P" using P unfolding invertible_mat_def Units_def inverts_mat_def
    apply auto
     apply (simp add: ring_mat_simps(5))
    by (metis index_mult_mat(2) index_one_mat(2) ring_mat_simps(1) ring_mat_simps(3))
  then
  obtain Pi where Pi: "invertible_mat Pi" "Pi * P = 1m n"
  proof -
    assume a1: "Pi. invertible_mat Pi; Pi * P = 1m n  thesis"
    have "dim_row P = n"
      by (metis (no_types) A assms(1) carrier_matD(1) gauss_jordan_single(2) i index_mult_mat(2))
    then show ?thesis
      using a1 by (metis (no_types) ‹invertible_mat P index_mult_mat(3) index_one_mat(3) invertible_mat_def inverts_mat_def square_mat.simps)
  qed
  then have pi_carrier:"Pi  carrier_mat n n"
    by (metis carrier_mat_triv index_mult_mat(2) index_one_mat(2) invertible_mat_def square_mat.simps)
  have R1:"row_echelon_form ?R"
    using assms(2) gauss_jordan_single(3) by blast
  have R2: "?R  carrier_mat n nc"
    using A assms(2) gauss_jordan_single(2) by auto
  have Rcm: "take_cols ?R (map snd (pivot_positions ?R))
     carrier_mat n (length (map snd (pivot_positions ?R)))"
    apply (rule take_cols_carrier_mat_strict[OF R2])
    using rref_pivot_positions[OF R1 R2] by auto
  have "Pi * ?R = A" using i Pi
    by (smt A ‹invertible_mat P assoc_mult_mat carrier_mat_triv index_mult_mat(2) index_mult_mat(3) index_one_mat(3) invertible_mat_def left_mult_one_mat square_mat.simps)
  then have "rank (take_cols A (map snd (pivot_positions ?R))) = rank (take_cols (Pi * ?R) (map snd (pivot_positions ?R)))"
    by auto
  also have "... = rank ( Pi * take_cols ?R (map snd (pivot_positions ?R)))"
    by (metis A gauss_jordan_single(2) pi_carrier take_cols_mat_mul)
  also have "... = rank (take_cols ?R (map snd (pivot_positions ?R)))"
    by (intro rank_mul_left_invertible_mat[OF Pi(1) pi_carrier Rcm])
  also have "... = rank ?R"
    using assms(2) conjugatable_vec_space.rref_drop_pivots gauss_jordan_single(3)
    using R1 R2 by blast
  ultimately show ?thesis                                                            
    using A P  carrier_mat n n ‹invertible_mat P conjugatable_vec_space.rank_mul_left_invertible_mat i
    by auto
qed

lemma lin_indpt_subset_cols:
  fixes A:: "'a mat"
  fixes B:: "'a vec set"
  assumes "A  carrier_mat n n"
  assumes inv: "invertible_mat A"
  assumes "B  set (cols A)"
  shows "lin_indpt B"
proof -
  have "det A  0"
    using assms(1) inv invertible_det by blast
  then have "lin_indpt (set (rows AT))"
    using assms(1) idom_vec.lin_dep_cols_imp_det_0 by auto
  thus ?thesis using subset_li_is_li assms(3)
    by auto
qed

lemma rank_invertible_subset_cols:
  fixes A:: "'a mat"
  fixes B:: "'a vec list"
  assumes inv: "invertible_mat A"
  assumes A_square: "A  carrier_mat n n"
  assumes set_sub: "set (B)  set (cols A)"
  assumes dist_B: "distinct B"
  shows "rank (mat_of_cols n B) = length B"
proof - 
  let ?B_mat = "(mat_of_cols n B)"
  have h1: "lin_indpt (set(B))" 
    using assms lin_indpt_subset_cols[of A] by auto
  have "set B  carrier_vec n"
    using set_sub A_square cols_dim[of A] by auto
  then have cols_B: "cols (mat_of_cols n B) = B" using cols_mat_of_cols by auto
  then have "maximal (set B) (λT. T  set (B)  lin_indpt T)" using h1
    by (simp add: maximal_def subset_antisym)
  then have h2: "maximal (set B) (λT. T  set (cols (mat_of_cols n B))  lin_indpt T)"
    using cols_B by auto
  have h3: "rank (mat_of_cols n B) = card (set B)"
    using h1 h2 rank_card_indpt[of ?B_mat]
    using mat_of_cols_carrier(1) by blast 
  then show ?thesis using assms distinct_card by auto
qed

end

end

Theory BKR_Algorithm

theory BKR_Algorithm
  imports
    "More_Matrix"
    "Sturm_Tarski.Sturm_Tarski"
begin

section "Setup"

definition retrieve_polys:: "real poly list  nat list  real poly list"
  where "retrieve_polys qss index_list = (map (nth qss) index_list)"

definition construct_NofI:: "real poly  real poly list  rat"
  where "construct_NofI p I =  rat_of_int (changes_R_smods p ((pderiv p)*(prod_list I)))"

definition construct_rhs_vector:: "real poly  real poly list  nat list list  rat vec"
  where "construct_rhs_vector p qs Is = vec_of_list (map (λ I.(construct_NofI p (retrieve_polys qs I))) Is)"

section "Base Case"

definition base_case_info:: "(rat mat × (nat list list × rat list list))"
  where "base_case_info =
    ((mat_of_rows_list 2 [[1,1], [1,-1]]), ([[],[0]], [[1],[-1]]))"

(* When p, q are coprime, this will actually be an int vec, which is why taking the floor is okay *)
definition base_case_solve_for_lhs:: "real poly  real poly  rat vec"
  where "base_case_solve_for_lhs p q = (mult_mat_vec (mat_of_rows_list 2 [[1/2, 1/2], [1/2, -1/2]])  (construct_rhs_vector p [q] [[], [0]]))"

thm "gauss_jordan_compute_inverse"

primrec matr_option:: "nat  'a::{one, zero} mat option  'a mat"
  where "matr_option dimen None = 1m dimen"
  | "matr_option dimen (Some c) = c" 

(* For smooth code export, we want to use a computable notion of matrix equality *)
definition mat_equal:: "'a:: field mat  'a :: field mat  bool"
  where "mat_equal A B = (dim_row A = dim_row B  dim_col A = dim_col B  (mat_to_list A) = (mat_to_list B))"

definition mat_inverse_var :: "'a :: field mat  'a mat option" where 
  "mat_inverse_var A = (if dim_row A = dim_col A then
    let one = 1m (dim_row A) in
    (case gauss_jordan A one of
      (B, C)  if (mat_equal B one) then Some C else None) else None)"

(* Now solve for LHS in general. 
  Because mat_inverse returns an option type, we pattern match on this. 
  Notice that when we call this function in the algorithm, the matrix we pass will always be invertible,
  given how the construction works. *)
definition solve_for_lhs:: "real poly  real poly list  nat list list  rat mat  rat vec"
  where "solve_for_lhs p qs subsets matr =
     mult_mat_vec (matr_option (dim_row matr) (mat_inverse_var matr))  (construct_rhs_vector p qs subsets)"

section "Smashing" 

definition subsets_smash::"nat  nat list list  nat list list  nat list list"
  where "subsets_smash n s1 s2 = concat (map (λl1. map (λ l2. l1 @ (map ((+) n) l2)) s2) s1)"

definition signs_smash::"'a list list   'a list list  'a list list"
  where "signs_smash s1 s2 = concat (map (λl1. map (λ l2. l1 @ l2) s2) s1)"

definition smash_systems:: "real poly  real poly list  real poly list  nat list list  nat list list 
  rat list list  rat list list  rat mat  rat mat  
  real poly list × (rat mat × (nat list list × rat list list))"
  where "smash_systems p qs1 qs2 subsets1 subsets2 signs1 signs2 mat1 mat2 =
    (qs1@qs2, (kronecker_product mat1 mat2, (subsets_smash (length qs1) subsets1 subsets2, signs_smash signs1 signs2)))"

fun combine_systems:: "real poly  (real poly list × (rat mat × (nat list list × rat list list)))  (real poly list × (rat mat × (nat list list × rat list list)))
   (real poly list × (rat mat × (nat list list × rat list list)))"
  where "combine_systems p (qs1, m1, sub1, sgn1) (qs2, m2, sub2, sgn2) = 
    (smash_systems p qs1 qs2 sub1 sub2 sgn1 sgn2 m1 m2)"

(* Overall:
  Start with a matrix equation.
  Input a matrix, subsets, and signs.
  Drop columns of the matrix based on the 0's on the LHS---so extract a list of 0's. Reduce signs accordingly.
  Then find a list of rows to delete based on using rank (use the transpose result, pivot positions!),
   and delete those rows.  Reduce subsets accordingly.
  End with a reduced system! *)
section "Reduction"
definition find_nonzeros_from_input_vec:: "rat vec  nat list"
  where "find_nonzeros_from_input_vec lhs_vec = filter (λi. lhs_vec $ i  0) [0..< dim_vec lhs_vec]"

definition take_indices:: "'a list  nat list  'a list"
  where "take_indices subsets indices = map ((!) subsets) indices"

definition take_cols_from_matrix:: "'a mat  nat list  'a mat"
  where "take_cols_from_matrix matr indices_to_keep = 
    mat_of_cols (dim_row matr) ((take_indices (cols matr) indices_to_keep):: 'a vec list)"

definition take_rows_from_matrix:: "'a mat  nat list  'a mat"
  where "take_rows_from_matrix matr indices_to_keep = 
    mat_of_rows (dim_col matr) ((take_indices (rows matr) indices_to_keep):: 'a vec list)"

fun reduce_mat_cols:: "'a mat  rat vec  'a mat"
  where "reduce_mat_cols A lhs_vec = take_cols_from_matrix A (find_nonzeros_from_input_vec lhs_vec)"

(* Find which rows to drop. *)
definition rows_to_keep:: "('a::field) mat  nat list" where
  "rows_to_keep A = map snd (pivot_positions (gauss_jordan_single (AT)))"

fun reduction_step:: "rat mat  rat list list  nat list list  rat vec  rat mat × (nat list list × rat list list)"
  where "reduction_step A signs subsets lhs_vec = 
    (let reduce_cols_A = (reduce_mat_cols A lhs_vec);
         rows_keep = rows_to_keep reduce_cols_A in
    (take_rows_from_matrix  reduce_cols_A rows_keep,
      (take_indices subsets rows_keep,
      take_indices signs (find_nonzeros_from_input_vec lhs_vec))))"

fun reduce_system:: "real poly  (real poly list × (rat mat × (nat list list × rat list list)))  (rat mat × (nat list list × rat list list))"
  where "reduce_system p (qs,m,subs,signs) =
    reduction_step m signs subs (solve_for_lhs p qs subs m)" 

section "Overall algorithm "
  (* 
    Find the matrix, subsets, signs for an input p and qs.
    The "rat mat" in the output is the matrix. The "nat list list" is the list of subsets. 
    The "rat list list" is the list of signs.

    We will want to call this when p is nonzero and when every q in qs is pairwise coprime to p.
    Properties of this algorithm are proved in BKR_Proofs.thy. 
  *)
fun calculate_data:: "real poly  real poly list   (rat mat × (nat list list × rat list list))"
  where
    "calculate_data p qs = 
  ( let len = length qs in
    if len = 0 then
      (λ(a,b,c).(a,b,map (drop 1) c)) (reduce_system p ([1],base_case_info))
    else if len  1 then reduce_system p (qs,base_case_info)
    else
    (let q1 = take (len div 2) qs; left = calculate_data p q1;
         q2 = drop (len div 2) qs; right = calculate_data p q2;
         comb = combine_systems p (q1,left) (q2,right) in
         reduce_system p comb
    )
  )"

(* Extract the list of consistent sign assignments *)
definition find_consistent_signs_at_roots:: "real poly  real poly list  rat list list"
  where [code]:
    "find_consistent_signs_at_roots p qs =
  ( let (M,S,Σ) = calculate_data p qs in Σ )"

lemma find_consistent_signs_at_roots_thm:
  shows "find_consistent_signs_at_roots p qs = snd (snd (calculate_data p qs))"
  by (simp add: case_prod_beta find_consistent_signs_at_roots_def)

end

Theory Matrix_Equation_Construction

theory Matrix_Equation_Construction

imports "BKR_Algorithm"
begin

section "Results with Sturm's Theorem"

lemma relprime:
  fixes q::"real poly"
  assumes "coprime p q"
  assumes "p  0"
  assumes "q  0"
  shows "changes_R_smods p (pderiv p) = card {x. poly p x = 0  poly q x > 0} + card {x. poly p x = 0  poly q x < 0}"
proof -
  have 1: "{x. poly p x = 0  poly q x = 0} = {}"
    using assms(1) coprime_poly_0 by auto
  have 2: "changes_R_smods p (pderiv p) = int (card {x . poly p x = 0})" using sturm_R by auto
  have 3: "{x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0} = {}" by auto
  have "{x . poly p x = 0} =  {x. poly p x = 0  poly q x > 0} {x. poly p x = 0  poly q x < 0}  {x. poly p x = 0  poly q x = 0}" by force
  then have "{x . poly p x = 0} = {x. poly p x = 0  poly q x > 0} {x. poly p x = 0  poly q x < 0}" using 1 by auto
  then have "(card {x . poly p x = 0}) = (card ({x. poly p x = 0  poly q x > 0} {x. poly p x = 0  poly q x < 0}))" by presburger
  then have 4: "(card {x . poly p x = 0}) =  card {x. poly p x = 0  poly q x > 0} + card {x. poly p x = 0  poly q x < 0}" using 3 by (simp add: card_Un_disjoint assms(2) poly_roots_finite)
  show ?thesis  by (simp add: "2" "4")
qed

(* This is the same proof as card_eq_sum *)
lemma card_eq_const_sum: 
  fixes k:: real
  assumes "finite A"
  shows "k*card A = sum (λx. k) A"
proof -
  have "plus  (λ_. Suc 0) = (λ_. Suc)"
    by (simp add: fun_eq_iff)
  then have "Finite_Set.fold (plus  (λ_. Suc 0)) = Finite_Set.fold (λ_. Suc)"
    by (rule arg_cong)
  then have "Finite_Set.fold (plus  (λ_. Suc 0)) 0 A = Finite_Set.fold (λ_. Suc) 0 A"
    by (blast intro: fun_cong)
  then show ?thesis
    by (simp add: card.eq_fold sum.eq_fold)
qed

lemma restate_tarski:
  fixes q::"real poly"
  assumes "coprime p q"
  assumes "p  0"       
  assumes "q  0"
  shows "changes_R_smods p ((pderiv p) * q) = card {x. poly p x = 0  poly q x > 0} -  int(card {x. poly p x = 0  poly q x < 0})"
proof -
  have 3: "taq {x. poly p x=0} q  y{x. poly p x=0}. sign (poly q y)" by (simp add: taq_def)
  have 4: "{x. poly p x=0} =  {x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0}  {x. poly p x = 0  poly q x = 0}" by force
  then have 5: "{x. poly p x=0} =  {x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0}" using assms(1) coprime_poly_0 by auto
  then have 6: "y{x. poly p x=0}. sign (poly q y)  y{x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0}. sign (poly q y)" by presburger
  then have 12: "taq {x. poly p x=0} q  y{x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0}. sign (poly q y)" using 3 by linarith
  have 7: "{x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0} = {}" by auto
  then have 8: "y{x. poly p x = 0  poly q x > 0}  {x. poly p x = 0  poly q x < 0}. sign (poly q y)  (y{x. poly p x = 0  poly q x > 0}.sign (poly q y)) + (y{x. poly p x = 0  poly q x < 0}.sign(poly q y))" by (simp add: assms(2) poly_roots_finite sum.union_disjoint)
  then have 13: "taq {x. poly p x=0} q  (y{x. poly p x = 0  poly q x > 0}.sign (poly q y)) + (y{x. poly p x = 0  poly q x < 0}.sign(poly q y))" using 12 by linarith
  then have 9: "taq {x. poly p x = 0} q  (y{x. poly p x = 0  poly q x > 0}.1) + (y{x. poly p x = 0  poly q x < 0}.(-1))" by simp
  have 10: "(y{x. poly p x = 0  poly q x > 0}.1) =  card {x. poly p x = 0  poly q x > 0}" using card_eq_sum by auto
  have 11: " (y{x. poly p x = 0  poly q x < 0}.(-1)) = -1*card {x. poly p x = 0  poly q x < 0}" using card_eq_const_sum by simp
  have 14: "taq {x. poly p x = 0} q  card {x. poly p x = 0  poly q x > 0} + -1*card {x. poly p x = 0  poly q x < 0}" using 9 10 11 by simp
  have 1: "changes_R_smods p (pderiv p * q) = taq {x. poly p x=0} q" using sturm_tarski_R by simp
  then have 15: "changes_R_smods p (pderiv p * q) = card {x. poly p x = 0  poly q x > 0} + (-1*card {x. poly p x = 0  poly q x < 0})" using 14 by linarith
  have 16: "(-1*card {x. poly p x = 0  poly q x < 0}) = - card {x. poly p x = 0  poly q x < 0}" by auto
  then show ?thesis using 15 by linarith
qed

lemma restate_tarski2:
  fixes q::"real poly"
  assumes "p  0"
  shows "changes_R_smods p ((pderiv p) * q) =
        int(card {x. poly p x = 0  poly q x > 0}) -
        int(card {x. poly p x = 0  poly q x < 0})"
  unfolding sturm_tarski_R[symmetric] taq_def
proof -
  let ?all = "{x. poly p x=0}"
  let ?lt = "{x. poly p x=0  poly q x < 0}"
  let ?gt = "{x. poly p x=0  poly q x > 0}"
  let ?eq = "{x. poly p x=0  poly q x = 0}"
  have eq: "?all = ?lt  ?gt  ?eq" by force
  from poly_roots_finite[OF assms] have fin: "finite ?all" .
  show  "(x | poly p x = 0. sign (poly q x)) = int (card ?gt) - int (card ?lt)"
    unfolding eq
    apply (subst sum_Un)
      apply (auto simp add:fin)
    apply (subst sum_Un)
    by (auto simp add:fin)
qed

lemma coprime_set_prod:
  fixes I:: "real poly set"
  shows "finite I  (( q  I. (coprime p q))  (coprime p ( I)))"
proof (induct rule: finite_induct)
  case empty
  then show ?case
    by simp 
next
  case (insert x F)
  then show ?case using coprime_mult_right_iff
    by simp
qed

lemma finite_nonzero_set_prod:
  fixes I:: "real poly set"
  shows  nonzero_hyp: "finite I  (( q  I. q  0)   I  0)"
proof (induct rule: finite_induct)
  case empty
  then show ?case
    by simp 
next
  case (insert x F)
  have h: " (insert x F) = x * ( F)"
    by (simp add: insert.hyps(1) insert.hyps(2)) 
  have h_xin: "x  insert x F"
    by simp 
  have hq: "( q  (insert x F). q  0)  x  0" using h_xin
    by blast 
  show ?case using h hq
    using insert.hyps(3) by auto
qed

section "Setting up the construction: Definitions"

definition characterize_root_list_p:: "real poly  real list"
  where "characterize_root_list_p p  sorted_list_of_set({x. poly p x = 0}::real set)"

(************** Renegar's N(I); towards defining the RHS of the matrix equation **************)

lemma construct_NofI_prop:
  fixes p:: "real poly"
  fixes I:: "real poly list"
  assumes nonzero: "p0"
  shows "construct_NofI p I =
    rat_of_int (int (card {x. poly p x = 0  poly (prod_list I) x > 0}) - 
    int (card {x. poly p x = 0  poly (prod_list I) x < 0}))"
  unfolding construct_NofI_def
  using assms restate_tarski2 nonzero rsquarefree_def
  by (simp add: rsquarefree_def)

definition construct_s_vector:: "real poly  real poly list list  rat vec"
  where "construct_s_vector p Is = vec_of_list (map (λ I.(construct_NofI p I)) Is)"

(* Consistent sign assignments *)
definition squash::"'a::linordered_field  rat"
  where "squash x = (if x > 0 then 1
                    else if x < 0 then -1
                    else 0)"

definition signs_at::"real poly list  real  rat list"
  where "signs_at qs x 
    map (squash  (λq. poly q x)) qs"

definition characterize_consistent_signs_at_roots:: "real poly  real poly list  rat list list"
  where "characterize_consistent_signs_at_roots p qs =
  (remdups (map (signs_at qs) (characterize_root_list_p p)))"

(* An alternate version designed to be used when every polynomial in qs is relatively prime to p*)
definition consistent_sign_vec_copr::"real poly list  real  rat list"
  where "consistent_sign_vec_copr qs x 
    map (λ q. if (poly q x > 0) then (1::rat) else (-1::rat)) qs"

definition characterize_consistent_signs_at_roots_copr:: "real poly  real poly list  rat list list"
  where "characterize_consistent_signs_at_roots_copr p qss =
  (remdups (map (consistent_sign_vec_copr qss) (characterize_root_list_p p)))"

lemma csa_list_copr_rel:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  assumes nonzero: "p0"
  assumes pairwise_rel_prime: "q. ((List.member qs q)  (coprime p q))"
  shows "characterize_consistent_signs_at_roots p qs = characterize_consistent_signs_at_roots_copr p qs"
proof - 
  have "q  set(qs).  x  set (characterize_root_list_p p).  poly q x  0"
    using pairwise_rel_prime
    using coprime_poly_0 in_set_member nonzero poly_roots_finite characterize_root_list_p_def by fastforce 
  then have h: "q  set(qs).  x  set (characterize_root_list_p p). squash (poly q x) = (if (poly q x > 0) then (1::rat) else (-1::rat))"
    by (simp add: squash_def)
  have "map (λr. map (λp. if 0 < poly p r then 1 else - 1) qs) (characterize_root_list_p p) = map (λr. map (squash  (λp. poly p r)) qs) (characterize_root_list_p p)"
    by (simp add: h)
  thus ?thesis unfolding characterize_consistent_signs_at_roots_def characterize_consistent_signs_at_roots_copr_def
      signs_at_def consistent_sign_vec_copr_def
    by presburger
qed

(************** Towards defining Renegar's polynomial function and the LHS of the matrix equation **************)

definition list_constr:: "nat list  nat  bool"
  where "list_constr L n  list_all (λx. x < n) L"

definition all_list_constr:: "nat list list  nat  bool"
  where "all_list_constr L n  (x. List.member L x  list_constr x n)"

(* The first input is the subset; the second input is the consistent sign assignment.
  We want to map over the first list and pull out all of the elements in the second list with
  corresponding positions, and then multiply those together.
*)
definition z:: "nat list  rat list  rat"
  where "z index_list sign_asg  (prod_list (map (nth sign_asg) index_list))"

definition mtx_row:: "rat list list  nat list  rat list"
  where "mtx_row sign_list index_list  (map ( (z index_list)) sign_list)"

definition matrix_A:: "rat list list  nat list list  rat mat" 
  where "matrix_A sign_list subset_list = 
    (mat_of_rows_list (length sign_list) (map (λi .(mtx_row sign_list i)) subset_list))"

definition alt_matrix_A:: "rat list list  nat list list  rat mat"
  where "alt_matrix_A signs subsets = (mat (length subsets) (length signs) 
    (λ(i, j). z (subsets ! i) (signs ! j)))"

lemma alt_matrix_char: "alt_matrix_A signs subsets = matrix_A signs subsets"
proof - 
  have h0: "(i j. i < length subsets 
            j < length signs 
            map (λindex_list. map (z index_list) signs) subsets ! i ! j = z (subsets ! i) (signs ! j))"
  proof -
    fix i
    fix j
    assume i_lt: "i < length subsets"
    assume j_lt: "j < length signs"
    show "((map (λindex_list. map (z index_list) signs) subsets) ! i) ! j = z (subsets ! i) (signs ! j)"
    proof - 
      have h0: "(map (λindex_list. map (z index_list) signs) subsets) ! i =  map (z (subsets ! i)) signs" 
        using nth_map i_lt
        by blast
      then show ?thesis using nth_map j_lt
        by simp 
    qed
  qed
  have h: " mat (length subsets) (length signs) (λ(i, j). z (subsets ! i) (signs ! j)) =
    mat (length subsets) (length signs) (λ(i, y). map (λindex_list. map (z index_list) signs) subsets ! i ! y)"
    using h0 eq_matI[where A = "mat (length subsets) (length signs) (λ(i, j). z (subsets ! i) (signs ! j))",
        where B = "mat (length subsets) (length signs) (λ(i, y). map (λindex_list. map (z index_list) signs) subsets ! i ! y)"]
    by auto
  show ?thesis unfolding alt_matrix_A_def matrix_A_def mat_of_rows_list_def apply (auto) unfolding mtx_row_def
    using h   by blast
qed

lemma subsets_are_rows: "i < (length subsets). row (alt_matrix_A signs subsets) i  = vec (length signs) (λj. z (subsets ! i) (signs ! j))"
  unfolding row_def unfolding alt_matrix_A_def by auto

lemma signs_are_cols: "i < (length signs). col (alt_matrix_A signs subsets) i  = vec (length subsets) (λj. z (subsets ! j) (signs ! i))"
  unfolding col_def unfolding alt_matrix_A_def by auto

(* ith entry of LHS vector is the number of (distinct) real zeros of p where the sign vector of the qs  is the ith entry of signs.*)
definition construct_lhs_vector:: "real poly  real poly list  rat list list   rat vec"
  where "construct_lhs_vector p qs signs 
  vec_of_list (map (λw.  rat_of_int (int (length (filter (λv. v = w) (map (consistent_sign_vec_copr qs) (characterize_root_list_p p)))))) signs)"

(* Putting all of the pieces of the construction together *)
definition satisfy_equation:: "real poly  real poly list  nat list list  rat list list  bool"
  where "satisfy_equation p qs subset_list sign_list =
        (mult_mat_vec (matrix_A sign_list subset_list) (construct_lhs_vector p qs sign_list) = (construct_rhs_vector p qs subset_list))"

section "Setting up the construction: Proofs"

(* Some matrix lemmas  *)
lemma row_mat_of_rows_list:
  assumes "list_all (λr. length r = nc) rs"
  assumes "i < length rs"
  shows "row (mat_of_rows_list nc rs) i = vec_of_list (nth rs i)"
  by (smt assms(1) assms(2) dim_col_mat(1) dim_vec_of_list eq_vecI index_row(2) index_vec list_all_length mat_of_rows_list_def row_mat split_conv vec_of_list_index)


lemma mult_mat_vec_of_list:
  assumes "length ls = nc"
  assumes "list_all (λr. length r = nc) rs"
  shows "mat_of_rows_list nc rs *v vec_of_list ls =
    vec_of_list (map (λr. vec_of_list r  vec_of_list ls) rs)"
  unfolding mult_mat_vec_def
  using row_mat_of_rows_list assms 
  apply auto
  by (smt dim_row_mat(1) dim_vec dim_vec_of_list eq_vecI index_map_vec(1) index_map_vec(2) index_vec list_all_length mat_of_rows_list_def row_mat_of_rows_list vec_of_list_index)

lemma mtx_row_length:
  "list_all (λr. length r = length signs) (map (mtx_row signs) ls)"
  apply (induction ls)
  by (auto simp add: mtx_row_def)

thm construct_lhs_vector_def
thm  poly_roots_finite

(* Recharacterize the LHS vector *)
lemma construct_lhs_vector_clean:
  assumes "p  0"
  assumes "i < length signs"
  shows "(construct_lhs_vector p qs signs) $ i =
    card {x. poly p x = 0  ((consistent_sign_vec_copr qs x) = (nth signs i))}"
proof -
  from poly_roots_finite[OF assms(1)] have "finite {x. poly p x = 0}" .
  then have eq: "(Collect
       ((λv. v = signs ! i) 
        consistent_sign_vec_copr qs) 
      set (sorted_list_of_set
            {x. poly p x = 0})) =
    {x. poly p x = 0  consistent_sign_vec_copr qs x = signs ! i}"
    by auto
  show ?thesis
    unfolding construct_lhs_vector_def vec_of_list_index characterize_root_list_p_def
    apply auto
    apply (subst nth_map[OF assms(2)])
    apply auto
    apply (subst distinct_length_filter)
    using eq by auto
qed

lemma construct_lhs_vector_cleaner:
  assumes "p  0"
  shows "(construct_lhs_vector p qs signs) =
   vec_of_list (map (λs. rat_of_int (card {x. poly p x = 0  ((consistent_sign_vec_copr qs x) = s)})) signs)"
  apply (rule eq_vecI)
  apply (auto simp add:  construct_lhs_vector_clean[OF assms] )
  apply (simp add: vec_of_list_index)
  unfolding construct_lhs_vector_def
  using assms construct_lhs_vector_clean construct_lhs_vector_def apply auto[1]
  by simp

(* Show that because our consistent sign vectors consist of 1 and -1's, z returns 1 or -1 
  when applied to a consistent sign vector *)
lemma z_signs:
  assumes "list_all (λi. i < length signs) I"
  assumes "list_all (λs. s = 1  s = -1) signs"
  shows "(z I signs = 1)  (z I signs = -1)" using assms
proof (induction I)
  case Nil
  then show ?case
    by (auto simp add:z_def)
next
  case (Cons a I)
  moreover have "signs ! a = 1  signs ! a = -1"
    by (metis (mono_tags, lifting) add_Suc_right calculation(2) calculation(3) gr0_conv_Suc list.size(4) list_all_length nth_Cons_0)
  ultimately show ?case
    by (auto simp add:z_def)
qed

lemma z_lemma:
  fixes I:: "nat list" 
  fixes sign:: "rat list"
  assumes consistent: "sign  set (characterize_consistent_signs_at_roots_copr p qs)"
  assumes welldefined: "list_constr I (length qs)"
  shows "(z I sign = 1)  (z I sign = -1)"
proof (rule z_signs)
  have "length sign = length qs" using consistent
    by (auto simp add: characterize_consistent_signs_at_roots_copr_def consistent_sign_vec_copr_def)
  thus "list_all (λi. i < length sign) I"
    using welldefined
    by (auto simp add: list_constr_def characterize_consistent_signs_at_roots_copr_def consistent_sign_vec_copr_def)
  show "list_all (λs. s = 1  s = - 1) sign" using consistent
    apply (auto simp add: list.pred_map  characterize_consistent_signs_at_roots_copr_def  consistent_sign_vec_copr_def)
    using Ball_set
    by force
qed

(* Show that all consistent sign vectors on roots of polynomials are in characterize_consistent_signs_at_roots_copr  *)
lemma in_set: 
  fixes p:: "real poly"
  assumes nonzero: "p0"
  fixes qs:: "real poly list"
  fixes I:: "nat list" 
  fixes sign:: "rat list"
  fixes x:: "real"
  assumes root_p: "x  {x. poly p x = 0}"
  assumes sign_fix: "sign = consistent_sign_vec_copr qs x"
  assumes welldefined: "list_constr I (length qs)"
  shows "sign  set (characterize_consistent_signs_at_roots_copr p qs)" 
proof -
  have h1: "consistent_sign_vec_copr qs x 
      set (remdups (map (consistent_sign_vec_copr qs) (sorted_list_of_set {x. poly p x = 0})))" 
    using root_p apply auto apply (subst set_sorted_list_of_set)
    using nonzero poly_roots_finite rsquarefree_def apply blast by auto
  thus ?thesis unfolding characterize_consistent_signs_at_roots_copr_def characterize_root_list_p_def using sign_fix
    by blast
qed

(* Since all of the polynomials in qs are relatively prime to p, products of subsets of these
    polynomials are also relatively prime to p  *)
lemma nonzero_product: 
  fixes p:: "real poly"
  assumes nonzero: "p0"
  fixes qs:: "real poly list"
  assumes pairwise_rel_prime_1: "q. ((List.member qs q)  (coprime p q))"
  fixes I:: "nat list" 
  fixes x:: "real"
  assumes root_p: "x  {x. poly p x = 0}"
  assumes welldefined: "list_constr I (length qs)"
  shows "(poly (prod_list (retrieve_polys qs I)) x > 0)  (poly (prod_list (retrieve_polys qs I)) x < 0)"
proof -
  have "x. x  set (retrieve_polys qs I)  coprime p x"
    unfolding retrieve_polys_def
    by (smt in_set_conv_nth in_set_member length_map list_all_length list_constr_def nth_map pairwise_rel_prime_1 welldefined)
  then have coprimeh: "coprime p (prod_list (retrieve_polys qs I))"
    using prod_list_coprime_right by auto
  thus ?thesis using root_p
    using coprime_poly_0 linorder_neqE_linordered_idom by blast 
qed

(* The next few lemmas relate z to the signs of the product of subsets of polynomials of qs *)
lemma horiz_vector_helper_pos_ind: 
  fixes p:: "real poly"
  assumes nonzero: "p0"
  fixes qs:: "real poly list"
  assumes pairwise_rel_prime_1: "q. ((List.member qs q)  (coprime p q))"
  fixes I:: "nat list" 
  fixes sign:: "rat list"
  fixes x:: "real"
  assumes root_p: "x  {x. poly p x = 0}"
  assumes sign_fix: "sign = consistent_sign_vec_copr qs x"
  shows "(list_constr I (length qs))  (poly (prod_list (retrieve_polys qs I)) x > 0)  (z I sign = 1)"
proof (induct I)
  case Nil
  then show ?case
    by (simp add: retrieve_polys_def z_def) 
next
  case (Cons a I) 
  have welldef: "list_constr (a#I) (length qs)  (list_constr I (length qs))" 
    unfolding list_constr_def list_all_def by auto
  have set_hyp: "list_constr I (length qs)  sign  set (characterize_consistent_signs_at_roots_copr p qs)" 
    using in_set using nonzero root_p sign_fix by blast 
  have z_hyp: "list_constr I (length qs)  ((z I sign = 1)  (z I sign = -1))" 
    using set_hyp z_lemma[where sign="sign", where I = "I", where p="p", where qs="qs"] by blast
  have sign_hyp: "sign = map (λ q. if (poly q x > 0) then 1 else -1) qs" 
    using sign_fix unfolding consistent_sign_vec_copr_def by blast
  have ind_hyp_1: "list_constr (a#I) (length qs)  
    ((0 < poly (prod_list (retrieve_polys qs I)) x) = (z I sign = 1))"
    using welldef Cons.hyps by auto
  have ind_hyp_2: "list_constr (a#I) (length qs)  
    ((0 > poly (prod_list (retrieve_polys qs I)) x) = (z I sign = -1))"
    using welldef z_hyp Cons.hyps nonzero_product
    using pairwise_rel_prime_1 nonzero root_p by auto 
  have h1: "prod_list (retrieve_polys qs (a # I)) = (nth qs a)*(prod_list (retrieve_polys qs I))"
    by (simp add: retrieve_polys_def)
  have h2: "(z (a # I) sign) = (nth sign a)*(z I sign)"
    by (metis (mono_tags, hide_lams) list.simps(9) prod_list.Cons z_def)
  have h3help: "list_constr (a#I) (length qs)  a < length qs" unfolding list_constr_def
    by simp 
  then have h3: "list_constr (a#I) (length qs)  
    ((nth sign a) = (if (poly (nth qs a) x > 0) then 1 else -1))" 
    using nth_map sign_hyp by auto
  have h2: "(0 < poly ((nth qs a)*(prod_list (retrieve_polys qs I))) x)  
    ((0 < poly (nth qs a) x  (0 < poly (prod_list (retrieve_polys qs I)) x)) 
   (0 > poly (nth qs a) x  (0 > poly (prod_list (retrieve_polys qs I)) x)))"
    by (simp add: zero_less_mult_iff)
  have final_hyp_a: "list_constr (a#I) (length qs)  (((0 < poly (nth qs a) x  (0 < poly (prod_list (retrieve_polys qs I)) x)) 
     (0 > poly (nth qs a) x  (0 > poly (prod_list (retrieve_polys qs I)) x))) = 
    ((nth sign a)*(z I sign) = 1))" 
  proof -
    have extra_hyp_a: "list_constr (a#I) (length qs)  (0 < poly (nth qs a) x = ((nth sign a) = 1))" using h3
      by simp 
    have extra_hyp_b: "list_constr (a#I) (length qs)   (0 > poly (nth qs a) x = ((nth sign a) = -1))" 
      using h3 apply (auto) using coprime_poly_0 h3help in_set_member nth_mem pairwise_rel_prime_1 root_p by fastforce 
    have ind_hyp_1: "list_constr (a#I) (length qs)  (((0 < poly (nth qs a) x  (z I sign = 1))  
    (0 > poly (nth qs a) x  (z I sign = -1)))
      = ((nth sign a)*(z I sign) = 1))" using extra_hyp_a extra_hyp_b
      using zmult_eq_1_iff
      by (simp add: h3)   
    then show ?thesis
      using ind_hyp_1 ind_hyp_2 by (simp add: Cons.hyps welldef)
  qed
  then show ?case 
    using h1 z_def by (simp add: zero_less_mult_iff)  
qed

lemma horiz_vector_helper_pos: 
  fixes p:: "real poly"
  assumes nonzero: "p0"
  fixes qs:: "real poly list"
  assumes pairwise_rel_prime_1: "q. ((List.member qs q)  (coprime p q))"
  fixes I:: "nat list" 
  fixes sign:: "rat list"
  fixes x:: "real"
  assumes root_p: "x  {x. poly p x = 0}"
  assumes sign_fix: "sign = consistent_sign_vec_copr qs x"
  assumes welldefined: "list_constr I (length qs)"
  shows "(poly (prod_list (retrieve_polys qs I)) x > 0)  (z I sign = 1)"
  using horiz_vector_helper_pos_ind
  using pairwise_rel_prime_1 nonzero  root_p sign_fix welldefined by blast 

lemma horiz_vector_helper_neg: 
  fixes p:: "real poly"
  assumes nonzero: "p0"
  fixes qs:: "real poly list"
  assumes pairwise_rel_prime_1: "q. ((List.member qs q)  (coprime p q))"
  fixes I:: "nat list" 
  fixes sign:: "rat list"
  fixes x:: "real"
  assumes root_p: "x  {x. poly p x = 0}"
  assumes sign_fix: "sign = consistent_sign_vec_copr qs x"
  assumes welldefined: "list_constr I (length qs)"
  shows "(poly (prod_list (retrieve_polys qs I)) x < 0)  (z I sign = -1)"
proof - 
  have set_hyp: "list_constr I (length qs)  sign  set (characterize_consistent_signs_at_roots_copr p qs)" 
    using in_set using nonzero root_p sign_fix by blast 
  have z_hyp: "list_constr I (length qs)  ((z I sign = 1)  (z I sign = -1))" 
    using set_hyp  z_lemma[where sign="sign", where I = "I", where p="p", where qs="qs"] by blast
  have poly_hyp: "(poly (prod_list (retrieve_polys qs I)) x > 0)  (poly (prod_list (retrieve_polys qs I)) x < 0)"
    using nonzero_product
    using pairwise_rel_prime_1 nonzero root_p
    using welldefined by blast
  have pos_hyp: "(poly (prod_list (retrieve_polys qs I)) x > 0)  (z I sign = 1)" using horiz_vector_helper_pos
    using pairwise_rel_prime_1 nonzero root_p sign_fix welldefined by blast
  show ?thesis using z_hyp poly_hyp pos_hyp apply (auto)
    using welldefined by blast
qed

(* Recharacterize the dot product *)
lemma vec_of_list_dot_rewrite:
  assumes "length xs = length ys"
  shows "vec_of_list xs  vec_of_list ys =
    sum_list (map2 (*) xs ys)"
  using assms
proof (induction xs arbitrary:ys)
  case Nil
  then show ?case by auto
next
  case (Cons a xs)
  then show ?case apply auto
    by (smt (verit, best) Suc_length_conv list.simps(9) old.prod.case scalar_prod_vCons sum_list.Cons vec_of_list_Cons zip_Cons_Cons)
qed

lemma lhs_dot_rewrite:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  fixes I:: "nat list" 
  fixes signs:: "rat list list"
  assumes nonzero: "p0"
  shows
    "(vec_of_list (mtx_row signs I)  (construct_lhs_vector p qs signs)) =
   sum_list (map (λs. (z I s)  *  rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) signs)"
proof -
  have "p  0" using nonzero by auto
  from construct_lhs_vector_cleaner[OF this]
  have rhseq: "construct_lhs_vector p qs signs =
    vec_of_list
    (map (λs. rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) signs)" by auto
  have "(vec_of_list (mtx_row signs I)  (construct_lhs_vector p qs signs)) =    
    sum_list (map2 (*) (mtx_row signs I) (map (λs. rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) signs))"
    unfolding rhseq
    apply (intro vec_of_list_dot_rewrite)
    by (auto simp add: mtx_row_def)
  thus ?thesis unfolding mtx_row_def
    using map2_map_map 
    by (auto simp add: map2_map_map)
qed

lemma sum_list_distinct_filter:
  fixes f:: "'a  int"
  assumes "distinct xs" "distinct ys"
  assumes "set ys  set xs"
  assumes "x. x  set xs - set ys  f x = 0"
  shows "sum_list (map f xs) = sum_list (map f ys)"
  by (metis List.finite_set assms(1) assms(2) assms(3) assms(4) sum.mono_neutral_cong_left sum_list_distinct_conv_sum_set)

(* If we have a superset of the signs, we can drop to just the consistent ones *)
lemma construct_lhs_vector_drop_consistent:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  fixes I:: "nat list" 
  fixes signs:: "rat list list"
  assumes nonzero: "p0"
  assumes distinct_signs: "distinct signs"
  assumes all_info: "set (characterize_consistent_signs_at_roots_copr p qs)  set(signs)"
  assumes welldefined: "list_constr I (length qs)"
  shows
    "(vec_of_list (mtx_row signs I)  (construct_lhs_vector p qs signs)) =
     (vec_of_list (mtx_row (characterize_consistent_signs_at_roots_copr p qs) I) 
      (construct_lhs_vector p qs (characterize_consistent_signs_at_roots_copr p qs)))"
proof - 
  have h0: " sgn. sgn  set signs  sgn  consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  0 < rat_of_nat (card
                  {xa. poly p xa = 0  consistent_sign_vec_copr qs xa = sgn})  z I sgn = 0"
  proof - 
    have " sgn. sgn  set signs  sgn  consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  0 < rat_of_int (card
                  {xa. poly p xa = 0  consistent_sign_vec_copr qs xa = sgn})  {xa. poly p xa = 0  consistent_sign_vec_copr qs xa = sgn}  {}" 
      by fastforce
    then show ?thesis
    proof -
      { fix iis :: "rat list"
        have ff1: "0  p"
          using nonzero rsquarefree_def by blast
        obtain rr :: "(real  bool)  real" where
          ff2: "p. p (rr p)  Collect p = {}"
          by moura
        { assume "is. is = iis  {r. poly p r = 0  consistent_sign_vec_copr qs r = is}  {}"
          then have "is. consistent_sign_vec_copr qs (rr (λr. poly p r = 0  consistent_sign_vec_copr qs r = is)) = iis  {r. poly p r = 0  consistent_sign_vec_copr qs r = is}  {}"
            using ff2
            by (metis (mono_tags, lifting))
          then have "r. poly p r = 0  consistent_sign_vec_copr qs r = iis"
            using ff2 by smt
          then have "iis  consistent_sign_vec_copr qs ` set (sorted_list_of_set {r. poly p r = 0})"
            using ff1 poly_roots_finite by fastforce }
        then have "iis  set signs  iis  consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  ¬ 0 < rat_of_int (int (card {r. poly p r = 0  consistent_sign_vec_copr qs r = iis}))"
          by (metis (no_types) sgn. sgn  set signs  sgn  consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  0 < rat_of_int (int (card {xa. poly p xa = 0  consistent_sign_vec_copr qs xa = sgn}))  {xa. poly p xa = 0  consistent_sign_vec_copr qs xa = sgn}  {} characterize_root_list_p_def) }
      then show ?thesis
        by fastforce
    qed
  qed
  then have " sgn. sgn  set signs  sgn  consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  ((0 = rat_of_nat (card
                  {xa. poly p xa = 0  consistent_sign_vec_copr qs xa = sgn})  z I sgn = 0))"
    by auto
  then have hyp: " s. s  set signs  s  consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  (z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}) = 0)"
    by auto
  then have "(s set(signs). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) = 
        (s(set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
  proof - 
    have "set(signs) =(set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))) 
              (set(signs)-(consistent_sign_vec_copr qs ` set (characterize_root_list_p p)))"
      by blast
    then have sum_rewrite: "(s set(signs). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) =  
          (s (set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))) 
              (set(signs)-(consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
      by auto
    then have sum_split: "(s (set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))) 
              (set(signs)-(consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))
          = 
(s (set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))
+ (s (set(signs)-(consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
      by (metis (no_types, lifting) List.finite_set sum.Int_Diff)
    have sum_zero: "(s (set(signs)-(consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) = 0"   
      using hyp
      by (simp add: hyp)      
    show ?thesis using sum_rewrite sum_split sum_zero by linarith
  qed
  then have set_eq: "set (remdups
           (map (consistent_sign_vec_copr qs)
             (characterize_root_list_p p))) = set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))"
    using all_info
    by (simp add: characterize_consistent_signs_at_roots_copr_def subset_antisym)
  have hyp1: "(ssigns. z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) = 
        (sset (signs). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
    using distinct_signs sum_list_distinct_conv_sum_set by blast
  have hyp2: "(sremdups
           (map (consistent_sign_vec_copr qs)
             (characterize_root_list_p p)). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))
  = (s set (remdups
           (map (consistent_sign_vec_copr qs)
             (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
    using sum_list_distinct_conv_sum_set by blast 
  have set_sum_eq: "(s(set (signs)  (consistent_sign_vec_copr qs ` set (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) =
    (s set (remdups
           (map (consistent_sign_vec_copr qs)
             (characterize_root_list_p p))). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
    using set_eq by auto
  then have "(ssigns. z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) =
    (sremdups
           (map (consistent_sign_vec_copr qs)
             (characterize_root_list_p p)). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
    using set_sum_eq hyp1 hyp2
    using (sset signs. z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) = (sset signs  consistent_sign_vec_copr qs ` set (characterize_root_list_p p). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) by linarith
  then have "consistent_sign_vec_copr qs ` set (characterize_root_list_p p)  set signs 
    (p qss.
        characterize_consistent_signs_at_roots_copr p qss =
        remdups (map (consistent_sign_vec_copr qss) (characterize_root_list_p p))) 
    (ssigns. z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) =
    (sremdups
           (map (consistent_sign_vec_copr qs)
             (characterize_root_list_p p)). z I s * rat_of_nat (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
    by linarith
  then show ?thesis  unfolding lhs_dot_rewrite[OF nonzero]
    apply (auto intro!: sum_list_distinct_filter simp add: distinct_signs  characterize_consistent_signs_at_roots_copr_def)
    using all_info characterize_consistent_signs_at_roots_copr_def by auto[1]
qed

(* Both matrix_equation_helper_step and matrix_equation_main_step relate the matrix construction 
   to the Tarski queries, i.e. relate the product of a row of the matrix and the LHS vector to a 
   Tarski query on the RHS *)
lemma matrix_equation_helper_step:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  fixes I:: "nat list" 
  fixes signs:: "rat list list"
  assumes nonzero: "p0"
  assumes distinct_signs: "distinct signs"
  assumes all_info: "set (characterize_consistent_signs_at_roots_copr p qs)  set(signs)"
  assumes welldefined: "list_constr I (length qs)"
  assumes pairwise_rel_prime_1: "q. ((List.member qs q)  (coprime p q))"
  shows "(vec_of_list (mtx_row signs I)  (construct_lhs_vector p qs signs)) =
   rat_of_int (card {x. poly p x = 0  poly (prod_list (retrieve_polys qs I)) x > 0}) -
   rat_of_int (card {x. poly p x = 0  poly (prod_list (retrieve_polys qs I)) x < 0})"
proof -
  have "finite (set (map (consistent_sign_vec_copr qs)  (characterize_root_list_p p)))" by auto
  let ?gt = "(set (map (consistent_sign_vec_copr qs)  (characterize_root_list_p p))  {s. z I s = 1})"
  let ?lt = "  (set (map (consistent_sign_vec_copr qs)  (characterize_root_list_p p))  {s. z I s = -1})"  
  have eq: "set (map (consistent_sign_vec_copr qs)  (characterize_root_list_p p)) = ?gt  ?lt"
    apply auto
    by (metis characterize_root_list_p_def horiz_vector_helper_neg horiz_vector_helper_pos_ind nonzero nonzero_product pairwise_rel_prime_1 poly_roots_finite sorted_list_of_set(1) welldefined)
      (* First, drop the signs that are irrelevant *)
  from construct_lhs_vector_drop_consistent[OF assms(1-4)] have
    "vec_of_list (mtx_row signs I)  construct_lhs_vector p qs signs =
  vec_of_list (mtx_row (characterize_consistent_signs_at_roots_copr p qs) I) 
  construct_lhs_vector p qs (characterize_consistent_signs_at_roots_copr p qs)" .
    (* Now we split the sum *)
  from lhs_dot_rewrite[OF assms(1)]
  moreover have "... =
  (scharacterize_consistent_signs_at_roots_copr p qs.
    z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))" .
  moreover have "... =
  (sset (map (consistent_sign_vec_copr qs)  (characterize_root_list_p p)).
    z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))" unfolding characterize_consistent_signs_at_roots_copr_def sum_code[symmetric]
    by (auto)
  ultimately have "... =
  (s?gt. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) +
  (s?lt. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}))"
    apply (subst eq)
    apply (rule sum.union_disjoint)
    by auto
      (* Now recharacterize lt, gt*)
  have setroots: "set (characterize_root_list_p p) = {x. poly p x = 0}" unfolding characterize_root_list_p_def
    using poly_roots_finite nonzero rsquarefree_def set_sorted_list_of_set by blast    
  have *: "s. {x. poly p x = 0  consistent_sign_vec_copr qs x = s} =
        {x {x. poly p x = 0}. consistent_sign_vec_copr qs x = s}"
    by auto
  have lem_e1: "x. x  {x. poly p x = 0} 
       card
        {s  consistent_sign_vec_copr  qs ` {x. poly p x = 0}  {s. z I s = 1}.
         consistent_sign_vec_copr qs x = s} =
       (if 0 < poly (prod_list (retrieve_polys qs I)) x then 1 else 0)"
  proof -
    fix x
    assume rt: "x  {x. poly p x = 0}"
    then have 1: "{s  consistent_sign_vec_copr qs ` {x. poly p x = 0}  {s. z I s = 1}. consistent_sign_vec_copr qs x = s} =
      {s. z I s = 1  consistent_sign_vec_copr qs x = s}"
      by auto
    from horiz_vector_helper_pos[OF assms(1) assms(5) rt]
    have 2: "... = {s. (0 < poly (prod_list (retrieve_polys qs I)) x)   consistent_sign_vec_copr qs x = s}"
      using welldefined by blast
    have 3: "... = (if (0 < poly (prod_list (retrieve_polys qs I)) x)  then {consistent_sign_vec_copr qs x} else {})"
      by auto
    thus "card {s  consistent_sign_vec_copr qs ` {x. poly p x = 0}  {s. z I s = 1}. consistent_sign_vec_copr qs x = s} =
         (if 0 < poly (prod_list (retrieve_polys qs I)) x then 1 else 0) " using 1 2 3 by auto
  qed
  have e1: "(sconsistent_sign_vec_copr qs ` {x. poly p x = 0}  {s. z I s = 1}.
       card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}) =
     (sum (λx. if (poly (prod_list (retrieve_polys qs I)) x) > 0 then 1 else 0) {x. poly p x = 0})"
    unfolding * apply (rule sum_multicount_gen)
    using ‹finite (set (map (consistent_sign_vec_copr qs) (characterize_root_list_p p))) setroots apply auto[1]
    apply (metis List.finite_set setroots)
    using lem_e1 by auto
  have gtchr: "(s?gt. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) =
    rat_of_int (card {x. poly p x = 0  0 < poly (prod_list (retrieve_polys qs I)) x})"
    apply (auto simp add: setroots)
    apply (subst of_nat_sum[symmetric])
    apply (subst of_nat_eq_iff)
    apply (subst e1)
    apply (subst card_eq_sum)
    apply (rule sum.mono_neutral_cong_right)
    apply (metis List.finite_set setroots)
    by auto
  have lem_e2: "x. x  {x. poly p x = 0} 
       card
        {s  consistent_sign_vec_copr  qs ` {x. poly p x = 0}  {s. z I s = -1}.
         consistent_sign_vec_copr qs x = s} =
       (if poly (prod_list (retrieve_polys qs I)) x < 0 then 1 else 0)"
  proof -
    fix x
    assume rt: "x  {x. poly p x = 0}"
    then have 1: "{s  consistent_sign_vec_copr qs ` {x. poly p x = 0}  {s. z I s = -1}. consistent_sign_vec_copr qs x = s} =
      {s. z I s = -1  consistent_sign_vec_copr qs x = s}"
      by auto
    from horiz_vector_helper_neg[OF assms(1) assms(5) rt]
    have 2: "... = {s. (0 > poly (prod_list (retrieve_polys qs I)) x)   consistent_sign_vec_copr qs x = s}"
      using welldefined by blast
    have 3: "... = (if (0 > poly (prod_list (retrieve_polys qs I)) x)  then {consistent_sign_vec_copr qs x} else {})"
      by auto
    thus "card {s  consistent_sign_vec_copr qs ` {x. poly p x = 0}  {s. z I s = -1}. consistent_sign_vec_copr qs x = s} =
       (if poly (prod_list (retrieve_polys qs I)) x < 0 then 1 else 0)" using 1 2 3 by auto
  qed
  have e2: " (sconsistent_sign_vec_copr qs ` {x. poly p x = 0}  {s. z I s = - 1}.
       card {x. poly p x = 0  consistent_sign_vec_copr qs x = s}) =
     (sum (λx. if (poly (prod_list (retrieve_polys qs I)) x) < 0 then 1 else 0) {x. poly p x = 0})"
    unfolding * apply (rule sum_multicount_gen)
    using ‹finite (set (map (consistent_sign_vec_copr qs) (characterize_root_list_p p))) setroots apply auto[1]
     apply (metis List.finite_set setroots)
    using lem_e2 by auto
  have ltchr: "(s?lt. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) =
    - rat_of_int (card {x. poly p x = 0  0 > poly (prod_list (retrieve_polys qs I)) x})"
    apply (auto simp add: setroots sum_negf)
    apply (subst of_nat_sum[symmetric])
    apply (subst of_nat_eq_iff)
    apply (subst e2)
    apply (subst card_eq_sum)
    apply (rule sum.mono_neutral_cong_right)
       apply (metis List.finite_set setroots)
    by auto
  show ?thesis using gtchr ltchr
    using (sset (map (consistent_sign_vec_copr qs) (characterize_root_list_p p)). z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) = (sset (map (consistent_sign_vec_copr qs) (characterize_root_list_p p))  {s. z I s = 1}. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) + (sset (map (consistent_sign_vec_copr qs) (characterize_root_list_p p))  {s. z I s = - 1}. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) (scharacterize_consistent_signs_at_roots_copr p qs. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) = (sset (map (consistent_sign_vec_copr qs) (characterize_root_list_p p)). z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) ‹vec_of_list (mtx_row (characterize_consistent_signs_at_roots_copr p qs) I)  construct_lhs_vector p qs (characterize_consistent_signs_at_roots_copr p qs) = (scharacterize_consistent_signs_at_roots_copr p qs. z I s * rat_of_int (card {x. poly p x = 0  consistent_sign_vec_copr qs x = s})) ‹vec_of_list (mtx_row signs I)  construct_lhs_vector p qs signs = vec_of_list (mtx_row (characterize_consistent_signs_at_roots_copr p qs) I)  construct_lhs_vector p qs (characterize_consistent_signs_at_roots_copr p qs)
    by linarith
qed

(* A clean restatement of the helper lemma *)
lemma matrix_equation_main_step:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  fixes I:: "nat list" 
  fixes signs:: "rat list list"
  assumes nonzero: "p0"
  assumes distinct_signs: "distinct signs"
  assumes all_info: "set (characterize_consistent_signs_at_roots_copr p qs)  set(signs)"
  assumes welldefined: "list_constr I (length qs)"
  assumes pairwise_rel_prime_1: "q. ((List.member qs q)  (coprime p q))"
  shows "(vec_of_list (mtx_row signs I)  (construct_lhs_vector p qs signs)) =  
    construct_NofI p (retrieve_polys qs I)"
    unfolding construct_NofI_prop[OF nonzero]
    using matrix_equation_helper_step[OF assms]
    by linarith

lemma map_vec_vec_of_list_eq_intro:
  assumes "map f xs = map g ys"
  shows "map_vec f (vec_of_list xs) = map_vec g (vec_of_list ys)"
  by (metis assms vec_of_list_map)

(* Shows that as long as we have a "basis" of sign assignments (see assumptions all_info, welldefined), 
  and some other mild assumptions on our inputs (given in nonzero, distinct_signs, pairwise_rel_prime),
  the construction will be satisfied *)
theorem matrix_equation:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  fixes subsets:: "nat list list" 
  fixes signs:: "rat list list"
  assumes nonzero: "p0"
  assumes distinct_signs: "distinct signs"
  assumes all_info: "set (characterize_consistent_signs_at_roots_copr p qs)  set(signs)"
  assumes pairwise_rel_prime: "q. ((List.member qs q)  (coprime p q))"
  assumes welldefined: "all_list_constr (subsets) (length qs)"
  shows "satisfy_equation p qs subsets signs"
  unfolding satisfy_equation_def matrix_A_def
    construct_lhs_vector_def construct_rhs_vector_def all_list_constr_def
  apply (subst mult_mat_vec_of_list)
    apply (auto simp add: mtx_row_length intro!: map_vec_vec_of_list_eq_intro)
  using matrix_equation_main_step[OF assms(1-3) _ assms(4), unfolded construct_lhs_vector_def]
  using all_list_constr_def in_set_member welldefined by fastforce

(* Prettifying some theorems*)
definition roots:: "real poly  real set"
  where "roots p = {x. poly p x = 0}"

definition sgn::"'a::linordered_field  rat"
  where "sgn x = (if x > 0 then 1
                  else if x < 0 then -1
                  else 0)"

definition sgn_vec::"real poly list  real  rat list"
  where "sgn_vec qs x   map (sgn  (λq. poly q x)) qs"

definition consistent_signs_at_roots:: "real poly  real poly list  rat list set"
  where "consistent_signs_at_roots p qs =
    (sgn_vec qs) ` (roots p)"

lemma consistent_signs_at_roots_eq:
  assumes "p  0"
  shows "consistent_signs_at_roots p qs =
         set (characterize_consistent_signs_at_roots p qs)"
  unfolding consistent_signs_at_roots_def characterize_consistent_signs_at_roots_def
    characterize_root_list_p_def
  apply auto
  apply (subst set_sorted_list_of_set)
  using assms poly_roots_finite apply blast
  unfolding sgn_vec_def sgn_def signs_at_def squash_def o_def
  using roots_def apply auto[1]
  by (smt Collect_cong assms image_iff poly_roots_finite roots_def sorted_list_of_set(1))

abbreviation w_vec:: "real poly  real poly list  rat list list   rat vec"
  where "w_vec  construct_lhs_vector"

abbreviation v_vec:: "real poly  real poly list  nat list list  rat vec"
  where "v_vec  construct_rhs_vector"

abbreviation M_mat:: "rat list list  nat list list  rat mat"
  where "M_mat  matrix_A"

theorem matrix_equation_pretty:
  assumes "p0"
  assumes "q. q  set qs  coprime p q"
  assumes "distinct signs"
  assumes "consistent_signs_at_roots p qs  set signs"
  assumes "l i. l  set subsets  i  set l  i < length qs"
  shows "M_mat signs subsets *v w_vec p qs signs = v_vec p qs subsets"
  unfolding satisfy_equation_def[symmetric]
  apply (rule matrix_equation[OF assms(1) assms(3)])
  apply (metis assms(1) assms(2) assms(4) consistent_signs_at_roots_eq csa_list_copr_rel member_def)
  apply (simp add: assms(2) in_set_member)
  using Ball_set all_list_constr_def assms(5) list_constr_def member_def by fastforce

end

Theory BKR_Proofs

theory BKR_Proofs
  imports "Matrix_Equation_Construction"

begin

definition well_def_signs:: "nat => rat list list  bool"
  where "well_def_signs num_polys sign_conds   signs  set(sign_conds). (length signs = num_polys)"

definition satisfies_properties:: "real poly  real poly list  nat list list  rat list list  rat mat  bool"
  where "satisfies_properties p qs subsets signs matrix = 
  ( all_list_constr subsets (length qs)  well_def_signs (length qs) signs  distinct signs
   satisfy_equation p qs subsets signs   invertible_mat matrix   matrix = matrix_A signs subsets
   set (characterize_consistent_signs_at_roots_copr p qs)  set(signs)
  )"

section "Base Case"

lemma mat_base_case:
  shows "matrix_A [[1],[-1]] [[],[0]] = (mat_of_rows_list 2 [[1, 1], [1, -1]])"
  unfolding matrix_A_def mtx_row_def z_def apply (auto)
  by (simp add: numeral_2_eq_2)

lemma base_case_sgas:
  fixes q p:: "real poly"
  assumes nonzero: "p  0"
  assumes rel_prime: "coprime p q"
  shows "set (characterize_consistent_signs_at_roots_copr p [q])  {[1], [- 1]}"
  unfolding characterize_consistent_signs_at_roots_copr_def consistent_sign_vec_copr_def by auto

lemma base_case_sgas_alt:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  assumes len1: "length qs = 1"
  assumes nonzero: "p  0"
  assumes rel_prime: "q. (List.member qs q)  coprime p q"
  shows "set (characterize_consistent_signs_at_roots_copr p qs)  {[1], [- 1]}"
proof - 
  have ex_q: "(q::real poly). qs = [q]" 
    using len1    
    using length_Suc_conv[of qs 0] by auto
  then show ?thesis using base_case_sgas nonzero rel_prime
    apply (auto)
    using characterize_consistent_signs_at_roots_copr_def consistent_sign_vec_copr_def by auto
qed

lemma base_case_satisfy_equation:
  fixes q p:: "real poly"
  assumes nonzero: "p  0"
  assumes rel_prime: "coprime p q"
  shows "satisfy_equation p [q] [[],[0]] [[1],[-1]]"
proof - 
  have h1: "set (characterize_consistent_signs_at_roots_copr p [q])  {[1], [- 1]}"
    using base_case_sgas assms by auto
  have h2: " qa. List.member [q] qa  coprime p qa" using rel_prime
    by (simp add: member_rec(1) member_rec(2))
  have h3: "all_list_constr [[], [0]] (Suc 0)" unfolding all_list_constr_def
    by (simp add: list_constr_def member_def)
  then show ?thesis using matrix_equation[where p = "p", where qs = "[q]", where signs = "[[1],[-1]]", where subsets = "[[],[0]]"]
      nonzero h1 h2 h3 by auto
qed

lemma base_case_satisfy_equation_alt:
  fixes p:: "real poly"
  fixes qs:: "real poly list"
  assumes len1: "length qs = 1"
  assumes nonzero: "p  0"
  assumes rel_prime: "q. (List.member qs q)  coprime p q"
  shows "satisfy_equation p qs [[],[0]] [[1],[-1]]"
proof - 
  have ex_q: "(q::real poly). qs = [q]" 
    using len1    
    using length_Suc_conv[of qs 0] by auto
  then show ?thesis using base_case_satisfy_equation nonzero rel_prime
    apply (auto)
    by (simp add: member_rec(1)) 
qed

lemma base_case_matrix_eq:
  fixes q p:: "real poly"
  assumes nonzero: "p  0"
  assumes rel_prime: "coprime p q"
  shows "(mult_mat_vec (mat_of_rows_list 2 [[1, 1], [1, -1]]) (construct_lhs_vector p [q] [[1],[-1]]) = 
    (construct_rhs_vector p [q] [[],[0]]))"                      
  using mat_base_case base_case_satisfy_equation unfolding satisfy_equation_def
  by (simp add: nonzero rel_prime)

lemma less_two:
  shows "j < Suc (Suc 0)  j = 0  j = 1" by auto 

lemma inverse_mat_base_case: 
  shows "inverts_mat (mat_of_rows_list 2 [[1/2, 1/2], [1/2, -(1/2)]]::rat mat) (mat_of_rows_list 2 [[1, 1], [1, -1]]::rat mat)"
  unfolding inverts_mat_def mat_of_rows_list_def mat_eq_iff apply auto
  unfolding less_two by (auto simp add: scalar_prod_def)

lemma inverse_mat_base_case_2: 
  shows "inverts_mat (mat_of_rows_list 2 [[1, 1], [1, -1]]::rat mat) (mat_of_rows_list 2 [[1/2, 1/2], [1/2, -(1/2)]]::rat mat) "
  unfolding inverts_mat_def mat_of_rows_list_def mat_eq_iff apply auto
  unfolding less_two by (auto simp add: scalar_prod_def)

lemma base_case_invertible_mat: 
  shows "invertible_mat (matrix_A [[1], [- 1]] [[], [0]])"
  unfolding invertible_mat_def using inverse_mat_base_case inverse_mat_base_case_2
  apply (auto)
   apply (simp add: mat_base_case mat_of_rows_list_def)
  using mat_base_case by auto 

section "Inductive Step"

subsection "Lemmas on smashing subsets and signs"

lemma signs_smash_property:
  fixes signs1 signs2 :: "'a list list"
  fixes a b:: "nat"
  shows " (elem :: 'a list). (elem  (set (signs_smash signs1 signs2))  
    ( n m :: nat. 
      elem = ((nth signs1 n)@(nth signs2 m))))"
proof clarsimp 
  fix elem 
  assume assum: "elem  set (signs_smash signs1 signs2)"
  show "n m. elem = signs1 ! n @ signs2 ! m"
    using assum unfolding signs_smash_def apply (auto)
    by (metis in_set_conv_nth) 
qed

lemma signs_smash_property_set:
  fixes signs1 signs2 :: "'a list list"
  fixes a b:: "nat"
  shows " (elem :: 'a list). (elem  (set (signs_smash signs1 signs2))  
    ( (elem1::'a list).  (elem2::'a list). 
      (elem1  set(signs1)  elem2  set(signs2)  elem = (elem1@elem2))))"
proof clarsimp 
  fix elem 
  assume assum: "elem  set (signs_smash signs1 signs2)"
  show "elem1. elem1  set signs1  (elem2. elem2  set signs2  elem = elem1 @ elem2)"
    using assum unfolding signs_smash_def by auto
qed

lemma subsets_smash_property:
  fixes subsets1 subsets2 :: "nat list list"
  fixes n:: "nat"
  shows " (elem :: nat list). (List.member (subsets_smash n subsets1 subsets2) elem)  
    ( (elem1::nat list).  (elem2::nat list).
      (elem1  set(subsets1)  elem2  set(subsets2)  elem = (elem1 @ ((map ((+) n) elem2)))))"
proof - 
  let ?new_subsets = "(map (map ((+) n)) subsets2)"
    (* a slightly unorthodox use of signs_smash, but it closes the proof *)
  have "subsets_smash n subsets1 subsets2 = signs_smash subsets1 ?new_subsets" 
    unfolding subsets_smash_def signs_smash_def apply (auto)
    by (simp add: comp_def) 
  then show ?thesis
    by (smt imageE in_set_member set_map signs_smash_property_set)
qed

  (* If subsets for smaller systems are well-defined, then subsets for combined systems
   are well-defined *)
subsection "Well-defined subsets preserved when smashing"

lemma list_constr_append:
  "list_constr a n  list_constr b n  list_constr (a@b) n"
  apply (auto) unfolding list_constr_def
  using list_all_append by blast 

lemma well_def_step: 
  fixes qs1 qs2 :: "real poly list"
  fixes subsets1 subsets2 :: "nat list list"
  assumes well_def_subsets1: "all_list_constr (subsets1) (length qs1)"
  assumes well_def_subsets2: "all_list_constr (subsets2) (length qs2)"
  shows "all_list_constr ((subsets_smash (length qs1) subsets1 subsets2))
     (length (qs1 @ qs2))"
proof - 
  have h1: "elem.
     List.member (subsets_smash (length qs1) subsets1 subsets2) elem 
     (elem1 elem2. elem1  set subsets1  elem2  set subsets2  elem = elem1 @ map ((+) (length qs1)) elem2)"
    using subsets_smash_property by blast
  have h2: "elem1 elem2. (elem1  set subsets1  elem2  set subsets2)  list_constr (elem1 @ map ((+) (length qs1)) elem2) (length (qs1 @ qs2))"
  proof clarsimp 
    fix elem1
    fix elem2
    assume e1: "elem1  set subsets1"
    assume e2: "elem2  set subsets2"
    have h1: "list_constr elem1  (length qs1 + length qs2) " 
    proof - 
      have h1: "list_constr elem1  (length qs1)"  using e1 well_def_subsets1 
        unfolding all_list_constr_def
        by (simp add: in_set_member) 
      then show ?thesis unfolding list_constr_def
        by (simp add: list.pred_mono_strong) 
    qed
    have h2: "list_constr (map ((+) (length qs1)) elem2) (length qs1 + length qs2)"
    proof - 
      have h1: "list_constr elem2  (length qs2)"  using e2 well_def_subsets2 
        unfolding all_list_constr_def
        by (simp add: in_set_member) 
      then show ?thesis unfolding list_constr_def
        by (simp add: list_all_length)
    qed    
    show "list_constr (elem1 @ map ((+) (length qs1)) elem2) (length qs1 + length qs2)" 
      using h1 h2 list_constr_append
      by blast 
  qed
  then show ?thesis using h1 unfolding all_list_constr_def by auto
qed

subsection "Well def signs preserved when smashing"
lemma well_def_signs_step: 
  fixes qs1 qs2 :: "real poly list"
  fixes signs1 signs2 :: "rat list list"
  assumes "length qs1 > 0"
  assumes "length qs2 > 0"
  assumes well_def1: "well_def_signs (length qs1) signs1"
  assumes well_def2: "well_def_signs (length qs2) signs2"
  shows "well_def_signs (length (qs1@qs2)) (signs_smash signs1 signs2)"
  using well_def1 well_def2 unfolding well_def_signs_def using signs_smash_property_set[of signs1 signs2] length_append by auto

subsection "Distinct signs preserved when smashing"

lemma distinct_map_append:
  assumes "distinct ls"
  shows "distinct (map ((@) xs) ls)"
  unfolding distinct_map inj_on_def using assms by auto

lemma length_eq_append:
  assumes "length y = length b"
  shows "(x @ y = a @ b)  x=a  y = b"
  by (simp add: assms)

lemma distinct_step:
  fixes qs1 qs2 :: "real poly list"
  fixes signs1 signs2 :: "rat list list"
  assumes well_def1: "well_def_signs n1 signs1"
  assumes well_def2: "well_def_signs n2 signs2"
  assumes distinct1: "distinct signs1"
  assumes distinct2: "distinct signs2"
  shows "distinct (signs_smash signs1 signs2)"
  unfolding signs_smash_def
  using distinct1
proof(induction signs1)
  case Nil
  then show ?case by auto
next
  case (Cons a signs1)
  then show ?case apply (auto simp add: distinct2 distinct_map_append)
    using assms unfolding well_def_signs_def
    by (simp add: distinct1 distinct2 length_eq_append)
qed

(* In this section we will show that if signs1 contains all consistent sign assignments and signs2 
contains all consistent sign assignments, then their smash contains all consistent sign assignments.  
Intuitively, this makes sense because signs1 and signs2 are carrying information about different 
sets of polynomials, and their smash contains all possible combinations of that information.
*)
subsection "Consistent sign assignments preserved when smashing"

lemma subset_smash_signs: 
  fixes a1 b1 a2 b2:: "rat list list"
  assumes sub1: "set a1  set a2"
  assumes sub2: "set b1  set b2"
  shows "set (signs_smash a1 b1)  set (signs_smash a2 b2)"
proof -
  have h1: "xset (signs_smash a1 b1). xset (signs_smash a2 b2)"
  proof clarsimp 
    fix x
    assume x_in: "x  set (signs_smash a1 b1)"
    have h1: " n m :: nat. x = (nth a1 n)@(nth b1 m)"
      using signs_smash_property[of a1 b1] x_in
      by blast
    then have h2: " p q::nat. x = (nth a2 p)@(nth b2 q)"
      using sub1 sub2 apply (auto)
      by (metis nth_find_first signs_smash_property_set subset_code(1) x_in) 
    then show "x  set (signs_smash a2 b2)" 
      unfolding signs_smash_def apply (auto)
      using signs_smash_property_set sub1 sub2 x_in by fastforce 
  qed
  then show ?thesis
    by blast 
qed

lemma subset_helper:
  fixes p::