Session Modular_arithmetic_LLL_and_HNF_algorithms

Theory Matrix_Change_Row

section ‹Missing Matrix Operations›

text ‹In this theory we provide an operation that can change a single
  row in a matrix efficiently, and all other rows in the matrix implementation
  will be reused.›

(* TODO: move this part into JNF-AFP-entry *)

theory Matrix_Change_Row
  imports 
    Jordan_Normal_Form.Matrix_IArray_Impl
    Polynomial_Interpolation.Missing_Unsorted
begin

definition change_row :: "nat  (nat  'a  'a)  'a mat  'a mat" where
  "change_row k f A = mat (dim_row A) (dim_col A) (λ (i,j). 
     if i = k then f j (A $$ (k,j)) else A $$ (i,j))"

lemma change_row_carrier[simp]: 
  "(change_row k f A  carrier_mat nr nc) = (A  carrier_mat nr nc)" 
  "dim_row (change_row k f A) = dim_row A" 
  "dim_col (change_row k f A) = dim_col A" 
  unfolding change_row_def carrier_mat_def by auto

lemma change_row_index[simp]: "A  carrier_mat nr nc  i < nr  j < nc 
  change_row k f A $$ (i,j) = (if i = k then f j (A $$ (k,j)) else A $$ (i,j))" 
  "i < dim_row A  j < dim_col A  change_row k f A $$ (i,j) = (if i = k then f j (A $$ (k,j)) else A $$ (i,j))" 
  unfolding change_row_def by auto

lift_definition change_row_impl :: "nat  (nat  'a  'a)  'a mat_impl  'a mat_impl" is
  "λ k f (nr,nc,A). let Ak = IArray.sub A k; Arows = IArray.list_of A;
     Ak' = IArray.IArray (map (λ (i,c). f i c) (zip [0 ..< nc] (IArray.list_of Ak)));
     A' = IArray.IArray (Arows [k := Ak'])
     in (nr,nc,A')" 
proof (auto, goal_cases)
  case (1 k f nc b row)
  show ?case 
  proof (cases b)
    case (IArray rows)
    with 1 have "row  set rows  k < length rows 
        row = IArray (map (λ (i,c). f i c) (zip [0 ..< nc] (IArray.list_of (rows ! k))))"
      by (cases "k < length rows", auto simp: set_list_update dest: in_set_takeD in_set_dropD)
    with 1 IArray show ?thesis by (cases, auto)
  qed
qed

lemma change_row_code[code]: "change_row k f (mat_impl A) = (if k < dim_row_impl A 
  then mat_impl (change_row_impl k f A) 
  else Code.abort (STR ''index out of bounds in change_row'') (λ _. change_row k f (mat_impl A)))"
  (is "?l = ?r")
proof (cases "k < dim_row_impl A")
  case True
  hence id: "?r = mat_impl (change_row_impl k f A)" by simp
  show ?thesis unfolding id unfolding change_row_def
  proof (rule eq_matI, goal_cases)
    case (1 i j)
    thus ?case using True
      by (transfer, auto simp: mk_mat_def)
  qed (transfer, auto)+
qed simp

end

Theory Signed_Modulo

section ‹Signed Modulo Operation›

theory Signed_Modulo
  imports 
    Berlekamp_Zassenhaus.Poly_Mod
    Sqrt_Babylonian.Sqrt_Babylonian_Auxiliary
begin

text ‹The upcoming definition of symmetric modulo 
  is different to the HOL-Library-Signed\_Division.smod, since
  here the modulus will be in range $\{-m/2,...,m/2\}$, 
  whereas there -1 symmod m = m - 1.

  The advantage of have range $\{-m/2,...,m/2\}$ is that small negative
  numbers are represented by small numbers.

  One limitation is that the symmetric modulo is only working properly,
  if the modulus is a positive number.›

definition sym_mod :: "int  int  int" (infixl "symmod" 70) where
  "sym_mod x y = poly_mod.inv_M y (x mod y)"

lemma sym_mod_code[code]: "sym_mod x y = (let m = x mod y
   in if m + m  y then m else m - y)" 
  unfolding sym_mod_def poly_mod.inv_M_def Let_def ..

lemma sym_mod_zero[simp]: "n symmod 0 = n" "n > 0  0 symmod n = 0"
  unfolding sym_mod_def poly_mod.inv_M_def by auto

lemma sym_mod_range: "y > 0  x symmod y  {- ((y - 1) div 2) .. y div 2}"
  unfolding sym_mod_def poly_mod.inv_M_def using pos_mod_bound[of y x]
  by (cases "x mod y  y", auto) 
    (smt (verit) Euclidean_Division.pos_mod_bound Euclidean_Division.pos_mod_sign half_nonnegative_int_iff)+

text ‹The range is optimal in the sense that exactly y elements can be represented.›
lemma card_sym_mod_range: "y > 0  card {- ((y - 1) div 2) .. y div 2} = y" 
  by simp

lemma sym_mod_abs: "y > 0  ¦x symmod y¦ < y"
  "y  1  ¦x symmod y¦  y div 2"
  using sym_mod_range[of y x] by auto


lemma sym_mod_sym_mod[simp]: "x symmod y symmod y = x symmod (y :: int)" 
  unfolding sym_mod_def using poly_mod.M_def poly_mod.M_inv_M_id by auto

lemma sym_mod_diff_eq: "(a symmod c - b symmod c) symmod c = (a - b) symmod c" 
  unfolding sym_mod_def
  by (metis mod_diff_cong mod_mod_trivial poly_mod.M_def poly_mod.M_inv_M_id)

lemma sym_mod_sym_mod_cancel: "c dvd b  a symmod b symmod c = a symmod c" 
  using mod_mod_cancel[of c b] unfolding sym_mod_def
  by (metis poly_mod.M_def poly_mod.M_inv_M_id)

lemma sym_mod_diff_right_eq: "(a - b symmod c) symmod c = (a - b) symmod c" 
  using sym_mod_diff_eq by (metis sym_mod_sym_mod)

lemma sym_mod_mult_right_eq: "a * (b symmod c) symmod c = a * b symmod c" 
  unfolding sym_mod_def by (metis poly_mod.M_def poly_mod.M_inv_M_id mod_mult_right_eq)

lemma dvd_imp_sym_mod_0 [simp]:
  "b symmod a = 0" if "a > 0" "a dvd b"
  unfolding sym_mod_def poly_mod.inv_M_def using that by simp

lemma sym_mod_0_imp_dvd [dest!]:
  "b dvd a" if "a symmod b = 0"
  using that unfolding sym_mod_def poly_mod.inv_M_def
  by (smt (verit) Euclidean_Division.pos_mod_bound dvd_eq_mod_eq_0)

definition sym_div :: "int  int  int" (infixl "symdiv" 70) where
  "sym_div x y = (let d = x div y; m = x mod y in 
       if m + m  y then d else d + 1)"

lemma of_int_mod_integer: "(of_int (x mod y) :: integer) = (of_int x :: integer) mod (of_int y)" 
  using integer_of_int_eq_of_int modulo_integer.abs_eq by presburger

lemma sym_div_code[code]: 
  "sym_div x y = (let yy = integer_of_int y in 
     (case divmod_integer (integer_of_int x) yy
     of (d, m)  if m + m  yy then int_of_integer d else (int_of_integer (d + 1))))"
  unfolding sym_div_def Let_def divmod_integer_def split
  apply (rule if_cong, subst of_int_le_iff[symmetric], unfold of_int_add)
  by (subst (1 2) of_int_mod_integer, auto)

lemma sym_mod_sym_div: assumes y: "y > 0" shows "x symmod y = x - sym_div x y * y"
proof -
  let ?z = "x - y * (x div y)" 
  let ?u = "y * (x div y)" 
  have "x = y * (x div y) + x mod y" using y by simp
  hence id: "x mod y = ?z" by linarith
  have "x symmod y = poly_mod.inv_M y ?z" unfolding sym_mod_def id by auto
  also have " = (if ?z + ?z  y then ?z else ?z - y)" unfolding poly_mod.inv_M_def ..
  also have " = x - (if (x mod y) + (x mod y)  y then x div y else x div y + 1) * y" 
    by (simp add: algebra_simps id)
  also have "(if (x mod y) + (x mod y)  y then x div y else x div y + 1) = sym_div x y" 
    unfolding sym_div_def Let_def ..
  finally show ?thesis .
qed
  
lemma dvd_sym_div_mult_right [simp]:
  "(a symdiv b) * b = a" if "b > 0" "b dvd a"
  using sym_mod_sym_div[of b a] that by simp

lemma dvd_sym_div_mult_left [simp]:
  "b * (a symdiv b) = a" if "b > 0" "b dvd a"
  using dvd_sym_div_mult_right[OF that] by (simp add: ac_simps)


end

Theory Storjohann_Mod_Operation

section ‹Storjohann's Lemma 13›

text ‹This theory contains the result that one can always perform a mod-operation on
  the entries of the $d\mu$-matrix.›

theory Storjohann_Mod_Operation
  imports 
    LLL_Basis_Reduction.LLL_Certification
    Signed_Modulo
begin 

lemma map_vec_map_vec: "map_vec f (map_vec g v) = map_vec (f o g) v" 
  by (intro eq_vecI, auto)

context semiring_hom
begin

(* TODO: move *)
lemma mat_hom_add: assumes A: "A  carrier_mat nr nc" and B: "B  carrier_mat nr nc"
  shows "math (A + B) = math A + math B"
  by (intro eq_matI, insert A B, auto simp: hom_add)
end

text ‹We now start to prove lemma 13 of Storjohann's paper.›
context
  fixes A I :: "'a :: field mat" and n :: nat
  assumes A: "A  carrier_mat n n" 
  and det: "det A  0" 
  and I: "I = the (mat_inverse A)" 
begin
lemma inverse_via_det: "I * A = 1m n" "A * I = 1m n" "I  carrier_mat n n" 
  "I = mat n n (λ (i,j). det (replace_col A (unit_vec n j) i) / det A)"
proof -
  from det_non_zero_imp_unit[OF A det] 
  have Unit: "A  Units (ring_mat TYPE('a) n n)" .
  from mat_inverse(1)[OF A, of n] Unit I have "mat_inverse A = Some I" 
    by (cases "mat_inverse A", auto)
  from mat_inverse(2)[OF A this]
  show left: "I * A = 1m n" and right: "A * I = 1m n" and I: "I  carrier_mat n n" 
    by blast+
  {
    fix i j
    assume i: "i < n" and j: "j < n" 
    from I i j have cI: "col I j $ i = I $$ (i,j)" by simp
    from j have uv: "unit_vec n j  carrier_vec n" by auto
    from j I have col: "col I j  carrier_vec n" by auto
    from col_mult2[OF A I j, unfolded right] j
    have "A *v col I j = unit_vec n j" by simp
    from cramer_lemma_mat[OF A col i, unfolded this cI]
    have "I $$ (i,j) = det (replace_col A (unit_vec n j) i) / det A" using det by simp
  }
  thus "I = mat n n (λ (i,j). det (replace_col A (unit_vec n j) i) / det A)"
    by (intro eq_matI, use I in auto)
qed

lemma matrix_for_singleton_entry: assumes i: "i < n" and 
  j: "j < n" 
  and Rdef: "R = mat n n ( λ ij. if ij = (i,j) then c :: 'a else 0)" 
shows "mat n n
   (λ(i', j'). if i' = i then c * det (replace_col A (unit_vec n j') j) / det A
       else 0) * A = R" 
proof -
  note I = inverse_via_det(3)
  have R: "R  carrier_mat n n" unfolding Rdef by auto
  have "(R * I) * A = R * (I * A)" using I A R by auto
  also have "I * A = 1m n" unfolding inverse_via_det(1) ..
  also have "R *  = R" using R by simp
  also have "R * I = mat n n (λ (i',j'). row R i'  col I j')"
    using I R unfolding times_mat_def by simp
  also have " = mat n n ( λ (i',j'). if i' = i then c * I $$ (j, j') else 0)" 
    (is "mat n n ?f = mat n n ?g")
  proof -
    {
      fix i' j'
      assume i': "i' < n" and j': "j' < n" 
      have "?f (i',j') = ?g (i',j')" 
      proof (cases "i' = i")
        case False
        hence "row R i' = 0v n" unfolding Rdef using i'
          by (intro eq_vecI, auto simp: Matrix.row_def)
        thus ?thesis using False i' j' I by simp
      next
        case True
        hence "row R i' = c v unit_vec n j" unfolding Rdef using i' j' i j
          by (intro eq_vecI, auto simp: Matrix.row_def)
        with True show ?thesis using i' j' I j by simp
      qed
    }
    thus ?thesis by auto
  qed
  finally show ?thesis unfolding inverse_via_det(4) using j 
    by (auto intro!: arg_cong[of _ _ "λ x. x * A"])
qed
end

lemma (in gram_schmidt_fs_Rn) det_M_1: "det (M m) = 1" 
proof -
  have "det (M m) = prod_list (diag_mat (M m))" 
    by (rule det_lower_triangular[of m], auto simp: μ.simps)
  also have " = 1" 
    by (rule prod_list_neutral, auto simp: diag_mat_def μ.simps)
  finally show ?thesis .
qed

context gram_schmidt_fs_int
begin
lemma assumes IM: "IM = the (mat_inverse (M m))" 
  shows inv_mu_lower_triangular: " k i. k < i  i < m  IM $$ (k, i) = 0"
  and inv_mu_diag: " k. k < m  IM $$ (k, k) = 1"
  and d_inv_mu_integer: " i j. i < m  j < m  d i * IM $$ (i,j)  " 
  and inv_mu_inverse: "IM * M m = 1m m" "M m * IM = 1m m" "IM  carrier_mat m m" 
proof - 
  note * = inverse_via_det[OF M_dim(3) _ IM, unfolded det_M_1]
  from * show inv: "IM * M m = 1m m" "M m * IM = 1m m" 
    and IM: "IM  carrier_mat m m"  by auto
  from * have IM_det: "IM = mat m m (λ(i, j). det (replace_col (M m) ((unit_vec m) j) i))" 
    by auto
  from matrix_equality have "IM * FF = IM * ((M m) * Fs)" by simp
  also have " = (IM * M m) * Fs" using M_dim(3) IM Fs_dim(3)
    by (metis assoc_mult_mat)
  also have " = Fs" unfolding inv using Fs_dim(3) by simp
  finally have equality: "IM * FF = Fs" .
  {
    fix i k
    assume i: "k < i" "i < m" 
    show "IM $$ (k, i) = 0" using i M_dim unfolding IM_det
      by (simp, subst det_lower_triangular[of m], auto simp: replace_col_def μ.simps diag_mat_def)
  } note IM_lower_triag = this
  {
    fix k
    assume k: "k < m" 
    show "IM $$ (k,k) = 1" using k M_dim unfolding IM_det
      by (simp, subst det_lower_triangular[of m], auto simp: replace_col_def μ.simps diag_mat_def
        intro!: prod_list_neutral)
  } note IM_diag_1 = this
  {
    fix k
    assume k: "k < m" 
    let ?f = "λ i. IM $$ (k, i) v fs ! i" 
    let ?sum = "M.sumlist (map ?f [0..<m])" 
    let ?sumk = "M.sumlist (map ?f [0..<k])" 
    have set: "set (map ?f [0..<m])  carrier_vec n" using fs_carrier by auto
    hence sum: "?sum  carrier_vec n" by simp
    from set k have setk: "set (map ?f [0..<k])  carrier_vec n" by auto
    hence sumk: "?sumk  carrier_vec n" by simp
    from sum have dim_sum: "dim_vec ?sum = n" by simp
    have "gso k = row Fs k" using k by auto
    also have " = row (IM * FF) k" unfolding equality ..
    also have "IM * FF = mat m n (λ (i,j). row IM i   col FF j)" 
      unfolding times_mat_def using IM FF_dim by auto
    also have "row  k = vec n (λ j. row IM k  col FF j)" 
      unfolding Matrix.row_def using IM FF_dim k by auto
    also have " = vec n (λ j.  i < m. IM $$ (k, i) * fs ! i $ j)" 
      by (intro eq_vecI, insert IM k, auto simp: scalar_prod_def Matrix.row_def intro!: sum.cong)
    also have " = ?sum" 
      by (intro eq_vecI, insert IM, unfold dim_sum, subst sumlist_vec_index, 
        auto simp: o_def sum_list_sum_nth intro!: sum.cong)
    also have "[0..<m] = [0..<k] @ [k] @ [Suc k ..<m]" using k
      by (simp add: list_trisect)
    also have "M.sumlist (map ?f ) = ?sumk + 
      (?f k + M.sumlist (map ?f [Suc k ..< m]))" 
      unfolding map_append 
      by (subst M.sumlist_append; (subst M.sumlist_append)?, insert k fs_carrier, auto)
    also have "M.sumlist (map ?f [Suc k ..< m]) = 0v n" 
      by (rule sumlist_neutral, insert IM_lower_triag, auto)
    also have "IM $$ (k,k) = 1" using IM_diag_1[OF k] .
    finally have gso: "gso k = ?sumk + fs ! k"  using k by simp
    define b where "b = vec k (λ j. fs ! j  fs ! k)" 
    {
      fix j
      assume jk: "j < k" 
      with k have j: "j < m" by auto
      have "fs ! j  gso k = fs ! j  (?sumk + fs ! k)" 
        unfolding gso by simp
      also have "fs ! j  gso k = 0" using jk k
        by (simp add: fi_scalar_prod_gso gram_schmidt_fs.μ.simps)
      also have "fs ! j  (?sumk + fs ! k)
         = fs ! j  ?sumk + fs ! j  fs ! k" 
        by (rule scalar_prod_add_distrib[OF _ sumk], insert j k, auto)
      also have "fs ! j  fs ! k = b $ j" unfolding b_def using jk by simp
      finally have "b $ j = - (fs ! j  ?sumk)" by linarith
    } note b_index = this
    let ?x = "vec k (λ i. - IM $$ (k, i))" 
    have x: "?x  carrier_vec k" by auto
    from k have km: "k  m" by simp 
    have bGx: "b = Gramian_matrix fs k *v (vec k (λ i. - IM $$ (k, i)))" 
      unfolding Gramian_matrix_alt_alt_def[OF km]
    proof (rule eq_vecI; simp)
      fix i
      assume i: "i < k" 
      have "b $ i = - (x[0..<k]. fs ! i  (IM $$ (k, x) v fs ! x))" 
        unfolding b_index[OF i]
        by (subst scalar_prod_right_sum_distrib, insert setk i k, auto simp: o_def)
      also have " = vec k (λj. fs ! i  fs ! j)  vec k (λi. - IM $$ (k, i))" 
        by (subst (3) scalar_prod_def, insert i k, auto simp: o_def sum_list_sum_nth simp flip: sum_negf
          intro!: sum.cong)
      finally show "b $ i = vec k (λj. fs ! i  fs ! j)  vec k (λi. - IM $$ (k, i))" .
    qed (simp add: b_def)
    have G: "Gramian_matrix fs k  carrier_mat k k" 
      unfolding Gramian_matrix_alt_alt_def[OF km] by simp
    from cramer_lemma_mat[OF G x, folded bGx Gramian_determinant_def]
    have "i < k  
      d k * IM $$ (k, i) = - det (replace_col (Gramian_matrix fs k) (vec k (λ j. fs ! j  fs ! k)) i)" 
      for i unfolding b_def by simp
  } note IM_lower_values = this
  {
    fix i j
    assume i: "i < m" and j: "j < m" 
    from i have im: "i  m" by auto
    consider (1) "j < i" | (2) "j = i" | (3) "i < j" by linarith
    thus "d i * IM $$ (i,j)  "
    proof cases
      case 1
      show ?thesis unfolding IM_lower_values[OF i 1] replace_col_def Gramian_matrix_alt_alt_def[OF im]
        by (intro Ints_minus Ints_det, insert i j, auto intro!: Ints_scalar_prod[of _ n] fs_int)
    next
      case 3
      show ?thesis unfolding IM_lower_triag[OF 3 j] by simp
    next
      case 2
      show ?thesis unfolding IM_diag_1[OF i] 2 using i unfolding Gramian_determinant_def
         Gramian_matrix_alt_alt_def[OF im]
        by (intro Ints_mult Ints_det, insert i j, auto intro!: Ints_scalar_prod[of _ n] fs_int)
    qed 
  }
qed

definition inv_mu_ij_mat :: "nat  nat  int  int mat" where
 "inv_mu_ij_mat i j c = (let
    B = mat m m (λ ij. if ij = (i,j) then c else 0);
    C = mat m m (λ (i,j). the_inv (of_int :: _  'a) (d i * the (mat_inverse (M m)) $$ (i,j)))
   in B * C + 1m m)" 

lemma inv_mu_ij_mat: assumes i: "i < m" and ji: "j < i" 
  shows 
(* Effect on μ *)
   "map_mat of_int (inv_mu_ij_mat i j c) * M m =
    mat m m (λij. if ij = (i, j) then of_int c * d j else 0) + M m" (* only change value of μ_ij *)
(* Effect on A *)
  "A  carrier_mat m n  c mod p = 0  map_mat (λ x. x mod p) (inv_mu_ij_mat i j c * A) = 
    (map_mat (λ x. x mod p) A)" (* no change (mod p) *)
(* The transformation-matrix is ... *)
  "inv_mu_ij_mat i j c  carrier_mat m m" (* ... of dimension m*m *)
  "i' < j'  j' < m  inv_mu_ij_mat i j c $$ (i',j') = 0" (* ... lower triangular *)
  "k < m  inv_mu_ij_mat i j c $$ (k,k) = 1" (* ... with diagonal all 1 *)  
proof -
  obtain IM where IM: "IM = the (mat_inverse (M m))" by auto
  let ?oi = "of_int :: _  'a" 
  let ?C = "mat m m (λ ij. if ij = (i,j) then ?oi c else 0)" 
  let ?D = "mat m m (λ (i,j). d i * IM $$ (i,j))" 
  have oi: "inj ?oi" unfolding inj_on_def by auto
  have C: "?C  carrier_mat m m" by auto
  from i ji have j: "j < m" by auto
  from j have jm: "{0..<m} = {0..<j}  {j}  {Suc j..<m}" by auto
  note IM_props = d_inv_mu_integer[OF IM] inv_mu_inverse[OF IM]
  have mat_oi: "map_mat ?oi (inv_mu_ij_mat i j c) = ?C * ?D + 1m m" (is "?MM = _")
    unfolding inv_mu_ij_mat_def Let_def IM[symmetric]
    apply (subst of_int_hom.mat_hom_add, force, force)
    apply (rule arg_cong2[of _ _ _ _ "(+)"])
     apply (subst of_int_hom.mat_hom_mult, force, force)
     apply (rule arg_cong2[of _ _ _ _ "(*)"])
      apply force
     apply (rule eq_matI, (auto)[3], goal_cases)
  proof -
    case (1 i j)
    from IM_props(1)[OF 1]
    show ?case unfolding Ints_def using the_inv_f_f[OF oi] by auto
  qed auto
  have "map_mat ?oi (inv_mu_ij_mat i j c) * M m = (?C * ?D) * M m + M m" unfolding mat_oi
    by (subst add_mult_distrib_mat[of _ m m], auto)
  also have "(?C * ?D) * M m = ?C * (?D * M m)" 
    by (rule assoc_mult_mat, auto)
  also have "?D = mat m m (λ (i,j). if i = j then d j else 0) * IM" (is "_ = ?E * _")
  proof (rule eq_matI, insert IM_props(4), auto simp: scalar_prod_def, goal_cases)
    case (1 i j)
    hence id: "{0..<m} = {0..<i}  {i}  {Suc i ..<m}" 
      by (auto simp add: list_trisect)
    show ?case unfolding id
      by (auto simp: sum.union_disjoint)
  qed
  also have " * M m = ?E * (IM * M m)" 
    by (rule assoc_mult_mat[of _ m m], insert IM_props, auto)
  also have "IM * M m = 1m m" by fact
  also have "?E * 1m m = ?E" by simp
  also have "?C * ?E = mat m m (λ ij. if ij = (i,j) then ?oi c * d j else 0)" 
    by (rule eq_matI, auto simp: scalar_prod_def, auto simp: jm sum.union_disjoint)
  finally show "map_mat ?oi (inv_mu_ij_mat i j c) * M m = 
    mat m m (λ ij. if ij = (i,j) then ?oi c * d j else 0) + M m" .
  show carr: "inv_mu_ij_mat i j c  carrier_mat m m"
    unfolding inv_mu_ij_mat_def by auto
  {
    assume k: "k < m" 
    have "of_int (inv_mu_ij_mat i j c $$ (k,k)) = ?MM $$ (k,k)" 
      using carr k by auto
    also have " = (?C * ?D) $$ (k,k) + 1" unfolding mat_oi using k by simp
    also have "(?C * ?D) $$ (k,k) = 0" using k
      by (auto simp: scalar_prod_def, auto simp: jm sum.union_disjoint 
        inv_mu_lower_triangular[OF IM ji i])
    finally show "inv_mu_ij_mat i j c $$ (k,k) = 1" by simp
  }
  {
    assume ij': "i' < j'" "j' < m"  
    have "of_int (inv_mu_ij_mat i j c $$ (i',j')) = ?MM $$ (i',j')" 
      using carr ij' by auto
    also have " = (?C * ?D) $$ (i',j')" unfolding mat_oi using ij' by simp
    also have "(?C * ?D) $$ (i',j') = (if i' = i then ?oi c * (d j * IM $$ (j, j')) else 0)" 
      using ij' i j by (auto simp: scalar_prod_def, auto simp: jm sum.union_disjoint)
    also have " = 0" using inv_mu_lower_triangular[OF IM _ ij'(2), of j] ij' i ji by auto
    finally show "inv_mu_ij_mat i j c $$ (i',j') = 0" by simp
  }
  {
    assume A: "A  carrier_mat m n" and c: "c mod p = 0" 
    let ?mod = "map_mat (λ x. x mod p)" 
    let ?C = "mat m m (λ ij. if ij = (i,j) then c else 0)" 
    let ?D = "mat m m (λ ij. if ij = (i,j) then 1 else (0 :: int))" 
    define B where "B = mat m m (λ (i,j). the_inv ?oi (d i * the (mat_inverse (M m)) $$ (i,j)))" 
    have B: "B  carrier_mat m m" unfolding B_def by auto
    define BA where "BA = B * A" 
    have BA: "BA  carrier_mat m n" unfolding BA_def using A B by auto
    define DBA where "DBA = ?D * BA" 
    have DBA: "DBA  carrier_mat m n" unfolding DBA_def using BA by auto
    have "?mod (inv_mu_ij_mat i j c * A) = 
     ?mod ((?C * B + 1m m) * A)" 
      unfolding inv_mu_ij_mat_def B_def by simp
    also have "(?C * B + 1m m) * A = ?C * B * A + A" 
      by (subst add_mult_distrib_mat, insert A B, auto)
    also have "?C * B * A = ?C * BA" 
      unfolding BA_def
      by (rule assoc_mult_mat, insert A B, auto)
    also have "?C = c m ?D" 
      by (rule eq_matI, auto)
    also have " * BA = c m DBA" using BA unfolding DBA_def by auto
    also have "?mod ( + A) = ?mod A" 
      by (rule eq_matI, insert DBA A c, auto simp: mult.assoc) 
    finally show "?mod (inv_mu_ij_mat i j c * A) = ?mod A" .
  }
qed   
end
 
lemma Gramian_determinant_of_int: assumes fs: "set fs  carrier_vec n" 
  and j: "j  length fs" 
shows "of_int (gram_schmidt.Gramian_determinant n fs j)
  = gram_schmidt.Gramian_determinant n (map (map_vec rat_of_int) fs) j" 
proof -
  from j have j: "k < j  k < length fs" for k by auto
  show ?thesis
  unfolding gram_schmidt.Gramian_determinant_def
  by (subst of_int_hom.hom_det[symmetric], rule arg_cong[of _ _ det],
      unfold gram_schmidt.Gramian_matrix_def Let_def, subst of_int_hom.mat_hom_mult, force, force,
      unfold map_mat_transpose[symmetric],
      rule arg_cong2[of _ _ _ _ "λ x y. x * yT"], insert fs[unfolded set_conv_nth] 
      j, (fastforce intro!: eq_matI)+)
qed

context LLL
begin

(* this lemma might also be useful for swap/add-operation *)
lemma multiply_invertible_mat: assumes lin: "lin_indep fs" 
  and len: "length fs = m" 
  and A: "A  carrier_mat m m" 
  and A_invertible: " B. B  carrier_mat m m  B * A = 1m m" 
  and fs'_prod: "fs' = Matrix.rows (A * mat_of_rows n fs)" 
shows "lattice_of fs' = lattice_of fs" 
  "lin_indep fs'" 
  "length fs' = m" 
proof -
  let ?Mfs = "mat_of_rows n fs" 
  let ?Mfs' = "mat_of_rows n fs'" 
  from A_invertible obtain B where B: "B  carrier_mat m m" and inv: "B * A = 1m m" by auto
  from lin have fs: "set fs  carrier_vec n" unfolding gs.lin_indpt_list_def by auto
  with len have Mfs: "?Mfs  carrier_mat m n" by auto
  from A Mfs have prod: "A * ?Mfs  carrier_mat m n" by auto
  hence fs': "length fs' = m" "set fs'  carrier_vec n" unfolding fs'_prod
    by (auto simp: Matrix.rows_def Matrix.row_def)  
  have Mfs_prod': "?Mfs' = A * ?Mfs" 
    unfolding arg_cong[OF fs'_prod, of "mat_of_rows n"]
    by (intro eq_matI, auto simp: mat_of_rows_def)
  have "B * ?Mfs' = B * (A * ?Mfs)" 
    unfolding Mfs_prod' by simp
  also have " = (B * A) * ?Mfs"
    by (subst assoc_mult_mat[OF _ A Mfs], insert B, auto)
  also have "B * A = 1m m" by fact
  also have " * ?Mfs = ?Mfs" using Mfs by auto
  finally have Mfs_prod: "?Mfs = B * ?Mfs'" ..  
  interpret LLL: LLL_with_assms n m fs 2
    by (unfold_locales, auto simp: len lin)
  from LLL.LLL_change_basis[OF fs'(2,1) B A Mfs_prod Mfs_prod']
  show latt': "lattice_of fs' = lattice_of fs" and lin': "gs.lin_indpt_list (RAT fs')" 
    and len': "length fs' = m" 
    by (auto simp add: LLL_with_assms_def)
qed

text ‹This is the key lemma.›
lemma change_single_element: assumes lin: "lin_indep fs" 
  and len: "length fs = m" 
  and i: "i < m" and ji: "j < i"  
  and A: "A = gram_schmidt_fs_int.inv_mu_ij_mat n (RAT fs)"    ― ‹the transformation matrix A›
  and fs'_prod: "fs' = Matrix.rows (A i j c * mat_of_rows n fs)" ― ‹fs' is the new basis›
  and latt: "lattice_of fs = L" 
shows "lattice_of fs' = L"
  "c mod p = 0  map (map_vec (λ x. x mod p)) fs' = map (map_vec (λ x. x mod p)) fs" 
  "lin_indep fs'" 
  "length fs' = m" 
  " k. k < m  gso fs' k = gso fs k" 
  " k. k  m  d fs' k = d fs k" 
  "i' < m  j' < m  
    μ fs' i' j' = (if (i',j') = (i,j) then rat_of_int (c * d fs j) + μ fs i' j' else μ fs i' j')" 
  "i' < m  j' < m fs' i' j' = (if (i',j') = (i,j) then c * d fs j * d fs (Suc j) +fs i' j' elsefs i' j')" 
proof -
  let ?A = "A i j c" 
  let ?Mfs = "mat_of_rows n fs" 
  let ?Mfs' = "mat_of_rows n fs'" 
  from lin have fs: "set fs  carrier_vec n" unfolding gs.lin_indpt_list_def by auto
  with len have Mfs: "?Mfs  carrier_mat m n" by auto
  interpret gsi: gram_schmidt_fs_int n "RAT fs"
    rewrites "gsi.inv_mu_ij_mat = A" using lin unfolding A
    by (unfold_locales, insert lin[unfolded gs.lin_indpt_list_def], auto simp: set_conv_nth)
  note A = gsi.inv_mu_ij_mat[unfolded length_map len, OF i ji, where c = c]
  from A(3) Mfs have prod: "?A * ?Mfs  carrier_mat m n" by auto
  hence fs': "length fs' = m" "set fs'  carrier_vec n" unfolding fs'_prod
    by (auto simp: Matrix.rows_def Matrix.row_def)  
  have Mfs_prod': "?Mfs' = ?A * ?Mfs" 
    unfolding arg_cong[OF fs'_prod, of "mat_of_rows n"]
    by (intro eq_matI, auto simp: mat_of_rows_def)
  have detA: "det ?A = 1" 
    by (subst det_lower_triangular[OF A(4) A(3)], insert A, auto intro!: prod_list_neutral 
      simp: diag_mat_def)
  have " B. B  carrier_mat m m  B * ?A = 1m m" 
    by (intro exI[of _ "adj_mat ?A"], insert adj_mat[OF A(3)], auto simp: detA)
  from multiply_invertible_mat[OF lin len A(3) this fs'_prod] latt
  show latt': "lattice_of fs' = L" and lin': "gs.lin_indpt_list (RAT fs')" 
    and len': "length fs' = m" by auto
  interpret LLL: LLL_with_assms n m fs 2
    by (unfold_locales, auto simp: len lin)
  interpret fs: fs_int_indpt n fs
    by (standard, auto simp: lin)
  interpret fs': fs_int_indpt n fs'
    by (standard, auto simp: lin')
  {
    assume c: "c mod p = 0" 
    have id: "rows (map_mat f A) = map (map_vec f) (rows A)" for f A
      unfolding rows_def by auto
    have rows_id: "set fs  carrier_vec n  rows (mat_of_rows n fs) = fs" for fs
      unfolding mat_of_rows_def rows_def 
      by (force simp: Matrix.row_def set_conv_nth intro!: nth_equalityI)
    from A(2)[OF Mfs c]
    have "rows (map_mat (λx. x mod p) ?Mfs') = rows (map_mat (λx. x mod p) ?Mfs)" unfolding Mfs_prod'
      by simp
    from this[unfolded id rows_id[OF fs] rows_id[OF fs'(2)]]
    show "map (map_vec (λ x. x mod p)) fs' = map (map_vec (λ x. x mod p)) fs" .
  }
  {
    define B where "B = ?A" 
    have gs_eq: "k < m  gso fs' k = gso fs k" for k
    proof(induct rule: nat_less_induct)
      case (1 k)
      then show ?case 
      proof(cases "k = 0")
        case True
        then show ?thesis 
        proof -
          have "row ?Mfs' 0 = row ?Mfs 0"
          proof -
            have 2: "0 {0..<m}" and 3: "{1..<m} = {0..<m} - {0}" 
              and 4: "finite {0..<m}" using 1 by auto
            have "row ?Mfs' 0 = vec n (λj. row B 0  col ?Mfs j)" 
              using row_mult A(3) Mfs 1 Mfs_prod' unfolding B_def by simp
            also have " = vec n (λj. (l{0..<m}. B $$ (0, l) * ?Mfs $$ (l, j)))"
              using Mfs A(3) len 1 B_def unfolding scalar_prod_def by auto
            also have " = vec n (λj. B $$ (0, 0) * ?Mfs $$ (0, j) + 
              (l{1..<m}. B $$ (0, l) * ?Mfs $$ (l, j)))"
              using Groups_Big.comm_monoid_add_class.sum.remove[OF 4 2] 3
              by (simp add: g. sum g {0..<m} = g 0 + sum g ({0..<m} - {0}))
            also have " = row ?Mfs 0" 
              using A(4-) 1 unfolding B_def[symmetric] by (simp add: row_def)
            finally show ?thesis by (simp add: B_def Mfs_prod')
          qed
          then show ?thesis using True 1 fs'.f_carrier fs.f_carrier 
            fs'.gs.fs0_gso0 len' len gsi.fs0_gso0 by auto
        qed
      next
        case False
        then show ?thesis
        proof -
          have gso0kcarr: "gsi.gso ` {0 ..<k}  carrier_vec n"
            using 1(2) gsi.gso_carrier len by auto
          hence gsospancarr: "gs.span(gsi.gso ` {0 ..<k})  carrier_vec n " 
            using span_is_subset2 by auto

          have fs'_gs_diff_span: 
            "(RAT fs') !  k - fs'.gs.gso k  gs.span (gsi.gso ` {0 ..<k})"
          proof -
            define gs'sum where "gs'sum =
              gs.M.sumlist (map (λja. fs'.gs.μ k ja v fs'.gs.gso ja) [0..<k])"
            define gssum where "gssum = 
              gs.M.sumlist (map (λja. fs'.gs.μ k ja v gsi.gso ja) [0..<k])"
            have "set (map (λja. fs'.gs.μ k ja v gsi.gso ja) [0..<k]) 
               gs.span(gsi.gso ` {0 ..<k})" using 1(2) gs.span_mem gso0kcarr
              by auto
            hence gssumspan: "gssum  gs.span(gsi.gso ` {0 ..<k})"
              using atLeastLessThan_iff gso0kcarr imageE set_map set_upt 
                vec_space.sumlist_in_span 
              unfolding gssum_def by (smt subsetD)
            hence gssumcarr: "gssum  carrier_vec n" 
              using gsospancarr gssum_def by blast
            have sumid: "gs'sum = gssum"
            proof -
              have "map (λja. fs'.gs.μ k ja v fs'.gs.gso ja) [0..<k] =
                map (λja. fs'.gs.μ k ja v gsi.gso ja) [0..<k]"
                using 1 by simp
              thus ?thesis unfolding gs'sum_def gssum_def by argo
            qed
            have "(RAT fs') !  k = fs'.gs.gso k + gssum" 
              using fs'.gs.fs_by_gso_def len' False 1 sumid 
              unfolding gs'sum_def by auto
            hence "(RAT fs') !  k - fs'.gs.gso k = gssum" 
              using gssumcarr 1(2) len' by auto
            thus ?thesis using gssumspan by simp
          qed

          define v2 where "v2 = sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..< k])"
          have v2carr: "v2  carrier_vec n" 
          proof -
            have "set (map (λja. B $$ (k, ja) v fs ! ja) [0..< k])  carrier_vec n"
              using len 1(2) fs.f_carrier by auto
            thus ?thesis unfolding v2_def by simp
          qed
          define ratv2 where "ratv2 = (map_vec rat_of_int v2)"
          have ratv2carr: "ratv2  carrier_vec n" 
            unfolding ratv2_def using v2carr by simp
          have fs'id: "(RAT fs') ! k = (RAT fs) ! k + ratv2"
          proof -
            have zkm: "[0..<m] = [0..<(Suc k)] @ [(Suc k)..<m]" using 1(2) 
              by (metis Suc_lessI append_Nil2 upt_append upt_rec zero_less_Suc)
            have prep: "set (map (λja. B $$ (k, ja) v fs ! ja) [0..<m])  carrier_vec n" 
              using len fs.f_carrier by auto

            have "fs' ! k = vec n (λj. row B k  col ?Mfs j)"
              using 1(2) Mfs B_def A(3) fs'_prod by simp
            also have " = sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<m])"
            proof -
              {
                fix i
                assume i: "i < n"
                have "(vec n (λj. row B k  col ?Mfs j)) $ i = row B k  col ?Mfs i" 
                  using i by auto
                also have " = (j = 0..<m. B $$ (k, j) * ?Mfs $$ (j,i))" 
                  using A(3) unfolding B_def[symmetric] 
                  by (smt 1(2) Mfs R.finsum_cong' i atLeastLessThan_iff carrier_matD
                      dim_col index_col index_row(1) scalar_prod_def)
                also have " = (j = 0..<m. B $$ (k, j) * (fs ! j $ i))"
                  by (metis (no_types, lifting) R.finsum_cong' atLeastLessThan_iff i
                      len mat_of_rows_index)
                also have " = 
                  (j = 0..<m. (map (λja.  B $$ (k, ja) v fs ! ja) [0..<m]) ! j $ i)"
                proof -
                  have "j<m. i<n. B $$ (k, j) * (fs ! j $ i) = 
                    (map (λja.  B $$ (k, ja) v fs ! ja) [0..<m]) ! j $ i" 
                    using 1(2) i A(3) len fs.f_carrier
                    unfolding B_def[symmetric] by auto
                  then show ?thesis using i by auto
                qed
                also have " = sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<m]) $ i"
                  using sumlist_nth i fs.f_carrier carrier_vecD len by simp
                finally have "(vec n (λj. row B k  col ?Mfs j)) $ i =
                  sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<m]) $ i" by auto
              }
              then show ?thesis using fs.f_carrier len dim_sumlist by auto
            qed
            also have " = sumlist (map (λja. B $$ (k, ja) v fs ! ja) 
              ([0..<(Suc k)] @ [(Suc k)..<m]))" 
              using zkm by simp
            also have " = sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<(Suc k)]) +
              sumlist (map (λja. B $$ (k, ja) v fs ! ja) [(Suc k)..<m])"
              (is " = ?L2 + ?L3")
              using fs.f_carrier len dim_sumlist sumlist_append prep zkm by auto
            also have "?L3 = 0v n"
              using A(4) fs.f_carrier len sumlist_nth carrier_vecD sumlist_carrier 
                prep zkm unfolding B_def[symmetric] by auto
            also have "?L2 = sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<k]) +
              B $$ (k, k) v fs ! k" using prep zkm sumlist_snoc by simp
            also have " = sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<k]) + fs ! k"
              using A(5) 1(2) unfolding B_def[symmetric] by simp
            finally have "fs' ! k = fs ! k + 
              sumlist (map (λja. B $$ (k, ja) v fs ! ja) [0..<k])"
              using prep zkm by (simp add: M.add.m_comm)
            then have "fs' !  k = fs !  k + v2" unfolding v2_def by simp
            then show ?thesis using v2carr 1(2) len len' ratv2_def by force
          qed
          have ratv2span: "ratv2  gs.span (gsi.gso ` {0 ..<k})" 
          proof -
            have rat: "ratv2 = gs.M.sumlist
              (map (λj. of_int (B $$ (k, j)) v (RAT fs) ! j) [0..<k])"
            proof -
              have "set (map (λj. of_int (B $$ (k, j)) v (RAT fs) ! j) [0..<k]) 
                 carrier_vec n"
                using fs.f_carrier 1(2) len by auto
              hence carr: "gs.M.sumlist 
                (map (λj. of_int (B $$ (k, j)) v (RAT fs) ! j) [0..<k])  carrier_vec n"
                by auto
              have "set (map (λj. B $$ (k, j) v fs ! j) [0..<k])  carrier_vec n"
                using fs.f_carrier 1(2) len by auto
              hence "i j. i < n  j < k  of_int ((B $$ (k, j) v fs ! j) $ i)
                = (of_int (B $$ (k, j)) v (RAT fs) ! j) $ i"
                using 1(2) len by fastforce
              hence "i. i < n  ratv2 $ i = gs.M.sumlist
                (map (λj. (of_int (B $$ (k, j)) v (RAT fs) ! j)) [0..<k]) $ i"
                using fs.f_carrier 1(2) len v2carr gs.sumlist_nth sumlist_nth 
                  ratv2_def v2_def by simp
              then show ?thesis using ratv2carr carr by auto
            qed
            have "i. i < k  (RAT fs) ! i = 
              gs.M.sumlist (map (λ j. gsi.μ i j v gsi.gso j) [0 ..< Suc i])"
              using gsi.fi_is_sum_of_mu_gso len 1(2) by auto
            moreover have "i. i < k  (λ j. gsi.μ i j v gsi.gso j) ` {0 ..< Suc i}
               gs.span (gsi.gso ` {0 ..<k})"
              using gs.span_mem gso0kcarr by auto
            ultimately have "i. i < k  (RAT fs) ! i  gs.span (gsi.gso ` {0 ..<k})"
              using gso0kcarr set_map set_upt vec_space.sumlist_in_span subsetD by smt
            then show ?thesis using rat atLeastLessThan_iff set_upt gso0kcarr imageE 
              set_map gs.smult_in_span vec_space.sumlist_in_span by smt
          qed
          have fs_gs_diff_span:
            "(RAT fs) !  k - fs'.gs.gso k  gs.span (gsi.gso ` {0 ..<k})"
          proof -
            from fs'_gs_diff_span obtain v3 where sp: "v3  gs.span (gsi.gso ` {0 ..<k})"
              and eq: "(RAT fs) ! k - fs'.gs.gso k = v3 - ratv2" 
              using fs'.gs.gso_carrier len' 1(2) ratv2carr fs'id by fastforce
            have "v3+(-ratv2)  gs.span(gsi.gso ` {0 ..<k})"
              by (metis sp gs.span_add1 gso0kcarr gram_schmidt.inv_in_span 
                  gso0kcarr ratv2span)
            moreover have "v3+(-ratv2) = v3-ratv2" using ratv2carr by auto
            ultimately have "v3 - ratv2  gs.span (gsi.gso ` {0 ..<k})" by simp
            then show ?thesis using eq by auto
          qed
          {
            fix i
            assume i: "i < k"
            hence "fs'.gs.gso k  fs'.gs.gso i = 0" using 1(2) fs'.gs.orthogonal len' by auto
            hence "fs'.gs.gso k  gsi.gso i = 0" using 1 i by simp
          }
          hence "x. x  gsi.gso ` {0..<k}  fs'.gs.gso k  x = 0" by auto

          then show ?thesis
            using gsi.oc_projection_unique len len' fs_gs_diff_span 1(2) by auto
        qed
      qed
    qed

    have " i' j'. i' < m  j' < m  μ fs' i' j' = 
      (map_mat of_int (A i j c) * gsi.M m) $$ (i',j')" and
      " k. k < m  gso fs' k = gso fs k"
    proof -
      define rB where "rB = map_mat rat_of_int B"
      have rBcarr: "rB  carrier_mat m m" using A(3) unfolding rB_def B_def by simp
      define rfs where "rfs = mat_of_rows n (RAT fs)"
      have rfscarr: "rfs  carrier_mat m n" using Mfs unfolding rfs_def by simp

      {
        fix i'
        fix j'
        assume i': "i' < m"
        assume j': "j' < m"
        have prep: 
          "of_int_hom.vec_hom (row (B * mat_of_rows n fs) i') = row (rB * rfs) i'" 
          using len i' B_def A(3) rB_def rfs_def by (auto simp: scalar_prod_def)
        have prep2: "row (rB * rfs) i' = vec n (λl. row rB i'  col rfs l)"
          using len fs.f_carrier i' B_def A(3) scalar_prod_def rB_def
          unfolding rfs_def by auto
        have prep3: "(vec m (λ j1. row rfs j1  gsi.gso j' / gsi.gso j'2)) =
          (vec m (λ j1. (gsi.M m) $$ (j1, j')))"
        proof -
          {
            fix x y
            assume x: "x < m" and y: "y < m"
            have "(gsi.M m) $$ (x,y) = (if y < x then map of_int_hom.vec_hom fs ! x 
               fs'.gs.gso y / fs'.gs.gso y2 else if x = y then 1 else 0)" 
              using gsi.μ.simps x y j' len gs_eq gsi.M_index by auto
            hence "row rfs x  gsi.gso y / gsi.gso y2 = (gsi.M m) $$ (x,y)"
              unfolding rfs_def 
              by (metis carrier_matD(1) divide_eq_eq fs'.gs.β_zero fs'.gs.gso_norm_beta 
                  gs_eq gsi.μ.simps gsi.fi_scalar_prod_gso gsi.fs_carrier len len' 
                  length_map nth_rows rfs_def rfscarr rows_mat_of_rows x y)
          }
          then show ?thesis using j' by auto
        qed
        have prep4: "(1 / gsi.gso j'2) v (vec m (λj1. row rfs j1  gsi.gso j')) =
          (vec m (λj1. row rfs j1  gsi.gso j' / gsi.gso j'2))" by auto

        have "map of_int_hom.vec_hom fs' ! i'  fs'.gs.gso j' / fs'.gs.gso j'2
           = map of_int_hom.vec_hom fs' ! i'  gsi.gso j' / gsi.gso j'2"
          using gs_eq j' by simp
        also have " = row (rB * rfs) i'  gsi.gso j' / gsi.gso j'2"
          using prep i' len' unfolding rB_def B_def by (simp add: fs'_prod)
        also have " = 
          (vec n (λl. row rB i'  col rfs l))  gsi.gso j' / gsi.gso j'2"
          using prep2 by auto
        also have "vec n (λl. row rB i'  col rfs l) = 
            (vec n (λl. (j1=0..<m. (row rB i') $ j1 * (col rfs l) $ j1)))"
          using gsi.gso_carrier
          by (metis (no_types) carrier_matD(1) col_def dim_vec rfscarr scalar_prod_def)
        also have " = 
            (vec n (λl. (j1=0..<m. rB $$ (i', j1) * rfs $$ (j1, l))))" 
          using rBcarr rfscarr i' by auto
        also have "  gsi.gso j' = 
            (j2=0..<n. (vec n 
            (λl. (j1=0..<m. rB $$ (i', j1) * rfs $$ (j1, l)))) $ j2 * (gsi.gso j') $ j2)"
          using gsi.gso_carrier len j' scalar_prod_def 
          by (smt gs.R.finsum_cong' gsi.gso_dim length_map)
        also have " = (j2=0..<n.
            (j1=0..<m. rB $$ (i', j1) * rfs $$ (j1, j2)) * (gsi.gso j') $ j2)"
          using gsi.gso_carrier len j' by simp
        also have " = (j2=0..<n. (j1=0..<m.
            rB $$ (i', j1) * rfs $$ (j1, j2) * (gsi.gso j') $ j2))" 
          by (smt gs.R.finsum_cong' sum_distrib_right)
        also have " = (j1=0..<m. (j2=0..<n.
            rB $$ (i', j1) * rfs $$ (j1, j2) * (gsi.gso j') $ j2))"
          using sum.swap by auto
        also have " = (j1=0..<m. rB $$ (i', j1) * (j2=0..<n. 
            rfs $$ (j1, j2) * (gsi.gso j') $ j2))"
          using gs.R.finsum_cong' sum_distrib_left by (smt gs.m_assoc)
        also have " = row rB i'  (vec m (λ j1. (j2=0..<n.
            rfs $$ (j1, j2) * (gsi.gso j') $ j2)))" 
          using rBcarr rfscarr i' scalar_prod_def
          by (smt atLeastLessThan_iff carrier_matD(1) carrier_matD(2) dim_vec 
              gs.R.finsum_cong' index_row(1) index_vec)
        also have "(vec m (λ j1. (j2=0..<n. rfs $$ (j1, j2) * (gsi.gso j') $ j2)))
            =  (vec m (λ j1. row rfs j1  gsi.gso j'))"
          using rfscarr gsi.gso_carrier len j' rfscarr by (auto simp add: scalar_prod_def)
        also have "row rB i'   / gsi.gso j'2 =
          row rB i'  vec m (λ j1. row rfs j1  gsi.gso j' / gsi.gso j'2)"
          using prep4 scalar_prod_smult_right rBcarr carrier_matD(2) dim_vec row_def 
          by (smt gs.l_one times_divide_eq_left)
        also have " = (rB * (gsi.M m)) $$ (i', j')" 
          using rBcarr i' j' prep3 gsi.M_def by (simp add: col_def)
        finally have 
          "map of_int_hom.vec_hom fs' ! i'  fs'.gs.gso j' / fs'.gs.gso j'2 =
          (rB * (gsi.M m)) $$ (i', j')" by auto
      }
      then show " i' j'. i' < m  j' < m  μ fs' i' j' = 
        (map_mat of_int (A i j c) * gsi.M m) $$ (i',j')" 
        using B_def fs'.gs.β_zero fs'.gs.fi_scalar_prod_gso fs'.gs.gso_norm_beta
          len' rB_def by auto
      show " k. k < m  gso fs' k = gso fs k" using gs_eq by auto
    qed
  } note mu_gso = this

  show " k. k < m  gso fs' k = gso fs k" by fact
  {
    fix k
    have "k  m  rat_of_int (d fs' k) = rat_of_int (d fs k)" for k
    proof (induct k)
      case 0
      show ?case by (simp add: d_def)
    next
      case (Suc k)
      hence k: "k  m" "k < m" by auto 
      show ?case
        by (subst (1 2) LLL_d_Suc[OF _ k(2)], auto simp: Suc(1)[OF k(1)] mu_gso(2)[OF k(2)]
          LLL_invariant_weak_def lin lin' len len' latt latt')
    qed
    thus "k  m  d fs' k = d fs k" by simp
  } note d = this
  {
    assume i': "i' < m" and j': "j' < m"
    have fs' i' j' = (of_int_hom.mat_hom (A i j c) * gsi.M m) $$ (i',j')" by (rule mu_gso(1)[OF i' j'])
    also have " = (if (i',j') = (i,j) then of_int c * gsi.d j else 0) + gsi.M m $$ (i',j')" 
      unfolding A(1) using i' j' by (auto simp: gsi.M_def)
    also have "gsi.M m $$ (i',j') = μ fs i' j'" 
      unfolding gsi.M_def using i' j' by simp
    also have "gsi.d j = of_int (d fs j)" 
      unfolding d_def by (subst Gramian_determinant_of_int[OF fs], insert ji i len, auto)
    finally show mu: fs' i' j' = (if (i',j') = (i,j) then rat_of_int (c * d fs j) + μ fs i' j' else μ fs i' j')" 
      by simp
    let ?d = "d fs (Suc j')" 
    have d_fs: "of_int (fs i' j') = rat_of_int ?d * μ fs i' j'" 
      unfolding dμ_def 
      using fs.fs_int_mu_d_Z_m_m[unfolded len, OF i' j'] 
      by (metis LLL.LLL.d_def assms(2) fs.fs_int_mu_d_Z_m_m fs_int.d_def i' 
          int_of_rat(2) j')
    have "rat_of_int (fs' i' j') = rat_of_int (d fs' (Suc j')) * μ fs' i' j'" 
      unfolding dμ_def 
      using fs'.fs_int_mu_d_Z_m_m[unfolded len', OF i' j']
      using LLL.LLL.d_def fs'(1) fs'.dμ fs'.dμ_def fs_int.d_def i' j' by auto
    also have "d fs' (Suc j') = ?d" by (rule d, insert j', auto)
    also have "rat_of_int  * μ fs' i' j' = 
      (if (i',j') = (i,j) then rat_of_int (c * d fs j * ?d) else 0) + of_int (fs i' j')" 
      unfolding mu d_fs by (simp add: field_simps)
    also have " = rat_of_int ((if (i',j') = (i,j) then c * d fs j * ?d else 0) +fs i' j')"
      by simp
    also have " = rat_of_int ((if (i',j') = (i,j) then c * d fs j * d fs (Suc j) +fs i' j' elsefs i' j'))"
      by simp
    finally show "dμ fs' i' j' = (if (i',j') = (i,j) then c * d fs j * d fs (Suc j) +fs i' j' elsefs i' j')" 
      by simp
  }
qed

text ‹Eventually: Lemma 13 of Storjohann's paper.›
lemma mod_single_element: assumes lin: "lin_indep fs" 
  and len: "length fs = m" 
  and i: "i < m" and ji: "j < i"  
  and latt: "lattice_of fs = L" 
  and pgtz: "p > 0"
shows " fs'. lattice_of fs' = L  
  map (map_vec (λ x. x mod p)) fs' = map (map_vec (λ x. x mod p)) fs 
  map (map_vec (λ x. x symmod p)) fs' = map (map_vec (λ x. x symmod p)) fs 
  lin_indep fs' 
  length fs' = m  
  ( k < m. gso fs' k = gso fs k)  
  ( k  m. d fs' k = d fs k) 
  ( i' < m.  j' < m.fs' i' j' = (if (i',j') = (i,j) thenfs i j' symmod (p * d fs j' * d fs (Suc j')) elsefs i' j'))" 
proof -
  have inv: "LLL_invariant_weak fs" using LLL_invariant_weak_def assms by simp
  let ?mult = "d fs j * d fs (Suc j)" 
  define M where "M = ?mult" 
  define pM where "pM = p * M" 
  then have pMgtz: "pM > 0" using pgtz unfolding pM_def M_def using LLL_d_pos[OF inv] i ji by simp
  let ?d = "dμ fs i j" 
  define c where "c = - (?d symdiv pM)" 
  have d_mod: "?d symmod pM = c * pM + ?d" unfolding c_def using pMgtz sym_mod_sym_div by simp
  define A where "A = gram_schmidt_fs_int.inv_mu_ij_mat n (RAT fs)" 
  define fs' where fs': "fs' = Matrix.rows (A i j (c * p) * mat_of_rows n fs)"
  note main = change_single_element[OF lin len i ji A_def fs' latt]
  have "map (map_vec (λx. x mod p)) fs' = map (map_vec (λx. x mod p)) fs" 
    by (intro main, auto)
  from arg_cong[OF this, of "map (map_vec (poly_mod.inv_M p))"]
  have id: "map (map_vec (λx. x symmod p)) fs' = map (map_vec (λx. x symmod p)) fs" 
    unfolding map_map o_def sym_mod_def map_vec_map_vec .
  show ?thesis
  proof (intro exI[of _ fs'] conjI main allI impI id)
    fix i' j'
    assume ij: "i' < m" "j' < m" 
    have "dμ fs' i' j' = (if (i', j') = (i, j) then (c * p) * M + ?d elsefs i' j')" 
      unfolding main(8)[OF ij] M_def by simp
    also have "(c * p) * M + ?d = ?d symmod pM" 
      unfolding d_mod by (simp add: pM_def)
    finally show "dμ fs' i' j' = (if (i',j') = (i,j) thenfs i j' symmod (p * d fs j' * d fs (Suc j')) elsefs i' j')" 
      by (auto simp: pM_def M_def ac_simps)
  qed auto
qed 

text ‹A slight generalization to perform modulo on arbitrary set of indices $I$.›
lemma mod_finite_set: assumes lin: "lin_indep fs" 
  and len: "length fs = m" 
  and I: "I  {(i,j). i < m  j < i}"
  and latt: "lattice_of fs = L" 
  and pgtz: "p > 0"
shows " fs'. lattice_of fs' = L 
  map (map_vec (λ x. x mod p)) fs' = map (map_vec (λ x. x mod p)) fs 
  map (map_vec (λ x. x symmod p)) fs' = map (map_vec (λ x. x symmod p)) fs 
  lin_indep fs' 
  length fs' = m  
  ( k < m. gso fs' k = gso fs k)  
  ( k  m. d fs' k = d fs k) 
  ( i' < m.  j' < m.fs' i' j' = 
     (if (i',j')  I thenfs i' j' symmod (p * d fs j' * d fs (Suc j')) elsefs i' j'))"
proof -
  let ?exp = "λ fs' I i' j'.fs' i' j' = (if (i',j')  I thenfs i' j' symmod (p * d fs j' * d fs (Suc j')) elsefs i' j')" 
  let ?prop = "λ fs fs'. lattice_of fs' = L  
    map (map_vec (λ x. x mod p)) fs' = map (map_vec (λ x. x mod p)) fs 
    map (map_vec (λ x. x symmod p)) fs' = map (map_vec (λ x. x symmod p)) fs 
    lin_indep fs' 
    length fs' = m  
    ( k < m. gso fs' k = gso fs k)  
    ( k  m. d fs' k = d fs k)" 
  have "finite I" 
  proof (rule finite_subset[OF I], rule finite_subset)
    show "{(i, j). i < m  j < i}  {0..m} × {0..m}" by auto
  qed auto
  from this I have " fs'. ?prop fs fs'  ( i' < m.  j' < m. ?exp fs' I i' j')"
  proof (induct I)
    case empty
    show ?case
      by (intro exI[of _ fs], insert assms, auto)
  next
    case (insert ij I)
    obtain i j where ij: "ij = (i,j)" by force
    from ij insert(4) have i: "i < m" "j < i" by auto
    from insert(3,4) obtain gs where gs: "?prop fs gs" 
      and exp: " i' j'. i' < m  j' < m  ?exp gs I i' j'" by auto
    from gs have "lin_indep gs" "lattice_of gs = L" "length gs = m" by auto
    from mod_single_element[OF this(1,3) i this(2), of p] 
    obtain hs where hs: "?prop gs hs" 
      and exp': " i' j'. i' < m  j' < m hs i' j' = (if (i', j') = (i, j) 
         thengs i j' symmod (p * d gs j' * d gs (Suc j')) elsegs i' j')" 
      using pgtz by auto
    from gs i have id: "d gs j = d fs j" "d gs (Suc j) = d fs (Suc j)" by auto
    show ?case 
    proof (intro exI[of _ hs], rule conjI; (intro allI impI)?)
      show "?prop fs hs" using gs hs by auto
      fix i' j'
      assume *: "i' < m" "j' < m" 
      show "?exp hs (insert ij I) i' j'" unfolding exp'[OF *] ij using exp * i
        by (auto simp: id)
    qed
  qed
  thus ?thesis by auto
qed

end

end

Theory Storjohann

section ‹Storjohann's basis reduction algorithm (abstract version)›

text ‹This theory contains the soundness proofs of Storjohann's basis
  reduction algorithms, both for the normal and the improved-swap-order variant.

  The implementation of Storjohann's version of LLL uses modular operations throughout.
  It is an abstract implementation that is already quite close to what the actual implementation will be.
   In particular, the swap operation here is derived from the computation lemma for the swap
   operation in the old, integer-only formalization of LLL.›

theory Storjohann
  imports 
    Storjohann_Mod_Operation
    LLL_Basis_Reduction.LLL_Number_Bounds
    Sqrt_Babylonian.NthRoot_Impl
begin

subsection ‹Definition of algorithm›

text ‹In the definition of the algorithm, the first-flag determines, whether only the first vector
  of the reduced basis should be computed, i.e., a short vector. Then the modulus can be slightly
  decreased in comparison to the required modulus for computing the whole reduced matrix.›

fun max_list_rats_with_index :: "(int * int * nat) list  (int * int * nat)" where
  "max_list_rats_with_index [x] = x" |
  "max_list_rats_with_index ((n1,d1,i1) # (n2,d2,i2) # xs) 
     = max_list_rats_with_index ((if n1 * d2  n2 * d1 then (n2,d2,i2) else (n1,d1,i1)) # xs)"

context LLL
begin

definition "log_base = (10 :: int)" 

definition bound_number :: "bool  nat" where
  "bound_number first = (if first  m  0 then 1 else m)" 

definition compute_mod_of_max_gso_norm :: "bool  rat  int" where
  "compute_mod_of_max_gso_norm first mn = log_base ^ (log_ceiling log_base (max 2 (
     root_rat_ceiling 2 (mn * (rat_of_nat (bound_number first) + 3)) + 1)))"

definition g_bnd_mode :: "bool  rat  int vec list  bool" where 
  "g_bnd_mode first b fs = (if first  m  0 then sq_norm (gso fs 0)  b else g_bnd b fs)" 

definition d_of where "d_of dmu i = (if i = 0 then 1 :: int else dmu $$ (i - 1, i - 1))"

definition compute_max_gso_norm :: "bool  int mat  rat × nat" where
  "compute_max_gso_norm first dmu = (if m = 0 then (0,0) else 
      case max_list_rats_with_index (map (λ i. (d_of dmu (Suc i), d_of dmu i, i)) [0 ..< (if first then 1 else m)])
      of (num, denom, i)  (of_int num / of_int denom, i))"


context
  fixes p :: int ― ‹the modulus›
    and first :: bool ― ‹only compute first vector of reduced basis›
begin

definition basis_reduction_mod_add_row :: 
  "int vec list  int mat  nat  nat  (int vec list × int mat)"  where
  "basis_reduction_mod_add_row mfs dmu i j = 
    (let c = round_num_denom (dmu $$ (i,j)) (d_of dmu (Suc j)) in
      (if c = 0 then (mfs, dmu) 
        else (mfs[ i := (map_vec (λ x. x symmod p)) (mfs ! i - c v mfs ! j)], 
             mat m m (λ(i',j'). (if (i' = i  j'  j) 
                then (if j'=j then (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                      else (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                            symmod (p * d_of dmu j' * d_of dmu (Suc j')))
                else (dmu $$ (i',j')))))))"

fun basis_reduction_mod_add_rows_loop where
  "basis_reduction_mod_add_rows_loop mfs dmu i 0 = (mfs, dmu)"
| "basis_reduction_mod_add_rows_loop mfs dmu i (Suc j) = (
     let (mfs', dmu') = basis_reduction_mod_add_row mfs dmu i j
      in basis_reduction_mod_add_rows_loop mfs' dmu' i j)" 

definition basis_reduction_mod_swap_dmu_mod :: "int mat  nat  int mat" where
  "basis_reduction_mod_swap_dmu_mod dmu k = mat m m (λ(i, j). (
    if j < i  (j = k  j = k - 1) then 
        dmu $$ (i, j) symmod (p * d_of dmu j * d_of dmu (Suc j))
    else dmu $$ (i, j)))"

definition basis_reduction_mod_swap where
  "basis_reduction_mod_swap mfs dmu k = 
     (mfs[k := mfs ! (k - 1), k - 1 := mfs ! k],
      basis_reduction_mod_swap_dmu_mod (mat m m (λ(i,j). (
      if j < i then
        if i = k - 1 then 
           dmu $$ (k, j)
        else if i = k  j  k - 1 then 
             dmu $$ (k - 1, j)
        else if i > k  j = k then
           ((d_of dmu (Suc k)) * dmu $$ (i, k - 1) - dmu $$ (k, k - 1) * dmu $$ (i, j)) 
              div (d_of dmu k)
        else if i > k  j = k - 1 then
           (dmu $$ (k, k - 1) * dmu $$ (i, j) + dmu $$ (i, k) * (d_of dmu (k-1)))
              div (d_of dmu k)
        else dmu $$ (i, j)
      else if i = j then 
        if i = k - 1 then 
          ((d_of dmu (Suc k)) * (d_of dmu (k-1)) + dmu $$ (k, k - 1) * dmu $$ (k, k - 1)) 
            div (d_of dmu k)
        else (d_of dmu (Suc i))
      else dmu $$ (i, j))
    )) k)" 

fun basis_reduction_adjust_mod where
  "basis_reduction_adjust_mod mfs dmu = 
    (let (b,g_idx) = compute_max_gso_norm first dmu;
         p' = compute_mod_of_max_gso_norm first b
        in if p' < p then 
           let mfs' = map (map_vec (λx. x symmod p')) mfs;
               d_vec = vec (Suc m) (λ i. d_of dmu i);
               dmu' = mat m m (λ (i,j). if j < i then dmu $$ (i,j) 
                 symmod (p' * d_vec $ j * d_vec $ (Suc j)) else
                 dmu $$ (i,j))
             in (p', mfs', dmu', g_idx)
           else (p, mfs, dmu, g_idx))" 

definition basis_reduction_adjust_swap_add_step where
  "basis_reduction_adjust_swap_add_step mfs dmu g_idx i = (
    let i1 = i - 1; 
        (mfs1, dmu1) = basis_reduction_mod_add_row mfs dmu i i1;
        (mfs2, dmu2) = basis_reduction_mod_swap mfs1 dmu1 i
      in if i1 = g_idx then basis_reduction_adjust_mod mfs2 dmu2
         else (p, mfs2, dmu2, g_idx))"


definition basis_reduction_mod_step where
  "basis_reduction_mod_step mfs dmu g_idx i (j :: int) = (if i = 0 then (p, mfs, dmu, g_idx, Suc i, j)
     else let di = d_of dmu i;
              (num, denom) = quotient_of α
      in if di * di * denom  num * d_of dmu (i - 1) * d_of dmu (Suc i) then
          (p, mfs, dmu, g_idx, Suc i, j)
      else let (p', mfs', dmu', g_idx') = basis_reduction_adjust_swap_add_step mfs dmu g_idx i
          in (p', mfs', dmu', g_idx', i - 1, j + 1))" 

primrec basis_reduction_mod_add_rows_outer_loop where
  "basis_reduction_mod_add_rows_outer_loop mfs dmu 0 = (mfs, dmu)" |
  "basis_reduction_mod_add_rows_outer_loop mfs dmu (Suc i) = 
    (let (mfs', dmu') = basis_reduction_mod_add_rows_outer_loop mfs dmu i in
      basis_reduction_mod_add_rows_loop mfs' dmu' (Suc i) (Suc i))"
end

text ‹the main loop of the normal Storjohann algorithm›
partial_function (tailrec) basis_reduction_mod_main where
  "basis_reduction_mod_main p first mfs dmu g_idx i (j :: int) = (
    (if i < m 
       then 
         case basis_reduction_mod_step p first mfs dmu g_idx i j
         of (p', mfs', dmu', g_idx', i', j')   
           basis_reduction_mod_main p' first mfs' dmu' g_idx' i' j'
       else
         (p, mfs, dmu)))"

definition compute_max_gso_quot:: "int mat  (int * int * nat)" where
  "compute_max_gso_quot dmu = max_list_rats_with_index 
    (map (λi. ((d_of dmu (i+1)) * (d_of dmu (i+1)), (d_of dmu (i+2)) * (d_of dmu i), Suc i)) [0..<(m-1)])"

text ‹the main loop of Storjohann's algorithm with improved swap order›
partial_function (tailrec) basis_reduction_iso_main where
  "basis_reduction_iso_main p first mfs dmu g_idx (j :: int) = (
    (if m > 1 then
      (let (max_gso_num, max_gso_denum, indx) = compute_max_gso_quot dmu;
        (num, denum) = quotient_of α in
        (if (max_gso_num * denum  > num * max_gso_denum) then
            case basis_reduction_adjust_swap_add_step p first mfs dmu g_idx indx of
              (p', mfs', dmu', g_idx') 
          basis_reduction_iso_main p' first mfs' dmu' g_idx' (j + 1) 
         else
           (p, mfs, dmu)))
     else (p, mfs, dmu)))"

definition compute_initial_mfs where
  "compute_initial_mfs p = map (map_vec (λx. x symmod p)) fs_init"

definition compute_initial_dmu where
  "compute_initial_dmu p dmu = mat m m (λ(i',j'). if j' < i' 
        then dmu $$ (i', j') symmod (p * d_of dmu j' * d_of dmu (Suc j')) 
        else dmu $$ (i', j'))"

definition "dmu_initial = (let dmu = dμ_impl fs_init
   in mat m m (λ (i,j). 
   if j  i then dμ_impl fs_init !! i !! j else 0))"

definition "compute_initial_state first = 
  (let dmu = dmu_initial;
       (b, g_idx) = compute_max_gso_norm first dmu;
       p = compute_mod_of_max_gso_norm first b
     in (p, compute_initial_mfs p, compute_initial_dmu p dmu, g_idx))" 

text ‹Storjohann's algorithm›
definition reduce_basis_mod :: "int vec list" where
  "reduce_basis_mod = (
     let first = False;
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_mod_main p0 first mfs0 dmu0 g_idx 0 0;
         (mfs'', dmu'') = basis_reduction_mod_add_rows_outer_loop p' mfs' dmu' (m-1)
      in mfs'')"

text ‹Storjohann's algorithm with improved swap order›
definition reduce_basis_iso :: "int vec list" where
  "reduce_basis_iso = (
     let first = False; 
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_iso_main p0 first mfs0 dmu0 g_idx 0;
         (mfs'', dmu'') = basis_reduction_mod_add_rows_outer_loop p' mfs' dmu' (m-1)
      in mfs'')"

text ‹Storjohann's algorithm for computing a short vector›
definition 
  "short_vector_mod = (
     let first = True;
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_mod_main p0 first mfs0 dmu0 g_idx 0 0
      in hd mfs')"

text ‹Storjohann's algorithm (iso-variant) for computing a short vector›
definition 
  "short_vector_iso = (
     let first = True; 
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_iso_main p0 first mfs0 dmu0 g_idx 0
      in hd mfs')"
end

subsection ‹Towards soundness of Storjohann's algorithm›

lemma max_list_rats_with_index_in_set: 
  assumes max: "max_list_rats_with_index xs = (nm, dm, im)"
  and len: "length xs  1"
shows "(nm, dm, im)  set xs"
  using assms
proof (induct xs rule: max_list_rats_with_index.induct)
  case (2 n1 d1 i1 n2 d2 i2 xs)
  have "1  length ((if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) # xs)" by simp
  moreover have "max_list_rats_with_index ((if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) # xs)
              = (nm, dm, im)" using 2 by simp
  moreover have "(if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) 
        set ((n1, d1, i1) # (n2, d2, i2) # xs)" by simp
  moreover then have "set ((if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) # xs) 
        set ((n1, d1, i1) # (n2, d2, i2) # xs)" by auto
  ultimately show ?case using 2(1) by auto
qed auto

lemma max_list_rats_with_index: assumes " n d i. (n,d,i)  set xs  d > 0" 
  and max: "max_list_rats_with_index xs = (nm, dm, im)" 
  and "(n,d,i)  set xs" 
shows "rat_of_int n / of_int d  of_int nm / of_int dm" 
  using assms
proof (induct xs arbitrary: n d i rule: max_list_rats_with_index.induct)
  case (2 n1 d1 i1 n2 d2 i2 xs n d i)
  let ?r = "rat_of_int" 
  from 2(2) have "d1 > 0" "d2 > 0" by auto
  hence d: "?r d1 > 0" "?r d2 > 0" by auto
  have "(n1 * d2  n2 * d1) = (?r n1 * ?r d2  ?r n2 * ?r d1)" 
    unfolding of_int_mult[symmetric] by presburger
  also have " = (?r n1 / ?r d1  ?r n2 / ?r d2)" using d 
    by (smt divide_strict_right_mono leD le_less_linear mult.commute nonzero_mult_div_cancel_left 
        not_less_iff_gr_or_eq times_divide_eq_right)
  finally have id: "(n1 * d2  n2 * d1) = (?r n1 / ?r d1  ?r n2 / ?r d2)" .
  obtain n' d' i' where new: "(if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) = (n',d',i')" 
    by force  
  have nd': "(n',d',i')  {(n1,d1,i1), (n2, d2, i2)}" using new[symmetric] by auto
  from 2(3) have res: "max_list_rats_with_index ((n',d',i') # xs) = (nm, dm, im)" using new by auto
  note 2 = 2[unfolded new]
  show ?case 
  proof (cases "(n,d,i)  set xs")
    case True
    show ?thesis 
      by (rule 2(1)[of n d, OF 2(2) res], insert True nd', force+)
  next
    case False
    with 2(4) have "n = n1  d = d1  n = n2  d = d2" by auto
    hence "?r n / ?r d  ?r n' / ?r d'" using new[unfolded id]
      by (metis linear prod.inject)
    also have "?r n' / ?r d'  ?r nm / ?r dm" 
      by (rule 2(1)[of n' d', OF 2(2) res], insert nd', force+)
    finally show ?thesis .
  qed
qed auto

context LLL
begin

lemma log_base: "log_base  2" unfolding log_base_def by auto

definition LLL_invariant_weak' :: "nat  int vec list  bool" where 
  "LLL_invariant_weak' i fs = ( 
    gs.lin_indpt_list (RAT fs)  
    lattice_of fs = L 
    weakly_reduced fs i 
    i  m  
    length fs = m    
  )" 

lemma LLL_invD_weak: assumes "LLL_invariant_weak' i fs"
  shows 
  "lin_indep fs" 
  "length (RAT fs) = m" 
  "set fs  carrier_vec n"
  " i. i < m  fs ! i  carrier_vec n" 
  " i. i < m  gso fs i  carrier_vec n" 
  "length fs = m"
  "lattice_of fs = L" 
  "weakly_reduced fs i"
  "i  m"
proof (atomize (full), goal_cases)
  case 1
  interpret gs': gram_schmidt_fs_lin_indpt n "RAT fs"
    by (standard) (use assms LLL_invariant_weak'_def gs.lin_indpt_list_def in auto)
  show ?case
    using assms gs'.fs_carrier gs'.f_carrier gs'.gso_carrier
    by (auto simp add: LLL_invariant_weak'_def gram_schmidt_fs.reduced_def)
qed

lemma LLL_invI_weak: assumes  
  "set fs  carrier_vec n"
  "length fs = m"
  "lattice_of fs = L" 
  "i  m"
  "lin_indep fs" 
  "weakly_reduced fs i" 
shows "LLL_invariant_weak' i fs" 
  unfolding LLL_invariant_weak'_def Let_def using assms by auto

lemma LLL_invw'_imp_w: "LLL_invariant_weak' i fs  LLL_invariant_weak fs" 
  unfolding LLL_invariant_weak'_def LLL_invariant_weak_def by auto
  
lemma basis_reduction_add_row_weak: 
  assumes Linvw: "LLL_invariant_weak' i fs"
  and i: "i < m"  and j: "j < i" 
  and fs': "fs' = fs[ i := fs ! i - c v fs ! j]" 
shows "LLL_invariant_weak' i fs'"
  "g_bnd B fs  g_bnd B fs'" 
proof (atomize(full), goal_cases)
  case 1
  note Linv = LLL_invw'_imp_w[OF Linvw]
  note main = basis_reduction_add_row_main[OF Linv i j fs']
  have bnd: "g_bnd B fs  g_bnd B fs'" using main(6) unfolding g_bnd_def by auto
  note new = LLL_inv_wD[OF main(1)]
  note old = LLL_invD_weak[OF Linvw]
  have red: "weakly_reduced fs' i" using ‹weakly_reduced fs i main(6) i < m
    unfolding gram_schmidt_fs.weakly_reduced_def by auto
  have inv: "LLL_invariant_weak' i fs'" using LLL_inv_wD[OF main(1)] i < m
    by (intro LLL_invI_weak, auto intro: red)
  show ?case using inv red main bnd by auto
qed

lemma LLL_inv_weak_m_impl_i:
  assumes inv: "LLL_invariant_weak' m fs"
  and i: "i  m"
shows "LLL_invariant_weak' i fs"
proof -
  have "weakly_reduced fs i" using LLL_invD_weak(8)[OF inv]
    by (meson assms(2) gram_schmidt_fs.weakly_reduced_def le_trans less_imp_le_nat linorder_not_less)
  then show ?thesis
    using LLL_invI_weak[of fs i, OF LLL_invD_weak(3,6,7)[OF inv] _ LLL_invD_weak(1)[OF inv]] 
      LLL_invD_weak(2,4,5,8-)[OF inv] i by simp
qed
 
definition mod_invariant where 
  "mod_invariant b p first = (b  rat_of_int (p - 1)^2 / (rat_of_nat (bound_number first) + 3)
      ( e. p = log_base ^ e))"  

lemma compute_mod_of_max_gso_norm: assumes mn: "mn  0"
  and m: "m = 0  mn = 0" 
  and p: "p = compute_mod_of_max_gso_norm first mn" 
shows  
  "p > 1" 
  "mod_invariant mn p first" 
proof -
  let ?m = "bound_number first" 
  define p' where "p' = root_rat_ceiling 2 (mn * (rat_of_nat ?m + 3)) + 1" 
  define p'' where "p'' = max 2 p'" 
  define q where "q = real_of_rat (mn * (rat_of_nat ?m + 3))" 
  have *: "-1 < (0 :: real)" by simp
  also have "0  root 2 (real_of_rat (mn * (rat_of_nat ?m + 3)))" using mn by auto
  finally have "p'  0 + 1" unfolding p'_def
    by (intro plus_left_mono, simp)
  hence p': "p' > 0" by auto
  have p'': "p'' > 1" unfolding p''_def by auto
  have pp'': "p  p''" unfolding compute_mod_of_max_gso_norm_def p  p'_def[symmetric] p''_def[symmetric]
    using log_base p'' log_ceiling_sound by auto
  hence pp': "p  p'" unfolding p''_def by auto    
  show "p > 1" using pp'' p'' by auto

  have q0: "q  0" unfolding q_def using mn m by auto
  have "(mn  rat_of_int (p' - 1)^2 / (rat_of_nat ?m + 3)) 
    = (real_of_rat mn  real_of_rat (rat_of_int (p' - 1)^2 / (rat_of_nat ?m + 3)))" using of_rat_less_eq by blast
  also have " = (real_of_rat mn  real_of_rat (rat_of_int (p' - 1)^2) / real_of_rat (rat_of_nat ?m + 3))" by (simp add: of_rat_divide)
  also have " = (real_of_rat mn  ((real_of_int (p' - 1))^2) / real_of_rat (rat_of_nat ?m + 3))" 
    by (metis of_rat_of_int_eq of_rat_power)
  also have " = (real_of_rat mn  (real_of_int sqrt q)^2 / real_of_rat (rat_of_nat ?m + 3))" 
    unfolding p'_def sqrt_def q_def by simp
  also have "" 
  proof -
    have "real_of_rat mn  q / real_of_rat (rat_of_nat ?m + 3)" unfolding q_def using m
      by (auto simp: of_rat_mult)
    also have "  (real_of_int sqrt q)^2 / real_of_rat (rat_of_nat ?m + 3)" 
    proof (rule divide_right_mono)
      have "q = (sqrt q)^2" using q0 by simp
      also have "  (real_of_int sqrt q)^2" 
        by (rule power_mono, auto simp: q0)
      finally show "q  (real_of_int sqrt q)^2" .
    qed auto
    finally show ?thesis .
  qed
  finally have "mn  rat_of_int (p' - 1)^2 / (rat_of_nat ?m + 3)" .
  also have "  rat_of_int (p - 1)^2 / (rat_of_nat ?m + 3)"
    unfolding power2_eq_square
    by (intro divide_right_mono mult_mono, insert p' pp', auto) 
  finally have "mn  rat_of_int (p - 1)^2 / (rat_of_nat ?m + 3)" .
  moreover have " e. p = log_base ^ e" unfolding p compute_mod_of_max_gso_norm_def by auto
  ultimately show "mod_invariant mn p first" unfolding mod_invariant_def by auto
qed

lemma g_bnd_mode_cong: assumes " i. i < m  gso fs i = gso fs' i"
  shows "g_bnd_mode first b fs = g_bnd_mode first b fs'"
  using assms unfolding g_bnd_mode_def g_bnd_def by auto

definition LLL_invariant_mod :: "int vec list  int vec list  int mat  int  bool  rat  nat  bool" where 
  "LLL_invariant_mod fs mfs dmu p first b i = ( 
    length fs = m 
    length mfs = m 
    i  m 
    lattice_of fs = L 
    gs.lin_indpt_list (RAT fs) 
    weakly_reduced fs i 
    (map (map_vec (λx. x symmod p)) fs = mfs) 
    (i' < m.  j' < i'. ¦fs i' j'¦ < p * d fs j' * d fs (Suc j')) 
    (i' < m. j' < m.fs i' j' = dmu $$ (i',j')) 
    p > 1 
    g_bnd_mode first b fs 
    mod_invariant b p first
)"

lemma LLL_invD_mod: assumes "LLL_invariant_mod fs mfs dmu p first b i"
shows 
  "length mfs = m"
  "i  m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "weakly_reduced fs i"
  "(map (map_vec (λx. x symmod p)) fs = mfs)"
  "(i' < m. j' < i'. ¦fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.fs i' j' = dmu $$ (i',j'))"
  " i. i < m  fs ! i  carrier_vec n" 
  "set fs  carrier_vec n"
  " i. i < m  gso fs i  carrier_vec n" 
  " i. i < m  mfs ! i  carrier_vec n"
  "set mfs  carrier_vec n"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
proof (atomize (full), goal_cases)
  case 1
  interpret gs': gram_schmidt_fs_lin_indpt n "RAT fs"
    using assms LLL_invariant_mod_def gs.lin_indpt_list_def 
    by (meson gram_schmidt_fs_Rn.intro gram_schmidt_fs_lin_indpt.intro gram_schmidt_fs_lin_indpt_axioms.intro)
  have allfs: "i < m. fs ! i  carrier_vec n" using assms gs'.f_carrier 
    by (simp add: LLL.LLL_invariant_mod_def)
  then have setfs: "set fs  carrier_vec n" by (metis LLL_invariant_mod_def assms in_set_conv_nth subsetI)
  have allgso: "(i < m. gso fs i  carrier_vec n)" using assms gs'.gso_carrier
    by (simp add: LLL.LLL_invariant_mod_def)
  show ?case
    using assms gs'.fs_carrier gs'.f_carrier gs'.gso_carrier allfs allgso 
      LLL_invariant_mod_def gram_schmidt_fs.reduced_def in_set_conv_nth setfs by fastforce
qed

lemma LLL_invI_mod: assumes 
  "length mfs = m"
  "i  m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "weakly_reduced fs i"
  "map (map_vec (λx. x symmod p)) fs = mfs"
  "(i' < m. j' < i'. ¦fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.fs i' j' = dmu $$ (i',j'))"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
shows "LLL_invariant_mod fs mfs dmu p first b i" 
  unfolding LLL_invariant_mod_def using assms by blast

definition LLL_invariant_mod_weak :: "int vec list  int vec list  int mat  int  bool  rat  bool" where 
  "LLL_invariant_mod_weak fs mfs dmu p first b = ( 
    length fs = m 
    length mfs = m 
    lattice_of fs = L 
    gs.lin_indpt_list (RAT fs) 
    (map (map_vec (λx. x symmod p)) fs = mfs) 
    (i' < m.  j' < i'. ¦fs i' j'¦ < p * d fs j' * d fs (Suc j')) 
    (i' < m. j' < m.fs i' j' = dmu $$ (i',j')) 
    p > 1 
    g_bnd_mode first b fs 
    mod_invariant b p first
)"

lemma LLL_invD_modw: assumes "LLL_invariant_mod_weak fs mfs dmu p first b"
shows 
  "length mfs = m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "(map (map_vec (λx. x symmod p)) fs = mfs)"
  "(i' < m. j' < i'. ¦fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.fs i' j' = dmu $$ (i',j'))"
  " i. i < m  fs ! i  carrier_vec n" 
  "set fs  carrier_vec n"
  " i. i < m  gso fs i  carrier_vec n" 
  " i. i < m  mfs ! i  carrier_vec n"
  "set mfs  carrier_vec n"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
proof (atomize (full), goal_cases)
  case 1
  interpret gs': gram_schmidt_fs_lin_indpt n "RAT fs"
    using assms LLL_invariant_mod_weak_def gs.lin_indpt_list_def 
    by (meson gram_schmidt_fs_Rn.intro gram_schmidt_fs_lin_indpt.intro gram_schmidt_fs_lin_indpt_axioms.intro)
  have allfs: "i < m. fs ! i  carrier_vec n" using assms gs'.f_carrier 
    by (simp add: LLL.LLL_invariant_mod_weak_def)
  then have setfs: "set fs  carrier_vec n" by (metis LLL_invariant_mod_weak_def assms in_set_conv_nth subsetI)
  have allgso: "(i < m. gso fs i  carrier_vec n)" using assms gs'.gso_carrier
    by (simp add: LLL.LLL_invariant_mod_weak_def)
  show ?case
    using assms gs'.fs_carrier gs'.f_carrier gs'.gso_carrier allfs allgso 
      LLL_invariant_mod_weak_def gram_schmidt_fs.reduced_def in_set_conv_nth setfs by fastforce
qed

lemma LLL_invI_modw: assumes 
  "length mfs = m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "map (map_vec (λx. x symmod p)) fs = mfs"
  "(i' < m. j' < i'. ¦fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.fs i' j' = dmu $$ (i',j'))"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
shows "LLL_invariant_mod_weak fs mfs dmu p first b" 
  unfolding LLL_invariant_mod_weak_def using assms by blast

lemma ddμ:
  assumes i: "i < m"
  shows "d fs (Suc i) =fs i i"
proof-
  have fs i i = 1" using i by (simp add: gram_schmidt_fs.μ.simps)
  then show ?thesis using dμ_def by simp
qed

lemma d_of_main: assumes "(i' < m.fs i' i' = dmu $$ (i',i'))"
  and "i  m"
shows "d_of dmu i = d fs i" 
proof (cases "i = 0")
  case False
  with assms have "i - 1 < m" by auto
  from assms(1)[rule_format, OF this] ddμ[OF this, of fs] False
  show ?thesis by (simp add: d_of_def)
next
  case True
  thus ?thesis unfolding d_of_def True d_def by simp
qed

lemma d_of: assumes inv: "LLL_invariant_mod fs mfs dmu p b first j"
  and "i  m" 
shows "d_of dmu i = d fs i" 
  by (rule d_of_main[OF _ assms(2)], insert LLL_invD_mod(9)[OF inv], auto)

lemma d_of_weak: assumes inv: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and "i  m" 
shows "d_of dmu i = d fs i" 
  by (rule d_of_main[OF _ assms(2)], insert LLL_invD_modw(7)[OF inv], auto)

lemma compute_max_gso_norm: assumes dmu: "(i' < m.fs i' i' = dmu $$ (i',i'))" 
  and Linv: "LLL_invariant_weak fs" 
shows "g_bnd_mode first (fst (compute_max_gso_norm first dmu)) fs" 
  "fst (compute_max_gso_norm first dmu)  0" 
  "m = 0  fst (compute_max_gso_norm first dmu) = 0" 
proof -
  show gbnd: "g_bnd_mode first (fst (compute_max_gso_norm first dmu)) fs" 
  proof (cases "first  m  0")
    case False
    have "?thesis = (g_bnd (fst (compute_max_gso_norm first dmu)) fs)" unfolding g_bnd_mode_def using False by auto
    also have  unfolding g_bnd_def
    proof (intro allI impI)
      fix i
      assume i: "i < m" 
      have id: "(if first then 1 else m) = m" using False i by auto
      define list where "list = map (λ i. (d_of dmu (Suc i), d_of dmu i, i)) [0 ..< m ]" 
      obtain num denom j where ml: "max_list_rats_with_index list = (num, denom, j)" 
        by (metis prod_cases3)
      have dpos: "d fs i > 0" using LLL_d_pos[OF Linv, of i]  i by auto
      have pos: "(n, d, i)  set list  0 < d" for n d i 
        using LLL_d_pos[OF Linv] unfolding list_def using d_of_main[OF dmu] by auto
      from i have "list ! i  set list" using i unfolding list_def by auto
      also have "list ! i = (d_of dmu (Suc i), d_of dmu i, i)" unfolding list_def using i by auto
      also have " = (d fs (Suc i), d fs i, i)" using d_of_main[OF dmu] i by auto
      finally have "(d fs (Suc i), d fs i, i)  set list" . 
      from max_list_rats_with_index[OF pos ml this] 
      have "of_int (d fs (Suc i)) / of_int (d fs i)  fst (compute_max_gso_norm first dmu)" 
        unfolding compute_max_gso_norm_def list_def[symmetric] ml id split using i by auto
      also have "of_int (d fs (Suc i)) / of_int (d fs i) = sq_norm (gso fs i)" 
        using LLL_d_Suc[OF Linv i] dpos by auto
      finally show "sq_norm (gso fs i)  fst (compute_max_gso_norm first dmu)" .
    qed
    finally show ?thesis .
  next
    case True
    thus ?thesis unfolding g_bnd_mode_def compute_max_gso_norm_def using d_of_main[OF dmu] 
      LLL_d_Suc[OF Linv, of 0] LLL_d_pos[OF Linv, of 0] LLL_d_pos[OF Linv, of 1] by auto
  qed
  show "fst (compute_max_gso_norm first dmu)  0" 
  proof (cases "m = 0")
    case True
    thus ?thesis unfolding compute_max_gso_norm_def by simp
  next
    case False
    hence 0: "0 < m" by simp
    have "0  sq_norm (gso fs 0)" by blast
    also have "  fst (compute_max_gso_norm first dmu)" 
      using gbnd[unfolded g_bnd_mode_def g_bnd_def] using 0 by metis
    finally show ?thesis .
  qed
qed (auto simp: LLL.compute_max_gso_norm_def)


lemma increase_i_mod:
  assumes Linv: "LLL_invariant_mod fs mfs dmu p first b i"
  and i: "i < m" 
  and red_i: "i  0  sq_norm (gso fs (i - 1))  α * sq_norm (gso fs i)"
shows "LLL_invariant_mod fs mfs dmu p first b (Suc i)" "LLL_measure i fs > LLL_measure (Suc i) fs" 
proof -
  note inv = LLL_invD_mod[OF Linv]
  from inv have red: "weakly_reduced fs i"  by (auto)
  from red red_i i have red: "weakly_reduced fs (Suc i)" 
    unfolding gram_schmidt_fs.weakly_reduced_def
    by (intro allI impI, rename_tac ii, case_tac "Suc ii = i", auto)
  show "LLL_invariant_mod fs mfs dmu p first b (Suc i)"
    by (intro LLL_invI_mod, insert inv red i, auto)
  show "LLL_measure i fs > LLL_measure (Suc i) fs" unfolding LLL_measure_def using i by auto
qed

lemma basis_reduction_mod_add_row_main:
  assumes Linvmw: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and i: "i < m"  and j: "j < i" 
  and c: "c = round (μ fs i j)"
  and mfs': "mfs' = mfs[ i := (map_vec (λ x. x symmod p)) (mfs ! i - c v mfs ! j)]"
  and dmu': "dmu' = mat m m (λ(i',j'). (if (i' = i  j'  j) 
        then (if j'=j then (dmu $$ (i,j') - c * dmu $$ (j,j')) 
              else (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                    symmod (p * (d_of dmu j') * (d_of dmu (Suc j'))))
        else (dmu $$ (i',j'))))"
shows "(fs'. LLL_invariant_mod_weak fs' mfs' dmu' p first b 
        LLL_measure i fs' = LLL_measure i fs
         (μ_small_row i fs (Suc j)  μ_small_row i fs' j) 
         (k < m. gso fs' k = gso fs k)
         (ii  m. d fs' ii = d fs ii)
         ¦μ fs' i j¦  1 / 2
         (i' j'. i' < i  j'  i'  μ fs' i' j' = μ fs i' j')
         (LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs' mfs' dmu' p first b i))"
proof -
  define fs' where "fs' = fs[ i := fs ! i - c v fs ! j]"
  from LLL_invD_modw[OF Linvmw] have gbnd: "g_bnd_mode first b fs" and p1: "p > 1" and pgtz: "p > 0" by auto
  have Linvww: "LLL_invariant_weak fs" using LLL_invD_modw[OF Linvmw] LLL_invariant_weak_def by simp
  have 
    Linvw': "LLL_invariant_weak fs'" and
    01: "c = round (μ fs i j)  μ_small_row i fs (Suc j)  μ_small_row i fs' j" and
    02: "LLL_measure i fs' = LLL_measure i fs" and
    03: " i. i < m  gso fs' i = gso fs i" and
    04: " i' j'. i' < m  j' < m  
      μ fs' i' j' = (if i' = i  j'  j then μ fs i j' - of_int c * μ fs j j' else μ fs i' j')" and
    05: " ii. ii  m  d fs' ii = d fs ii" and 
    06: "¦μ fs' i j¦  1 / 2" and
    061: "(i' j'. i' < i  j'  i'  μ fs i' j' = μ fs' i' j')"
    using basis_reduction_add_row_main[OF Linvww i j fs'_def] c i by auto
  have 07: "lin_indep fs'" and 
    08: "length fs' = m" and 
    09: "lattice_of fs' = L" using LLL_inv_wD Linvw' by auto
  have 091: "fs_int_indpt n fs'" using 07 using Gram_Schmidt_2.fs_int_indpt.intro by simp
  define I where "I = {(i',j'). i' = i  j' < j}"
  have 10: "I  {(i',j'). i' < m  j' < i'}" "(i,j) I" "j'  j. (i,j')  I" using I_def i j by auto
  obtain fs'' where 
    11: "lattice_of fs'' = L" and
    12: "map (map_vec (λ x. x symmod p)) fs'' = map (map_vec (λ x. x symmod p)) fs'" and
    13: "lin_indep fs''" and
    14: "length fs'' = m" and
    15: "( k < m. gso fs'' k = gso fs' k)" and
    16: "( k  m. d fs'' k = d fs' k)" and
    17: "( i' < m.  j' < m.fs'' i' j' = 
      (if (i',j')  I thenfs' i' j' symmod (p * d fs' j' * d fs' (Suc j')) elsefs' i' j'))"
    using mod_finite_set[OF 07 08 10(1) 09 pgtz] by blast
  have 171: "(i' j'. i' < i  j'  i'  μ fs'' i' j' = μ fs' i' j')"
  proof -
    {
      fix i' j'
      assume i'j': "i' < i" "j'  i'"
      have "rat_of_int (fs'' i' j') = rat_of_int (fs' i' j')" using "17" I_def i i'j' by auto
      then have "rat_of_int (int_of_rat (rat_of_int (d fs'' (Suc j')) * μ fs'' i' j')) = 
        rat_of_int (int_of_rat (rat_of_int (d fs' (Suc j')) * μ fs' i' j'))"
        using dμ_def i'j' j by auto
      then have "rat_of_int (d fs'' (Suc j')) * μ fs'' i' j' = 
        rat_of_int (d fs' (Suc j')) * μ fs' i' j'" 
        by (smt "08" "091" "13" "14" d_def dual_order.strict_trans fs_int.d_def 
            fs_int_indpt.fs_int_mu_d_Z fs_int_indpt.intro i i'j'(1) i'j'(2) int_of_rat(2))
      then have fs'' i' j' = μ fs' i' j'" by (smt "16" 
            LLL_d_pos[OF Linvw'] Suc_leI int_of_rat(1)
            dual_order.strict_trans fs'_def i i'j' j 
            le_neq_implies_less nonzero_mult_div_cancel_left of_int_hom.hom_zero)
    }
    then show ?thesis by simp
  qed
  then have 172: "(i' j'. i' < i  j'  i'  μ fs'' i' j' = μ fs i' j')" using 061 by simp (* goal *)
  have 18: "LLL_measure i fs'' = LLL_measure i fs'" using 16 LLL_measure_def logD_def D_def by simp
  have 19: "(k < m. gso fs'' k = gso fs k)" using 03 15 by simp
  have "j'  {j..(m-1)}. j' < m" using j i by auto
  then have 20: "j'  {j..(m-1)}.fs'' i j' =fs' i j'" 
    using 10(3) 17 Suc_lessD less_trans_Suc by (meson atLeastAtMost_iff i)
  have 21: "j'  {j..(m-1)}. μ fs'' i j' = μ fs' i j'" 
  proof -
    {
      fix j'
      assume j': "j'  {j..(m-1)}"
      define μ'' :: rat where "μ'' = μ fs'' i j'"
      define μ' :: rat where "μ' = μ fs' i j'"
      have "rat_of_int (fs'' i j') = rat_of_int (fs' i j')" using 20 j' by simp
      moreover have "j' < length fs'" using i j' 08 by auto
      ultimately have "rat_of_int (d fs' (Suc j')) * gram_schmidt_fs.μ n (map of_int_hom.vec_hom fs') i j'
        = rat_of_int (d fs'' (Suc j')) * gram_schmidt_fs.μ n (map of_int_hom.vec_hom fs'') i j'"
        using 20 08 091 13 14 fs_int_indpt.dμ_def fs_int.d_def fs_int_indpt.dμ dμ_def d_def i fs_int_indpt.intro j'
        by metis
      then have "rat_of_int (d fs' (Suc j')) * μ'' = rat_of_int (d fs' (Suc j')) * μ'" 
        using 16 i j' μ'_def μ''_def unfolding dμ_def by auto
      moreover have "0 < d fs' (Suc j')" using LLL_d_pos[OF Linvw', of "Suc j'"] i j' by auto
      ultimately have fs'' i j' = μ fs' i j'" using μ'_def μ''_def by simp
    }
    then show ?thesis by simp
  qed
  then have 22: fs'' i j = μ fs' i j" using i j by simp
  then have 23: "¦μ fs'' i j¦  1 / 2" using 06 by simp (* goal *)
  have 24: "LLL_measure i fs'' = LLL_measure i fs" using 02 18 by simp (* goal *)
  have 25: "( k  m. d fs'' k = d fs k)" using 16 05 by simp (* goal *)
  have 26: "( k < m. gso fs'' k = gso fs k)" using 15 03 by simp (* goal *)
  have 27: "μ_small_row i fs (Suc j)  μ_small_row i fs'' j"
    using 21 01 μ_small_row_def i j c by auto (* goal *)
  have 28: "length fs = m" "length mfs = m" using LLL_invD_modw[OF Linvmw] by auto
  have 29: "map (map_vec (λx. x symmod p)) fs = mfs" using assms LLL_invD_modw by simp
  have 30: " i. i < m  fs ! i  carrier_vec n" " i. i < m  mfs ! i  carrier_vec n"
    using LLL_invD_modw[OF Linvmw] by auto
  have 31: " i. i < m  fs' ! i  carrier_vec n" using fs'_def 30(1) 
    using "08" "091" fs_int_indpt.f_carrier by blast
  have 32: " i. i < m  mfs' ! i  carrier_vec n" unfolding mfs' using 30(2) 28(2) 
    by (metis (no_types, lifting) Suc_lessD j less_trans_Suc map_carrier_vec minus_carrier_vec 
        nth_list_update_eq nth_list_update_neq smult_closed)
  have 33: "length mfs' = m" using 28(2) mfs' by simp (* invariant goal *)
  then have 34: "map (map_vec (λx. x symmod p)) fs' = mfs'"
  proof -
    {
      fix i' j'
      have j2: "j < m" using j i by auto
      assume i': "i' < m"
      assume j': "j' < n"
      then have fsij: "(fs ! i' $ j') symmod p = mfs ! i' $ j'" using 30 i' j' 28 29 by fastforce
      have "mfs' ! i $ j' = (mfs ! i $ j'- (c v mfs ! j) $ j') symmod p"
        unfolding mfs' using 30(2) j' 28 j2 
        by (metis (no_types, lifting) carrier_vecD i index_map_vec(1) index_minus_vec(1) 
            index_minus_vec(2) index_smult_vec(2) nth_list_update_eq)
      then have mfs'ij: "mfs' ! i $ j' = (mfs ! i $ j'- c * mfs ! j $ j') symmod p" 
        unfolding mfs' using 30(2) i' j' 28 j2 by fastforce
      have "(fs' ! i' $ j') symmod p = mfs' ! i' $ j'"
      proof(cases "i' = i")
        case True
        show ?thesis using fs'_def mfs' True 28 fsij 
        proof -
          have "fs' ! i' $ j' = (fs ! i' - c v fs ! j) $ j'" using fs'_def True i' j' 28(1) by simp
          also have " = fs ! i' $ j' - (c v fs ! j) $ j'" using i' j' 30(1)
            by (metis Suc_lessD carrier_vecD i index_minus_vec(1) index_smult_vec(2) j less_trans_Suc)
          finally have "fs' ! i' $ j' = fs ! i' $ j' - (c v fs ! j) $ j'" by auto
          then have "(fs' ! i' $ j') symmod p = (fs ! i' $ j' - (c v fs ! j) $ j') symmod p" by auto
          also have " = ((fs ! i' $ j') symmod p - ((c v fs ! j) $ j') symmod p) symmod p"
            by (simp add: sym_mod_diff_eq)
          also have "(c v fs ! j) $ j' = c * (fs ! j $ j')" 
            using i' j' True 28 30(1) j
            by (metis Suc_lessD carrier_vecD index_smult_vec(1) less_trans_Suc)
          also have "((fs ! i' $ j') symmod p - (c * (fs ! j $ j')) symmod p) symmod p = 
            ((fs ! i' $ j') symmod p - c * ((fs ! j $ j') symmod p)) symmod p" 
            using i' j' True 28 30(1) j by (metis sym_mod_diff_right_eq sym_mod_mult_right_eq)
          also have "((fs ! j $ j') symmod p) = mfs ! j $ j'" using 30 i' j' 28 29 j2 by fastforce
          also have "((fs ! i' $ j') symmod p - c * mfs ! j $ j') symmod p = 
            (mfs ! i' $ j' - c * mfs ! j $ j') symmod p" using fsij by simp
          finally show ?thesis using mfs'ij by (simp add: True)
        qed
      next
        case False
        show ?thesis using fs'_def mfs' False 28 fsij by simp
      qed
    }
    then have "i' < m. (map_vec (λx. x symmod p)) (fs' ! i') = mfs' ! i'"
      using 31 32 33 08 by fastforce
    then show ?thesis using 31 32 33 08 by (simp add: map_nth_eq_conv)
  qed
  then have 35: "map (map_vec (λx. x symmod p)) fs'' = mfs'" using 12 by simp (* invariant req. *)
  have 36: "lin_indep fs''"  using 13 by simp (* invariant req. *)
  have Linvw'': "LLL_invariant_weak fs''" using LLL_invariant_weak_def 11 13 14 by simp
  have 39: "(i' < m. j' < i'. ¦fs'' i' j'¦ < p * d fs'' j' * d fs'' (Suc j'))" (* invariant req. *)
  proof -
    {
      fix i' j'
      assume i': "i' < m"
      assume j': "j' < i'"
      define pdd where "pdd = (p * d fs'' j' * d fs'' (Suc j'))"
      then have pddgtz: "pdd > 0" 
        using pgtz j' LLL_d_pos[OF Linvw', of "Suc j'"] LLL_d_pos[OF Linvw', of j'] j' i' 16 by simp
      have "¦fs'' i' j'¦ < p * d fs'' j' * d fs'' (Suc j')"
      proof(cases "i' = i")
        case i'i: True
        then show ?thesis
        proof (cases "j' < j")
          case True
          then have eq'': "dμ fs'' i' j' =fs' i' j' symmod (p * d fs'' j' * d fs'' (Suc j'))"
            using 16 17 10 I_def True i' j' i'i by simp
          have "0 < pdd" using pddgtz by simp
          then show ?thesis unfolding eq'' unfolding pdd_def[symmetric] using sym_mod_abs by blast
        next
          case fls: False
          then have "(i',j')  I" using I_def i'i by simp
          then have dmufs''fs': "dμ fs'' i' j' =fs' i' j'" using 17 i' j' by simp
          show ?thesis
          proof (cases "j' = j")
            case True
            define μ'' where "μ'' = μ fs'' i' j'" 
            define d'' where "d'' = d fs'' (Suc j')"
            have pge1: "p  1" using pgtz by simp
            have lh: "¦μ''¦  1 / 2" using 23 True i'i μ''_def by simp
            moreover have eq: "dμ fs'' i' j' = μ'' * d''" using dμ_def i' j' μ''_def d''_def 
              by (smt "14" "36" LLL.d_def Suc_lessD fs_int.d_def fs_int_indpt.dμ fs_int_indpt.intro 
                  int_of_rat(1) less_trans_Suc mult_of_int_commute of_rat_mult of_rat_of_int_eq)
            moreover have Sj': "Suc j'  m" "j'  m" using True j' i i' by auto
            moreover then have gtz: "0 < d''" using LLL_d_pos[OF Linvw''] d''_def by simp
            moreover have "rat_of_int ¦fs'' i' j'¦ = ¦μ'' * (rat_of_int d'')¦" 
              using eq by (metis of_int_abs of_rat_hom.injectivity of_rat_mult of_rat_of_int_eq)
            moreover then have "¦μ'' * rat_of_int d'' ¦ =  ¦μ''¦ * rat_of_int ¦d''¦"
              by (metis (mono_tags, hide_lams) abs_mult of_int_abs)
            moreover have " = ¦μ''¦ * rat_of_int d'' " using gtz by simp
            moreover have " < rat_of_int d''" using lh gtz by simp
            ultimately have "rat_of_int ¦fs'' i' j'¦ < rat_of_int d''" by simp
            then have "¦fs'' i' j'¦ <  d fs'' (Suc j')" using d''_def by simp
            then have "¦fs'' i' j'¦ < p * d fs'' (Suc j')" using pge1
              by (smt mult_less_cancel_right2)
            then show ?thesis using pge1 LLL_d_pos[OF Linvw'' Sj'(2)] gtz unfolding d''_def
              by (smt mult_less_cancel_left2 mult_right_less_imp_less)
          next
            case False
            have "j' < m" using i' j' by simp
            moreover have "j' > j" using False fls by simp
            ultimately have fs' i' j' = μ fs i' j'" using i' 04 i by simp
            then have "dμ fs' i' j' =fs i' j'" using dμ_def i' j' 05 by simp
            then have "dμ fs'' i' j' =fs i' j'" using dmufs''fs' by simp
            then show ?thesis using LLL_invD_modw[OF Linvmw] i' j' 25 by simp
          qed
        qed
      next
        case False
        then have "(i',j')  I" using I_def by simp
        then have dmufs''fs': "dμ fs'' i' j' =fs' i' j'" using 17 i' j' by simp
        have fs' i' j' = μ fs i' j'" using i' 04 j' False by simp
        then have "dμ fs' i' j' =fs i' j'" using dμ_def i' j' 05 by simp
        moreover then have "dμ fs'' i' j' =fs i' j'" using dmufs''fs' by simp
        then show ?thesis using LLL_invD_modw[OF Linvmw] i' j' 25 by simp
      qed
    }
    then show ?thesis by simp
  qed
  have 40: "(i' < m. j' < m. i'  i  j' > j fs' i' j' = dmu $$ (i',j'))"
  proof -
    {
      fix i' j'
      assume i': "i' < m" and j': "j' < m"
      assume assm: "i'  i  j' > j"
      have "dμ fs' i' j' = dmu $$ (i',j')"
      proof (cases "i'  i")
        case True
        then show ?thesis using fs'_def LLL_invD_modw[OF Linvmw] dμ_def i i' j j'
          04 28(1) LLL_invI_weak basis_reduction_add_row_main(8)[OF Linvww] by auto
      next
        case False
        then show ?thesis 
          using 05 LLL_invD_modw[OF Linvmw] dμ_def i j j' 04 assm by simp
      qed
    }
    then show ?thesis by simp
  qed
  have 41: "j'  j.fs' i j' = dmu $$ (i,j') - c * dmu $$ (j,j')"
  proof -
    {
      let ?oi = "of_int :: _  rat" 
      fix j'
      assume j': "j'  j"
      define dj' μi μj where "dj' = d fs (Suc j')" and "μi = μ fs i j'" and "μj = μ fs j j'"
      have "?oi (fs' i j') = ?oi (d fs (Suc j')) * (μ fs i j' - ?oi c * μ fs j j')"
        using j' 04 dμ_def 
        by (smt "05" "08" "091" Suc_leI d_def diff_diff_cancel fs_int.d_def 
            fs_int_indpt.fs_int_mu_d_Z i int_of_rat(2) j less_imp_diff_less less_imp_le_nat)
      also have " = (?oi dj') * (μi - of_int c * μj)" 
        using dj'_def μi_def μj_def by (simp add: of_rat_mult)
      also have " = (rat_of_int dj') * μi - of_int c * (rat_of_int dj') * μj" by algebra
      also have " = rat_of_int (fs i j') - ?oi c * rat_of_int (fs j j')" unfolding dj'_def μi_def μj_def
        using i j j' dμ_def
        using "28"(1) LLL.LLL_invD_modw(4) Linvmw d_def fs_int.d_def fs_int_indpt.fs_int_mu_d_Z fs_int_indpt.intro by auto
      also have " = rat_of_int (dmu $$ (i,j')) - ?oi c * rat_of_int (dmu $$ (j,j'))" 
        using LLL_invD_modw(7)[OF Linvmw] dμ_def j' i j by auto
      finally have "?oi (fs' i j') = rat_of_int (dmu $$ (i,j')) - ?oi c * rat_of_int (dmu $$ (j,j'))" by simp
      then have "dμ fs' i j' = dmu $$ (i,j') - c * dmu $$ (j,j')"
        using of_int_eq_iff by fastforce
    }
    then show ?thesis by simp
  qed
  have 42: "(i' < m. j' < m.fs'' i' j' = dmu' $$ (i',j'))"
  proof -
    {
      fix i' j'
      assume i': "i' < m" and j': "j' < m"
      have "dμ fs'' i' j' = dmu' $$ (i',j')" 
      proof (cases "i' = i")
        case i'i: True
        then show ?thesis
        proof (cases "j' > j")
          case True
          then have "(i',j')I" using I_def by simp
          moreover then have "dμ fs' i' j' =fs i' j'" using "04" "05" True Suc_leI dμ_def i' j' by simp
          moreover have "dmu' $$ (i',j') = dmu $$ (i',j')" using dmu' True i' j' by simp
          ultimately show ?thesis using "17" "40" True i' j' by auto
        next
          case False
          then have j'lej: "j'  j" by simp
          then have eq': "dμ fs' i j' = dmu $$ (i,j') - c * dmu $$ (j,j')" using 41 by simp
          have id: "d_of dmu j' = d fs j'" "d_of dmu (Suc j') = d fs (Suc j')" 
            using d_of_weak[OF Linvmw] j' < m by auto
          show ?thesis
          proof (cases "j'  j")
            case True
            then have j'ltj: "j' < j" using True False by simp
            then have "(i',j')  I" using I_def True i'i by simp
            then have "dμ fs'' i' j' = 
              (dmu $$ (i,j') - c * dmu $$ (j,j')) symmod (p * d fs' j' * d fs' (Suc j'))"
              using 17 i' 41 j'lej by (simp add: j' i'i)
            also have " = (dmu $$ (i,j') - c * dmu $$ (j,j')) symmod (p * d fs j' * d fs (Suc j'))"
              using 05 i j'ltj j by simp
            also have " = dmu' $$ (i,j')" 
              unfolding dmu' index_mat(1)[OF i < m j' < m] split id using j'lej True by auto
            finally show ?thesis using i'i by simp
          next
            case False
            then have j'j: "j' = j" by simp
            then have "dμ fs'' i j' =fs' i j'" using 20 j' by simp
            also have " = dmu $$ (i,j') - c * dmu $$ (j,j')" using eq' by simp
            also have " = dmu' $$ (i,j')" using dmu' j'j i j' by simp
            finally show ?thesis using i'i by simp
          qed
        qed
      next
        case False
        then have "(i',j')I" using I_def by simp
        moreover then have "dμ fs' i' j' =fs i' j'" by (simp add: "04" "05" False Suc_leI dμ_def i' j')
        moreover then have "dmu' $$ (i',j') = dmu $$ (i',j')" using dmu' False i' j' by simp
        ultimately show ?thesis using "17" "40" False i' j' by auto
      qed
    }
    then show ?thesis by simp
  qed
  from gbnd 26 have gbnd: "g_bnd_mode first b fs''" using g_bnd_mode_cong[of fs'' fs] by simp
  {
    assume Linv: "LLL_invariant_mod fs mfs dmu p first b i"
    have Linvw: "LLL_invariant_weak' i fs" using Linv LLL_invD_mod LLL_invI_weak by simp
    note Linvww = LLL_invw'_imp_w[OF Linvw]
    have 00: "LLL_invariant_weak' i fs'" using Linvw basis_reduction_add_row_weak[OF Linvw i j fs'_def] by auto
    have 37: "weakly_reduced fs'' i" using 15 LLL_invD_weak(8)[OF 00] gram_schmidt_fs.weakly_reduced_def 
      by (smt Suc_lessD i less_trans_Suc) (* invariant req. *)
    have 38: "LLL_invariant_weak' i fs''"
      using 00 11 14 36 37 i 31 12  LLL_invariant_weak'_def by blast
    have "LLL_invariant_mod fs'' mfs' dmu' p first b i"
      using LLL_invI_mod[OF 33 _ 14 11 13 37 35 39 42 p1 gbnd LLL_invD_mod(17)[OF Linv]] i by simp
  }
  moreover have "LLL_invariant_mod_weak fs'' mfs' dmu' p first b"
    using LLL_invI_modw[OF 33 14 11 13 35 39 42 p1 gbnd LLL_invD_modw(15)[OF Linvmw]] by simp
  ultimately show ?thesis using 27 23 24 25 26 172 by auto
qed

definition D_mod :: "int mat  nat" where "D_mod dmu = nat ( i < m. d_of dmu i)"

definition logD_mod :: "int mat  nat"
  where "logD_mod dmu = (if α = 4/3 then (D_mod dmu) else nat (floor (log (1 / of_rat reduction) (D_mod dmu))))" 
end

locale fs_int'_mod = 
  fixes n m fs_init α i fs mfs dmu p first b 
  assumes LLL_inv_mod: "LLL.LLL_invariant_mod n m fs_init α fs mfs dmu p first b i"

context LLL_with_assms
begin

lemma basis_reduction_swap_weak': assumes Linvw: "LLL_invariant_weak' i fs"
  and i: "i < m"
  and i0: "i  0"
  and mu_F1_i: "¦μ fs i (i-1)¦  1 / 2"
  and norm_ineq: "sq_norm (gso fs (i - 1)) > α * sq_norm (gso fs i)" 
  and fs'_def: "fs' = fs[i := fs ! (i - 1), i - 1 := fs ! i]" 
shows "LLL_invariant_weak' (i - 1) fs'" 
proof -
  note inv = LLL_invD_weak[OF Linvw]
  note invw = LLL_invw'_imp_w[OF Linvw]
  note main = basis_reduction_swap_main[OF invw disjI2[OF mu_F1_i] i i0 norm_ineq fs'_def]
  note inv' = LLL_inv_wD[OF main(1)]
  from ‹weakly_reduced fs i have "weakly_reduced fs (i - 1)" 
    unfolding gram_schmidt_fs.weakly_reduced_def by auto
  also have "weakly_reduced fs (i - 1) = weakly_reduced fs' (i - 1)" 
    unfolding gram_schmidt_fs.weakly_reduced_def 
    by (intro all_cong, insert i0 i main(5), auto)
  finally have red: "weakly_reduced fs' (i - 1)" .
  show "LLL_invariant_weak' (i - 1) fs'" using i
    by (intro LLL_invI_weak red inv', auto)
qed

lemma basis_reduction_add_row_done_weak: 
  assumes Linv: "LLL_invariant_weak' i fs"
  and i: "i < m" 
  and mu_small: "μ_small_row i fs 0" 
shows "μ_small fs i"
proof -
  note inv = LLL_invD_weak[OF Linv]
  from mu_small 
  have mu_small: "μ_small fs i" unfolding μ_small_row_def μ_small_def by auto
  show ?thesis
    using i mu_small LLL_invI_weak[OF inv(3,6,7,9,1)] by auto
qed     

lemma LLL_invariant_mod_to_weak_m_to_i: assumes
  inv: "LLL_invariant_mod fs mfs dmu p first b m"
  and i: "i  m"
shows "LLL_invariant_mod fs mfs dmu p first b i"
  "LLL_invariant_weak' m fs"
  "LLL_invariant_weak' i fs"
proof -
  show "LLL_invariant_mod fs mfs dmu p first b i" 
  proof -
    have "LLL_invariant_weak' m fs" using LLL_invD_mod[OF inv] LLL_invI_weak by simp
    then have "LLL_invariant_weak' i fs" using LLL_inv_weak_m_impl_i i by simp
    then have "weakly_reduced fs i" using i LLL_invD_weak(8) by simp
    then show ?thesis using LLL_invD_mod[OF inv] LLL_invI_mod i by simp
  qed
  then show fsinvwi: "LLL_invariant_weak' i fs" using LLL_invD_mod LLL_invI_weak by simp
  show "LLL_invariant_weak' m fs" using LLL_invD_mod[OF inv] LLL_invI_weak by simp
qed

lemma basis_reduction_mod_swap_main: 
  assumes Linvmw: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and k: "k < m"
  and k0: "k  0"
  and mu_F1_i: "¦μ fs k (k-1)¦  1 / 2"
  and norm_ineq: "sq_norm (gso fs (k - 1)) > α * sq_norm (gso fs k)" 
  and mfs'_def: "mfs' = mfs[k := mfs ! (k - 1), k - 1 := mfs ! k]"
  and dmu'_def: "dmu' = (mat m m (λ(i,j). (
      if j < i then
        if i = k - 1 then 
           dmu $$ (k, j)
        else if i = k  j  k - 1 then 
             dmu $$ (k - 1, j)
        else if i > k  j = k then
           ((d_of dmu (Suc k)) * dmu $$ (i, k - 1) - dmu $$ (k, k - 1) * dmu $$ (i, j)) 
              div (d_of dmu k)
        else if i > k  j = k - 1 then
           (dmu $$ (k, k - 1) * dmu $$ (i, j) + dmu $$ (i, k) * (d_of dmu (k-1)))
              div (d_of dmu k)
        else dmu $$ (i, j)
      else if i = j then 
        if i = k - 1 then 
          ((d_of dmu (Suc k)) * (d_of dmu (k-1)) + dmu $$ (k, k - 1) * dmu $$ (k, k - 1)) 
            div (d_of dmu k)
        else (d_of dmu (Suc i))
      else dmu $$ (i, j))
    ))"
  and dmu'_mod_def: "dmu'_mod = mat m m (λ(i, j). (
        if j < i  (j = k  j = k - 1) then 
          dmu' $$ (i, j) symmod (p * (d_of dmu' j) * (d_of dmu' (Suc j)))
        else dmu' $$ (i, j)))"
shows "(fs'. LLL_invariant_mod_weak fs' mfs' dmu'_mod p first b 
        LLL_measure (k-1) fs' < LLL_measure k fs 
        (LLL_invariant_mod fs mfs dmu p first b k  LLL_invariant_mod fs' mfs' dmu'_mod p first b (k-1)))" 
proof - 
  define fs' where "fs' = fs[k := fs ! (k - 1), k - 1 := fs ! k]"
  have pgtz: "p > 0" and p1: "p > 1" using LLL_invD_modw[OF Linvmw] by auto
  have invw: "LLL_invariant_weak fs" using LLL_invD_modw[OF Linvmw] LLL_invariant_weak_def by simp
  note swap_main = basis_reduction_swap_main(3-)[OF invw disjI2[OF mu_F1_i] k k0 norm_ineq fs'_def]
  note ddμ_swap = d_dμ_swap[OF invw disjI2[OF mu_F1_i] k k0 norm_ineq fs'_def]
  have invw': "LLL_invariant_weak fs'" using fs'_def assms invw basis_reduction_swap_main(1) by simp
  have 02: "LLL_measure k fs > LLL_measure (k - 1) fs'" by fact
  have 03: " i j. i < m  j < i fs' i j = (
        if i = k - 1 thenfs k j
        else if i = k  j  k - 1 thenfs (k - 1) j
        else if i > k  j = k then
           (d fs (Suc k) *fs i (k - 1) -fs k (k - 1) *fs i j) div d fs k
        else if i > k  j = k - 1 then 
           (fs k (k - 1) *fs i j +fs i k * d fs (k - 1)) div d fs k
        elsefs i j)"
    using ddμ_swap by auto
  have 031: "i. i < k-1  gso fs' i = gso fs i" 
    using swap_main(2) k k0 by auto
  have 032: " ii. ii  m  of_int (d fs' ii) = (if ii = k then 
           sq_norm (gso fs' (k - 1)) / sq_norm (gso fs (k - 1)) * of_int (d fs k)
           else of_int (d fs ii))" 
    by fact 
  have gbnd: "g_bnd_mode first b fs'"
  proof (cases "first  m  0")
    case True
    have "sq_norm (gso fs' 0)  sq_norm (gso fs 0)" 
    proof (cases "k - 1 = 0")
      case False
      thus ?thesis using 031[of 0] by simp
    next
      case *: True
      have k_1: "k - 1 < m" using k by auto
      from * k0 have k1: "k = 1" by simp
      (* this is a copy of what is done in LLL.swap-main, should be made accessible in swap-main *)
      have "sq_norm (gso fs' 0)  abs (sq_norm (gso fs' 0))" by simp
      also have " = abs (sq_norm (gso fs 1) + μ fs 1 0 * μ fs 1 0 * sq_norm (gso fs 0))" 
        by (subst swap_main(3)[OF k_1, unfolded *], auto simp: k1)
      also have "  sq_norm (gso fs 1) + abs (μ fs 1 0) * abs (μ fs 1 0) * sq_norm (gso fs 0)"
        by (simp add: sq_norm_vec_ge_0)
      also have "  sq_norm (gso fs 1) + (1 / 2) * (1 / 2) * sq_norm (gso fs 0)" 
        using mu_F1_i[unfolded k1] 
        by (intro plus_right_mono mult_mono, auto)
      also have " < 1 / α * sq_norm (gso fs 0) + (1 / 2) * (1 / 2) * sq_norm (gso fs 0)" 
        by (intro add_strict_right_mono, insert norm_ineq[unfolded mult.commute[of α],
          THEN mult_imp_less_div_pos[OF α0(1)]] k1, auto)
      also have " = reduction * sq_norm (gso fs 0)" unfolding reduction_def
        using α0 by (simp add: ring_distribs add_divide_distrib)
      also have "  1 * sq_norm (gso fs 0)" using reduction(2)
        by (intro mult_right_mono, auto)
      finally show ?thesis by simp
    qed
    thus ?thesis using LLL_invD_modw(14)[OF Linvmw] True
      unfolding g_bnd_mode_def by auto
  next
    case False
    from LLL_invD_modw(14)[OF Linvmw] False have "g_bnd b fs" unfolding g_bnd_mode_def by auto
    hence "g_bnd b fs'" using g_bnd_swap[OF k k0 invw mu_F1_i norm_ineq fs'_def] by simp
    thus ?thesis using False unfolding g_bnd_mode_def by auto
  qed
  note d_of = d_of_weak[OF Linvmw]
  have 033: " i. i < m fs' i i = (
            if i = k - 1 then 
             ((d_of dmu (Suc k)) * (d_of dmu (k-1)) + dmu $$ (k, k - 1) * dmu $$ (k, k - 1)) 
                div (d_of dmu k)
            else (d_of dmu (Suc i)))"  
  proof -
    fix i
    assume i: "i < m"
    have "dμ fs' i i = d fs' (Suc i)" using ddμ i by simp
    also have " = (if i = k - 1 then 
          (d fs (Suc k) * d fs (k - 1) +fs k (k - 1) *fs k (k - 1)) div d fs k 
        else d fs (Suc i))"
      by (subst ddμ_swap, insert ddμ k0 i, auto)
    also have " = (if i = k - 1 then 
        ((d_of dmu (Suc k)) * (d_of dmu (k-1)) + dmu $$ (k, k - 1) * dmu $$ (k, k - 1)) 
          div (d_of dmu k)
       else (d_of dmu (Suc i)))" (is "_ = ?r") 
      using d_of i k LLL_invD_modw(7)[OF Linvmw] by auto
    finally show "dμ fs' i i = ?r" .
  qed
  have 04: "lin_indep fs'" "length fs' = m" "lattice_of fs' = L" using LLL_inv_wD[OF invw'] by auto
  define I where "I = {(i, j). i < m  j < i  (j = k  j = k - 1)}"
  then have Isubs: "I  {(i,j). i < m  j < i}" using k k0 by auto
  obtain fs'' where 
    05: "lattice_of fs'' = L" and
    06: "map (map_vec (λ x. x symmod p)) fs'' = map (map_vec (λ x. x symmod p)) fs'" and
    07: "lin_indep fs''" and
    08: "length fs'' = m" and
    09: "( k < m. gso fs'' k = gso fs' k)" and
    10: "( k  m. d fs'' k = d fs' k)" and
    11: "( i' < m.  j' < m.fs'' i' j' = 
           (if (i',j')  I thenfs' i' j' symmod (p * d fs' j' * d fs' (Suc j')) elsefs' i' j'))"
    using mod_finite_set[OF 04(1) 04(2) Isubs 04(3) pgtz] by blast
  have 13: "length mfs' = m" using mfs'_def LLL_invD_modw(1)[OF Linvmw] by simp (* invariant requirement *)
  have 14: "map (map_vec (λ x. x symmod p)) fs'' = mfs'"  (* invariant requirement *)
    using 06 fs'_def k k0 04(2) LLL_invD_modw(5)[OF Linvmw]
    by (metis (no_types, lifting) length_list_update less_imp_diff_less map_update mfs'_def nth_map)
  have "LLL_measure (k - 1) fs'' = LLL_measure (k - 1) fs'" using 10 LLL_measure_def logD_def D_def by simp
  then have 15: "LLL_measure (k - 1) fs'' < LLL_measure k fs" using 02 by simp (* goal *)
  {
    fix i' j'
    assume i'j': "i'<m" "j'<i'" 
      and neq: "j'  k" "j'  k - 1"
    hence j'k: "j'  k" "Suc j'  k" using k0 by auto
    hence "d fs'' j' = d fs j'" "d fs'' (Suc j') = d fs (Suc j')" 
      using k < m i'j' k0
        10[rule_format, of j'] 032[rule_format, of j']
        10[rule_format, of "Suc j'"] 032[rule_format, of "Suc j'"] 
      by auto
  } note d_id = this

  have 16: "i'<m. j'<i'. ¦fs'' i' j'¦ < p * d fs'' j' * d fs'' (Suc j')" (* invariant requirement *)
  proof -
    {
      fix i' j'
      assume i'j': "i'<m" "j'<i'"
      have "¦fs'' i' j'¦ < p * d fs'' j' * d fs'' (Suc j')"
      proof (cases "(i',j')  I")
        case True
        define pdd where "pdd = (p * d fs' j' * d fs' (Suc j'))"
        have pdd_pos: "pdd > 0" using pgtz i'j' LLL_d_pos[OF invw'] pdd_def by simp
        have "dμ fs'' i' j' =fs' i' j' symmod pdd" using True 11 i'j' pdd_def by simp
        then have "¦fs'' i' j'¦ < pdd" using True 11 i'j' pdd_pos sym_mod_abs by simp
        then show ?thesis unfolding pdd_def using 10 i'j' by simp
      next
        case False
        from False[unfolded I_def] i'j' have neg: "j'  k" "j'  k - 1" by auto
        
        consider (1) "i' = k - 1  i' = k" | (2) "¬ (i' = k - 1  i' = k)"  
          using False i'j' unfolding I_def by linarith
        thus ?thesis
        proof cases
          case **: 1
          let ?i'' = "if i' = k - 1 then k else k -1" 
          from ** neg i'j' have i'': "?i'' < m" "j' < ?i''" using k0 k by auto
          have "dμ fs'' i' j' =fs' i' j'" using 11 False i'j' by simp
          also have " =fs ?i'' j'" unfolding 03[OF i' < m j' < i']
            using ** neg by auto
          finally show ?thesis using LLL_invD_modw(6)[OF Linvmw, rule_format, OF i''] unfolding d_id[OF i'j' neg] by auto
        next
          case **: 2
          hence neq: "j'  k" "j'  k - 1" using False k k0 i'j' unfolding I_def by auto
          have "dμ fs'' i' j' =fs' i' j'" using 11 False i'j' by simp
          also have " =fs i' j'" unfolding 03[OF i' < m j' < i'] using ** neq by auto
          finally show ?thesis using LLL_invD_modw(6)[OF Linvmw, rule_format, OF i'j'] using d_id[OF i'j' neq] by auto
        qed
      qed
    }
    then show ?thesis by simp
  qed
  have 17: "i'<m. j'<m.fs'' i' j' = dmu'_mod $$ (i', j')" (* invariant requirement *)
  proof -
    {
      fix i' j'
      assume i'j': "i'<m" "j'<i'"
      have d'dmu': "j' < m. d fs' (Suc j') = dmu' $$ (j', j')" using ddμ dmu'_def 033 by simp
      have eq': "dμ fs' i' j' = dmu' $$ (i', j')"
      proof -
        have t00: "dμ fs k j' = dmu $$ (k, j')" and
          t01: "dμ fs (k - 1) j' =  dmu $$ (k - 1, j')" and
          t04: "dμ fs k (k - 1) = dmu $$ (k, k - 1)" and
          t05: "dμ fs i' k = dmu $$ (i', k)"
          using LLL_invD_modw(7)[OF Linvmw] i'j' k ddμ k0 by auto 
        have t03: "d fs k =fs (k-1) (k-1)" using k0 k by (metis LLL.ddμ Suc_diff_1 lessI not_gr_zero)
        have t06: "d fs (k - 1) = (d_of dmu (k-1))" using d_of k by auto
        have t07: "d fs k = (d_of dmu k)" using d_of k by auto
        have j': "j' < m" using i'j' by simp
        have "dμ fs' i' j' = (if i' = k - 1 then 
                   dmu $$ (k, j')
                else if i' = k  j'  k - 1 then 
                   dmu $$ (k - 1, j')
                else if i' > k  j' = k then
                   (dmu $$ (k, k) * dmu $$ (i', k - 1) - dmu $$ (k, k - 1) * dmu $$ (i', j')) div (d_of dmu k)
                else if i' > k  j' = k - 1 then 
                   (dmu $$ (k, k - 1) * dmu $$ (i', j') + dmu $$ (i', k) * d fs (k - 1)) div (d_of dmu k)
                else dmu $$ (i', j'))"
          using ddμ k t00 t01 t03 LLL_invD_modw(7)[OF Linvmw] k i'j' j' 03 t07 by simp
        then show ?thesis using dmu'_def i'j' j' t06 t07 by (simp add: d_of_def)
      qed
      have "dμ fs'' i' j' = dmu'_mod $$ (i', j')"
      proof (cases "(i',j')  I")
        case i'j'I: True
        have j': "j' < m" using i'j' by simp
        show ?thesis
        proof -
          have "dmu'_mod $$ (i',j') = dmu' $$ (i',j') 
                  symmod (p * (d_of dmu' j') * (d_of dmu' (Suc j')))"
            using dmu'_mod_def i'j' i'j'I I_def by simp
          also have "d_of dmu' j' = d fs' j'" 
            using j' d'dmu' d_def Suc_diff_1 less_imp_diff_less unfolding d_of_def 
            by (cases j', auto)
          finally have "dmu'_mod $$ (i',j') = dmu' $$ (i',j') symmod (p * d fs' j' * d fs' (Suc j'))"
            using ddμ[OF j'] d'dmu' j' by (auto simp: d_of_def)
          then show ?thesis using i'j'I 11 i'j' eq' by simp
        qed
      next
        case False
        have "dμ fs'' i' j' =fs' i' j'" using False 11 i'j' by simp
        also have " = dmu' $$ (i', j')" unfolding eq' ..
        finally show ?thesis unfolding dmu'_mod_def using False[unfolded I_def] i'j' by auto
      qed
    }
    moreover have "i' j'. i' < m  j' < m  i' = j' fs'' i' j' = dmu'_mod $$ (i', j')" 
      using ddμ dmu'_def 033 10 dmu'_mod_def 11 I_def by simp
    moreover {
      fix i' j'
      assume i'j'': "i' < m" "j' < m" "i' < j'"
      then have μz: fs'' i' j' = 0" by (simp add: gram_schmidt_fs.μ.simps)
      have "dmu'_mod $$ (i',j') = dmu' $$ (i',j')" using dmu'_mod_def i'j'' by auto
      also have " =fs i' j'" using LLL_invD_modw(7)[OF Linvmw] i'j'' dmu'_def by simp
      also have " = 0" using dμ_def i'j'' by (simp add: gram_schmidt_fs.μ.simps)
      finally have "dμ fs'' i' j' =  dmu'_mod $$ (i',j')" using μz d_def i'j'' dμ_def by simp
    }
    ultimately show ?thesis by (meson nat_neq_iff)
  qed
  from gbnd 09 have g_bnd: "g_bnd_mode first b fs''" using g_bnd_mode_cong[of fs' fs''] by auto
  {
    assume Linv: "LLL_invariant_mod fs mfs dmu p first b k"
    have 00: "LLL_invariant_weak' k fs" using LLL_invD_mod[OF Linv] LLL_invI_weak by simp
    note swap_weak' = basis_reduction_swap_weak'[OF 00 k k0 mu_F1_i norm_ineq fs'_def]
    have 01: "LLL_invariant_weak' (k - 1) fs'" by fact
    have 12: "weakly_reduced fs'' (k-1)" (* invariant requirement *)
      using 031 09 k LLL_invD_weak(8)[OF 00] unfolding gram_schmidt_fs.weakly_reduced_def by simp
    have "LLL_invariant_mod fs'' mfs' dmu'_mod p first b (k-1)" 
      using LLL_invI_mod[OF 13 _ 08 05 07 12 14 16 17 p1 g_bnd LLL_invD_mod(17)[OF Linv]] k by simp
  }
  moreover have "LLL_invariant_mod_weak fs'' mfs' dmu'_mod p first b"
    using LLL_invI_modw[OF 13 08 05 07 14 16 17 p1 g_bnd LLL_invD_modw(15)[OF Linvmw]] by simp
  ultimately show ?thesis using 15 by auto
qed

lemma dmu_quot_is_round_of_μ:
  assumes Linv: "LLL_invariant_mod fs mfs dmu p first b i'"
    and c: "c = round_num_denom (dmu $$ (i,j)) (d_of dmu (Suc j))" 
    and i: "i < m"
    and j: "j < i"
  shows "c = round(μ fs i j)" 
proof -
  have Linvw: "LLL_invariant_weak' i' fs" using LLL_invD_mod[OF Linv] LLL_invI_weak by simp
  have j2: "j < m" using i j by simp
  then have j3: "Suc j  m" by simp
  have μ1: fs j j = 1" using i j by (meson gram_schmidt_fs.μ.elims less_irrefl_nat)
  have inZ: "rat_of_int (d fs (Suc j)) * μ fs i j  " using fs_int_indpt.fs_int_mu_d_Z_m_m i j
      LLL_invD_mod(5)[OF Linv] LLL_invD_weak(2) Linvw d_def fs_int.d_def fs_int_indpt.intro by auto
  have "c = round(rat_of_int (fs i j) / rat_of_int (fs j j))" using LLL_invD_mod(9) Linv i j c 
    by (simp add: round_num_denom d_of_def)
  then show ?thesis using LLL_d_pos[OF LLL_invw'_imp_w[OF Linvw] j3] j i inZ dμ_def μ1 by simp
qed

lemma dmu_quot_is_round_of_μ_weak:
  assumes Linv: "LLL_invariant_mod_weak fs mfs dmu p first b"
    and c: "c = round_num_denom (dmu $$ (i,j)) (d_of dmu (Suc j))" 
    and i: "i < m"
    and j: "j < i"
  shows "c = round(μ fs i j)" 
proof -
  have Linvww: "LLL_invariant_weak fs" using LLL_invD_modw[OF Linv] LLL_invariant_weak_def by simp
  have j2: "j < m" using i j by simp
  then have j3: "Suc j  m" by simp
  have μ1: fs j j = 1" using i j by (meson gram_schmidt_fs.μ.elims less_irrefl_nat)
  have inZ: "rat_of_int (d fs (Suc j)) * μ fs i j  " using fs_int_indpt.fs_int_mu_d_Z_m_m i j
      LLL_invD_modw[OF Linv] d_def fs_int.d_def fs_int_indpt.intro by auto
  have "c = round(rat_of_int (fs i j) / rat_of_int (fs j j))" using LLL_invD_modw(7) Linv i j c 
    by (simp add: round_num_denom d_of_def)
  then show ?thesis using LLL_d_pos[OF Linvww j3] j i inZ dμ_def μ1 by simp
qed  

lemma basis_reduction_mod_add_row: assumes 
  Linv: "LLL_invariant_mod_weak fs mfs dmu p first b" 
  and res: "basis_reduction_mod_add_row p mfs dmu i j = (mfs', dmu')" 
  and i: "i < m"
  and j: "j < i"
  and igtz: "i  0"
shows "(fs'. LLL_invariant_mod_weak fs' mfs' dmu' p first b 
        LLL_measure i fs' = LLL_measure i fs 
        (μ_small_row i fs (Suc j)  μ_small_row i fs' j) 
        ¦μ fs' i j¦  1 / 2 
        (i' j'. i' < i  j'  i'  μ fs' i' j' = μ fs i' j') 
        (LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs' mfs' dmu' p first b i) 
        (ii  m. d fs' ii = d fs ii))"
proof -
  define c where "c = round_num_denom (dmu $$ (i,j)) (d_of dmu (Suc j))" 
  then have c: "c = round(μ fs i j)" using dmu_quot_is_round_of_μ_weak[OF Linv c_def i j] by simp
  show ?thesis
  proof (cases "c = 0")
    case True
    then have pair_id: "(mfs', dmu') = (mfs, dmu)" 
      using res c_def unfolding basis_reduction_mod_add_row_def Let_def by auto
    moreover have "¦μ fs i j¦  inverse 2" using c[symmetric, unfolded True] 
      by (simp add: round_def, linarith)
    moreover then have "(μ_small_row i fs (Suc j)  μ_small_row i fs j)" 
      unfolding μ_small_row_def using Suc_leI le_neq_implies_less by blast
    ultimately show ?thesis using Linv pair_id by auto
  next
    case False
    then have pair_id: "(mfs', dmu') = (mfs[i := map_vec (λx. x symmod p) (mfs ! i - c v mfs ! j)],
                mat m m (λ(i', j'). if i' = i  j'  j
                  then if j' = j then dmu $$ (i, j') - c * dmu $$ (j, j')
                       else (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                              symmod (p * (d_of dmu j') * (d_of dmu (Suc j')))
                  else dmu $$ (i', j')))" 
      using res c_def unfolding basis_reduction_mod_add_row_def Let_def by auto
    then have mfs': "mfs' = mfs[i := map_vec (λx. x symmod p) (mfs ! i - c v mfs ! j)]"
      and dmu': "dmu' = mat m m (λ(i', j'). if i' = i  j'  j
                  then if j' = j then dmu $$ (i, j') - c * dmu $$ (j, j')
                       else (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                              symmod (p * (d_of dmu j') * (d_of dmu (Suc j')))
                  else dmu $$ (i', j'))" by auto
    show ?thesis using basis_reduction_mod_add_row_main[OF Linv i j c mfs' dmu'] by blast
  qed
qed

lemma basis_reduction_mod_swap: assumes
  Linv: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and mu: "¦μ fs k (k-1)¦  1 / 2"
  and res: "basis_reduction_mod_swap p mfs dmu k = (mfs', dmu'_mod)" 
  and cond: "sq_norm (gso fs (k - 1)) > α * sq_norm (gso fs k)"
  and i: "k < m" "k  0" 
shows "(fs'. LLL_invariant_mod_weak fs' mfs' dmu'_mod p first b 
        LLL_measure (k - 1) fs' < LLL_measure k fs 
        (LLL_invariant_mod fs mfs dmu p first b k  LLL_invariant_mod fs' mfs' dmu'_mod p first b (k-1)))"
  using res[unfolded basis_reduction_mod_swap_def basis_reduction_mod_swap_dmu_mod_def] 
    basis_reduction_mod_swap_main[OF Linv i mu cond] by blast

lemma basis_reduction_adjust_mod: assumes 
  Linv: "LLL_invariant_mod_weak fs mfs dmu p first b" 
  and res: "basis_reduction_adjust_mod p first mfs dmu = (p', mfs', dmu', g_idx')" 
shows "(fs' b'. (LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs' mfs' dmu' p' first b' i) 
       LLL_invariant_mod_weak fs' mfs' dmu' p' first b' 
       LLL_measure i fs' = LLL_measure i fs)"
proof (cases " g_idx. basis_reduction_adjust_mod p first mfs dmu = (p, mfs, dmu, g_idx)")
  case True
  thus ?thesis using res Linv by auto
next
  case False
  obtain b' g_idx where norm: "compute_max_gso_norm first dmu = (b', g_idx)" by force
  define p'' where "p'' = compute_mod_of_max_gso_norm first b'" 
  define d_vec where "d_vec = vec (Suc m) (λi. d_of dmu i)" 
  define mfs'' where "mfs'' = map (map_vec (λx. x symmod p'')) mfs"  
  define dmu'' where "dmu'' = mat m m (λ(i, j).
                   if j < i then dmu $$ (i, j) symmod (p'' * d_vec $ j * d_vec $ Suc j)
                   else dmu $$ (i, j))" 
  note res = res False
  note res = res[unfolded basis_reduction_adjust_mod.simps Let_def norm split, 
      folded p''_def, folded d_vec_def mfs''_def, folded dmu''_def]
  from res have pp': "p'' < p" and id: "dmu' = dmu''" "mfs' = mfs''" "p' = p''" "g_idx' = g_idx"
    by (auto split: if_splits)
  define I where "I = {(i',j'). i' < m  j' < i'}"
  note inv = LLL_invD_modw[OF Linv]
  from inv(4) have lin: "gs.lin_indpt_list (RAT fs)" .
  from inv(3) have lat: "lattice_of fs = L" .
  from inv(2) have len: "length fs = m" .
  have weak: "LLL_invariant_weak fs" using Linv
    by (auto simp: LLL_invariant_mod_weak_def LLL_invariant_weak_def)
  from compute_max_gso_norm[OF _ weak, of dmu first, unfolded norm] inv(7)
  have bnd: "g_bnd_mode first b' fs" and b': "b'  0" "m = 0  b' = 0" by auto
  from compute_mod_of_max_gso_norm[OF b' p''_def] 
  have p'': "0 < p''" "1 < p''" "mod_invariant b' p'' first" 
    by auto
  obtain fs' where 
    01: "lattice_of fs' = L" and
    02: "map (map_vec (λ x. x symmod p'')) fs' = map (map_vec (λ x. x symmod p'')) fs" and
    03: "lin_indep fs'" and
    04: "length fs' = m" and
    05: "( k < m. gso fs' k = gso fs k)" and
    06: "( k  m. d fs' k = d fs k)" and
    07: "( i' < m.  j' < m.fs' i' j' = 
      (if (i',j')  I thenfs i' j' symmod (p'' * d fs j' * d fs (Suc j')) elsefs i' j'))"
    using mod_finite_set[OF lin len _ lat, of I] I_def p'' by blast
  from bnd 05 have bnd: "g_bnd_mode first b' fs'" using g_bnd_mode_cong[of fs fs'] by auto
  have D: "D fs = D fs'" unfolding D_def using 06 by auto  


  have Linv': "LLL_invariant_mod_weak fs' mfs'' dmu'' p'' first b'"
  proof (intro LLL_invI_modw p'' 04 03 01 bnd)
    {
      have "mfs'' = map (map_vec (λx. x symmod p'')) mfs" by fact
      also have " = map (map_vec (λx. x symmod p'')) (map (map_vec (λx. x symmod p)) fs)" 
        using inv by simp
      also have " = map (map_vec (λx. x symmod p symmod p'')) fs" by auto
      also have "(λ x. x symmod p symmod p'') = (λ x. x symmod p'')" 
      proof (intro ext)
        fix x
        from ‹mod_invariant b p first[unfolded mod_invariant_def] obtain e where 
          p: "p = log_base ^ e" by auto
        from p''[unfolded mod_invariant_def] obtain e' where
          p'': "p'' = log_base ^ e'" by auto
        from pp'[unfolded p p''] log_base have "e'  e" by simp
        hence dvd: "p'' dvd p" unfolding p p'' using log_base by (metis le_imp_power_dvd)
        thus "x symmod p symmod p'' = x symmod p''"  
          by (intro sym_mod_sym_mod_cancel)
      qed
      finally show "map (map_vec (λx. x symmod p'')) fs' = mfs''" unfolding 02 ..
    }
    thus "length mfs'' = m" using 04 by auto
    show "i'<m. j'<i'. ¦fs' i' j'¦ < p'' * d fs' j' * d fs' (Suc j')"
    proof -
      {
        fix i' j'
        assume i'j': "i' < m" "j' < i'"
        then have "dμ fs' i' j' =fs i' j' symmod (p'' * d fs' j' * d fs' (Suc j'))"
          using 07 06 unfolding I_def by simp
        then have "¦fs' i' j'¦ < p'' * d fs' j' * d fs' (Suc j')" 
          using sym_mod_abs p'' LLL_d_pos[OF weak] mult_pos_pos
          by (smt "06" i'j' less_imp_le_nat less_trans_Suc nat_SN.gt_trans)
      }
      then show ?thesis by simp
    qed
    from inv(7) have dmu: "i' < m  j' < m  dmu $$ (i', j') =fs i' j'" for i' j'
      by auto
    note d_of = d_of_weak[OF Linv]
    have dvec: "i  m  d_vec $ i = d fs i" for i unfolding d_vec_def using d_of by auto
    show "i'<m. j'<m.fs' i' j' = dmu'' $$ (i', j')" 
      using 07 unfolding dmu''_def I_def 
      by (auto simp: dmu dvec)
  qed

  moreover 
  {
    assume linv: "LLL_invariant_mod fs mfs dmu p first b i" 
    note inv = LLL_invD_mod[OF linv]
    hence i: "i  m" by auto
    have norm: "j < m  gso fs j2 = gso fs' j2" for j
      using 05 by auto
    have "weakly_reduced fs i = weakly_reduced fs' i" 
      unfolding gram_schmidt_fs.weakly_reduced_def using i
      by (intro all_cong arg_cong2[where f = "(≤)"] arg_cong[where f = "λ x. _ * x"] norm, auto)
    with inv have "weakly_reduced fs' i" by auto
    hence "LLL_invariant_mod fs' mfs'' dmu'' p'' first b' i" using inv         
      by (intro LLL_invI_mod LLL_invD_modw[OF Linv'])
  }

  moreover have "LLL_measure i fs' = LLL_measure i fs" 
    unfolding LLL_measure_def logD_def D ..
  ultimately show ?thesis unfolding id by blast
qed

lemma alpha_comparison: assumes 
  Linv: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and alph: "quotient_of α = (num, denom)" 
  and i: "i < m" 
  and i0: "i  0" 
shows "(d_of dmu i * d_of dmu i * denom  num * d_of dmu (i - 1) * d_of dmu (Suc i))
  = (sq_norm (gso fs (i - 1))  α * sq_norm (gso fs i))" 
proof - 
  note inv = LLL_invD_modw[OF Linv]
  interpret fs_indep: fs_int_indpt n fs
    by (unfold_locales, insert inv, auto)
  from inv(2) i have ifs: "i < length fs" by auto
  note d_of_fs = d_of_weak[OF Linv]
  show ?thesis 
    unfolding fs_indep.d_sq_norm_comparison[OF alph ifs i0, symmetric]
    by (subst (1 2 3 4) d_of_fs, use i d_def fs_indep.d_def in auto)
qed

lemma basis_reduction_adjust_swap_add_step: assumes 
  Linv: "LLL_invariant_mod_weak fs mfs dmu p first b" 
  and res: "basis_reduction_adjust_swap_add_step p first mfs dmu g_idx i = (p', mfs', dmu', g_idx')" 
  and alph: "quotient_of α = (num, denom)" 
  and ineq: "¬ (d_of dmu i * d_of dmu i * denom
               num * d_of dmu (i - 1) * d_of dmu (Suc i))" 
  and i: "i < m" 
  and i0: "i  0" 
shows "fs' b'. LLL_invariant_mod_weak fs' mfs' dmu' p' first b' 
        LLL_measure (i - 1) fs' < LLL_measure i fs 
        LLL_measure (m - 1) fs' < LLL_measure (m - 1) fs 
        (LLL_invariant_mod fs mfs dmu p first b i  
         LLL_invariant_mod fs' mfs' dmu' p' first b' (i - 1))"
proof -
  obtain mfs0 dmu0 where add: "basis_reduction_mod_add_row p mfs dmu i (i-1) = (mfs0, dmu0)" by force
  obtain mfs1 dmu1 where swap: "basis_reduction_mod_swap p mfs0 dmu0 i = (mfs1, dmu1)" by force
  note res = res[unfolded basis_reduction_adjust_swap_add_step_def Let_def add split swap]
  from i0 have ii: "i - 1 < i" by auto
  from basis_reduction_mod_add_row[OF Linv add i ii i0]
  obtain fs0 where Linv0: "LLL_invariant_mod_weak fs0 mfs0 dmu0 p first b" 
    and meas0: "LLL_measure i fs0 = LLL_measure i fs" 
    and small: "¦μ fs0 i (i - 1)¦  1 / 2" 
    and Linv0': "LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs0 mfs0 dmu0 p first b i" 
    by blast
  {
    have id: "d_of dmu0 i = d_of dmu i" "d_of dmu0 (i - 1) = d_of dmu (i - 1)"
      "d_of dmu0 (Suc i) = d_of dmu (Suc i)"
      using i i0 add[unfolded basis_reduction_mod_add_row_def Let_def]
      by (auto split: if_splits simp: d_of_def)
    from ineq[folded id, unfolded alpha_comparison[OF Linv0 alph i i0]]
    have "gso fs0 (i - 1)2 > α * gso fs0 i2" by simp
  } note ineq = this
  from Linv have "LLL_invariant_weak fs" 
    by (auto simp: LLL_invariant_weak_def LLL_invariant_mod_weak_def)
  from basis_reduction_mod_swap[OF Linv0 small swap ineq i i0, unfolded meas0] Linv0'
  obtain fs1 where Linv1: "LLL_invariant_mod_weak fs1 mfs1 dmu1 p first b"
    and meas1: "LLL_measure (i - 1) fs1 < LLL_measure i fs" 
    and Linv1': "LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs1 mfs1 dmu1 p first b (i - 1)" 
    by auto
  show ?thesis
  proof (cases "i - 1 = g_idx")
    case False
    with res have id: "p' = p" "mfs' = mfs1" "dmu' = dmu1" "g_idx' = g_idx" by auto
    show ?thesis unfolding id using Linv1' meas1 Linv1 by (intro exI[of _ fs1] exI[of _ b], auto simp: LLL_measure_def)
  next
    case True
    with res have adjust: "basis_reduction_adjust_mod p first mfs1 dmu1 = (p', mfs', dmu', g_idx')" by simp
    from basis_reduction_adjust_mod[OF Linv1 adjust, of "i - 1"] Linv1'
    obtain fs' b' where Linvw: "LLL_invariant_mod_weak fs' mfs' dmu' p' first b'"
      and Linv: "LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs' mfs' dmu' p' first b' (i - 1)"
      and meas: "LLL_measure (i - 1) fs' = LLL_measure (i - 1) fs1" 
      by blast
    note meas = meas1[folded meas]
    from meas have meas': "LLL_measure (m - 1) fs' < LLL_measure (m - 1) fs" 
      unfolding LLL_measure_def using i by auto
    show ?thesis
      by (intro exI conjI impI, rule Linvw, rule meas, rule meas', rule Linv) 
  qed
qed


lemma basis_reduction_mod_step: assumes 
  Linv: "LLL_invariant_mod fs mfs dmu p first b i" 
  and res: "basis_reduction_mod_step p first mfs dmu g_idx i j = (p', mfs', dmu', g_idx', i', j')" 
  and i: "i < m" 
shows "fs' b'. LLL_measure i' fs' < LLL_measure i fs  LLL_invariant_mod fs' mfs' dmu' p' first b' i'"
proof -
  note res = res[unfolded basis_reduction_mod_step_def Let_def]
  from Linv have Linvw: "LLL_invariant_mod_weak fs mfs dmu p first b" 
    by (auto simp: LLL_invariant_mod_weak_def LLL_invariant_mod_def)
  show ?thesis
  proof (cases "i = 0")
    case True
    then have ids: "mfs' = mfs" "dmu' = dmu" "i' = Suc i" "p' = p" using res by auto
    have "LLL_measure i' fs < LLL_measure i fs  LLL_invariant_mod fs mfs' dmu' p first b i'"
      using increase_i_mod[OF Linv i] True res ids inv by simp
    then show ?thesis using res ids inv by auto
  next
    case False
    hence id: "(i = 0) = False" by auto
    obtain num denom where alph: "quotient_of α = (num, denom)" by force
    note res = res[unfolded id if_False alph split]
    let ?comp = "d_of dmu i * d_of dmu i * denom  num * d_of dmu (i - 1) * d_of dmu (Suc i)" 
    show ?thesis
    proof (cases ?comp)
      case False
      hence id: "?comp = False" by simp
      note res = res[unfolded id if_False]
      let ?step = "basis_reduction_adjust_swap_add_step p first mfs dmu g_idx i" 
      from res have step: "?step = (p', mfs', dmu', g_idx')" 
        and i': "i' = i - 1" 
        by (cases ?step, auto)+
      from basis_reduction_adjust_swap_add_step[OF Linvw step alph False i i  0] Linv
      show ?thesis unfolding i' by blast
    next
      case True
      hence id: "?comp = True" by simp
      note res = res[unfolded id if_True]
      from res have ids: "p' = p" "mfs' = mfs" "dmu' = dmu" "i' = Suc i" by auto
      from True alpha_comparison[OF Linvw alph i False]
      have ineq: "sq_norm (gso fs (i - 1))  α * sq_norm (gso fs i)" by simp
      from increase_i_mod[OF Linv i ineq]
      show ?thesis unfolding ids by auto
    qed
  qed
qed

lemma basis_reduction_mod_main: assumes "LLL_invariant_mod fs mfs dmu p first b i" 
  and res: "basis_reduction_mod_main p first mfs dmu g_idx i j = (p', mfs', dmu')" 
shows "fs' b'. LLL_invariant_mod fs' mfs' dmu' p' first b' m" 
  using assms
proof (induct "LLL_measure i fs" arbitrary: i mfs dmu j p b fs g_idx rule: less_induct)
  case (less i fs mfs dmu j p b g_idx)
  hence fsinv: "LLL_invariant_mod fs mfs dmu p first b i" by auto
  note res = less(3)[unfolded basis_reduction_mod_main.simps[of p first mfs dmu g_idx i j]]
  note inv = less(2)
  note IH = less(1)
  show ?case
  proof (cases "i < m")
    case i: True
    obtain p' mfs' dmu' g_idx' i' j' where step: "basis_reduction_mod_step p first mfs dmu g_idx i j = (p', mfs', dmu', g_idx', i', j')" 
      (is "?step = _") by (cases ?step, auto)
    then obtain fs' b' where Linv: "LLL_invariant_mod fs' mfs' dmu' p' first b' i'" 
      and decr: "LLL_measure i' fs' < LLL_measure i fs"
      using basis_reduction_mod_step[OF fsinv step i] i fsinv by blast 
    note res = res[unfolded step split]
    from res i show ?thesis using IH[OF decr Linv] by auto
  next
    case False
    with LLL_invD_mod[OF fsinv] res have i: "i = m" "p' = p" by auto
    then obtain fs' b' where "LLL_invariant_mod fs' mfs' dmu' p first b' m" using False res fsinv by simp
    then show ?thesis using i by auto
  qed
qed

lemma compute_max_gso_quot_alpha: 
  assumes inv: "LLL_invariant_mod_weak fs mfs dmu p first b" 
  and max: "compute_max_gso_quot dmu = (msq_num, msq_denum, idx)"
  and alph: "quotient_of α = (num, denum)" 
  and cmp: "(msq_num * denum  > num * msq_denum) = cmp" 
  and m: "m > 1" 
shows "cmp  idx  0  idx < m  ¬ (d_of dmu idx * d_of dmu idx * denum
               num * d_of dmu (idx - 1) * d_of dmu (Suc idx))" 
  and "¬ cmp  LLL_invariant_mod fs mfs dmu p first b m" 
proof -
  from inv
  have fsinv: "LLL_invariant_weak fs" 
    by (simp add: LLL_invariant_mod_weak_def LLL_invariant_weak_def)
  define qt where "qt = (λi. ((d_of dmu (i + 1)) * (d_of dmu (i + 1)),
            (d_of dmu (i + 2)) * (d_of dmu i), Suc i))"
  define lst where "lst = (map (λi. qt i) [0..<(m-1)])"
  have msqlst: "(msq_num, msq_denum, idx) = max_list_rats_with_index lst"
    using max lst_def qt_def unfolding compute_max_gso_quot_def by simp
  have nz: "n d i. (n, d, i)  set lst  d > 0"
    unfolding lst_def qt_def using d_of_weak[OF inv] LLL_d_pos[OF fsinv] by auto
  have geq: "(n, d, i)  set lst. rat_of_int msq_num / of_int msq_denum  rat_of_int n / of_int d"
    using max_list_rats_with_index[of lst] nz msqlst by (metis (no_types, lifting) case_prodI2)
  have len: "length lst  1" using m unfolding lst_def by simp
  have inset: "(msq_num, msq_denum, idx)  set lst"
    using max_list_rats_with_index_in_set[OF msqlst[symmetric] len] nz by simp
  then have idxm: "idx  {1..<m}" using lst_def[unfolded qt_def] by auto
  then have idx0: "idx  0" and idx: "idx < m" by auto
  have 00: "(msq_num, msq_denum, idx)  = qt (idx - 1)" using lst_def inset qt_def by auto
  then have id_qt: "msq_num = d_of dmu idx * d_of dmu idx" "msq_denum = d_of dmu (Suc idx) * d_of dmu (idx - 1)" 
    unfolding qt_def by auto
  have "msq_denum = (d_of dmu (idx + 1)) * (d_of dmu (idx - 1))"
    using 00 unfolding qt_def by simp
  then have dengt0: "msq_denum > 0" using d_of_weak[OF inv] idxm LLL_d_pos[OF fsinv] by auto
  have αdengt0: "denum > 0" using alph by (metis quotient_of_denom_pos)
  from cmp[unfolded id_qt]
  have cmp: "cmp = (¬ (d_of dmu idx * d_of dmu idx * denum  num * d_of dmu (idx - 1) * d_of dmu (Suc idx)))" 
    by (auto simp: ac_simps)
  {
    assume cmp    
    from this[unfolded cmp] 
    show "idx  0  idx < m  ¬ (d_of dmu idx * d_of dmu idx * denum
               num * d_of dmu (idx - 1) * d_of dmu (Suc idx))" using idx0 idx by auto
  }
  {
    assume "¬ cmp" 
    from this[unfolded cmp] have small: "d_of dmu idx * d_of dmu idx * denum  num * d_of dmu (idx - 1) * d_of dmu (Suc idx)" by auto
    note d_pos = LLL_d_pos[OF fsinv]
    have gso: "k < m  sq_norm (gso fs k) = of_int (d fs (Suc k)) / of_int (d fs k)" for k using 
        LLL_d_Suc[OF fsinv, of k] d_pos[of k] by simp
    have gso_pos: "k < m  sq_norm (gso fs k) > 0" for k 
      using gso[of k] d_pos[of k] d_pos[of "Suc k"] by auto
    from small[unfolded alpha_comparison[OF inv alph idx idx0]]
    have alph: "sq_norm (gso fs (idx - 1))  α * sq_norm (gso fs idx)" .
    with gso_pos[OF idx] have alph: "sq_norm (gso fs (idx - 1)) / sq_norm (gso fs idx)  α" 
      by (metis mult_imp_div_pos_le)
    have weak: "weakly_reduced fs m" unfolding gram_schmidt_fs.weakly_reduced_def
    proof (intro allI impI, goal_cases)
      case (1 i)
      from idx have idx1: "idx - 1 < m" by auto
      from geq[unfolded lst_def]
      have mem: "(d_of dmu (Suc i) * d_of dmu (Suc i),
            d_of dmu (Suc (Suc i)) * d_of dmu i, Suc i)  set lst" 
        unfolding lst_def qt_def using 1 by auto
      have "sq_norm (gso fs i) / sq_norm (gso fs (Suc i)) = 
        of_int (d_of dmu (Suc i) * d_of dmu (Suc i)) / of_int (d_of dmu (Suc (Suc i)) * d_of dmu i)" 
        using gso idx0 d_of_weak[OF inv] 1 by auto
      also have "  rat_of_int msq_num / rat_of_int msq_denum" 
        using geq[rule_format, OF mem, unfolded split] by auto
      also have " = sq_norm (gso fs (idx - 1)) / sq_norm (gso fs idx)" 
        unfolding id_qt gso[OF idx] gso[OF idx1] using idx0 d_of_weak[OF inv] idx by auto
      also have "  α" by fact
      finally show "sq_norm (gso fs i)  α * sq_norm (gso fs (Suc i))" using gso_pos[OF 1]
        using pos_divide_le_eq by blast
    qed
    with inv show "LLL_invariant_mod fs mfs dmu p first b m" 
      by (auto simp: LLL_invariant_mod_weak_def LLL_invariant_mod_def)
  }
qed
  

lemma small_m: 
  assumes inv: "LLL_invariant_mod_weak fs mfs dmu p first b" 
  and m: "m  1" 
shows "LLL_invariant_mod fs mfs dmu p first b m" 
proof -
  have weak: "weakly_reduced fs m" unfolding gram_schmidt_fs.weakly_reduced_def using m
    by auto
  with inv show "LLL_invariant_mod fs mfs dmu p first b m" 
    by (auto simp: LLL_invariant_mod_weak_def LLL_invariant_mod_def)
qed

lemma basis_reduction_iso_main: assumes "LLL_invariant_mod_weak fs mfs dmu p first b"
  and res: "basis_reduction_iso_main p first mfs dmu g_idx j = (p', mfs', dmu')" 
shows "fs' b'. LLL_invariant_mod fs' mfs' dmu' p' first b' m" 
  using assms
proof (induct "LLL_measure (m-1) fs" arbitrary: fs mfs dmu j p b g_idx rule: less_induct)
  case (less fs mfs dmu j p b g_idx)
  have inv: "LLL_invariant_mod_weak fs mfs dmu p first b" using less by auto
  hence fsinv: "LLL_invariant_weak fs" 
    by (simp add: LLL_invariant_mod_weak_def LLL_invariant_weak_def)
  note res = less(3)[unfolded basis_reduction_iso_main.simps[of p first mfs dmu g_idx j]]
  note IH = less(1)
  obtain msq_num msq_denum idx where max: "compute_max_gso_quot dmu = (msq_num, msq_denum, idx)" 
    by (metis prod_cases3)
  obtain num denum where alph: "quotient_of α = (num, denum)" by force
  note res = res[unfolded max alph Let_def split]
  consider (small) "m  1" | (final) "m > 1" "¬ (num * msq_denum < msq_num * denum)" | (step) "m > 1" "num * msq_denum < msq_num * denum" 
    by linarith
  thus ?case
  proof cases
    case *: step
    obtain p1 mfs1 dmu1 g_idx1 where step: "basis_reduction_adjust_swap_add_step p first mfs dmu g_idx idx = (p1, mfs1, dmu1, g_idx1)"
      by (metis prod_cases4)
    from res[unfolded step split] * have res: "basis_reduction_iso_main p1 first mfs1 dmu1 g_idx1 (j + 1) = (p', mfs', dmu')" by auto
    from compute_max_gso_quot_alpha(1)[OF inv max alph refl *]
    have idx0: "idx  0" and idx: "idx < m" and cmp: "¬ d_of dmu idx * d_of dmu idx * denum  num * d_of dmu (idx - 1) * d_of dmu (Suc idx)" by auto
    from basis_reduction_adjust_swap_add_step[OF inv step alph cmp idx idx0] obtain fs1 b1 
      where inv1: "LLL_invariant_mod