diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 639d12992ab8..3f9caf533ec7 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1223,6 +1223,28 @@ theorem toNat_ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) : · apply hn · apply Nat.pow_pos (by decide) +@[simp] +theorem getMsbD_ushiftRight {x : BitVec w} {i n : Nat} : + (x >>> n).getMsbD i = (decide (i < w) && (!decide (i < n) && x.getMsbD (i - n))) := by + simp only [getMsbD, getLsbD_ushiftRight] + by_cases h : i < n + · simp [getLsbD_ge, show w ≤ (n + (w - 1 - i)) by omega] + omega + · by_cases h₁ : i < w + · simp only [h, ushiftRight_eq, getLsbD_ushiftRight, show i - n < w by omega] + congr + omega + · simp [h, h₁] + +@[simp] +theorem msb_ushiftRight {x : BitVec w} {n : Nat} : + (x >>> n).msb = (!decide (0 < n) && x.msb) := by + induction n + case zero => + simp + case succ nn ih => + simp [BitVec.ushiftRight_eq, getMsbD_ushiftRight, BitVec.msb, ih, show nn + 1 > 0 by omega] + /-! ### ushiftRight reductions from BitVec to Nat -/ @[simp] @@ -1327,7 +1349,8 @@ theorem sshiftRight_or_distrib (x y : BitVec w) (n : Nat) : <;> simp [*] /-- The msb after arithmetic shifting right equals the original msb. -/ -theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} : +@[simp] +theorem msb_sshiftRight {n : Nat} {x : BitVec w} : (x.sshiftRight n).msb = x.msb := by rw [msb_eq_getLsbD_last, getLsbD_sshiftRight, msb_eq_getLsbD_last] by_cases hw₀ : w = 0 @@ -1354,7 +1377,7 @@ theorem sshiftRight_add {x : BitVec w} {m n : Nat} : by_cases h₃ : m + (n + ↑i) < w · simp [h₃] omega - · simp [h₃, sshiftRight_msb_eq_msb] + · simp [h₃, msb_sshiftRight] theorem not_sshiftRight {b : BitVec w} : ~~~b.sshiftRight n = (~~~b).sshiftRight n := by @@ -1372,11 +1395,56 @@ theorem not_sshiftRight_not {x : BitVec w} {n : Nat} : ~~~((~~~x).sshiftRight n) = x.sshiftRight n := by simp [not_sshiftRight] +@[simp] +theorem getMsbD_sshiftRight {x : BitVec w} {i n : Nat} : + getMsbD (x.sshiftRight n) i = (decide (i < w) && if i < n then x.msb else getMsbD x (i - n)) := by + simp only [getMsbD, BitVec.getLsbD_sshiftRight] + by_cases h : i < w + · simp only [h, decide_True, Bool.true_and] + by_cases h₁ : w ≤ w - 1 - i + · simp [h₁] + omega + · simp only [h₁, decide_False, Bool.not_false, Bool.true_and] + by_cases h₂ : i < n + · simp only [h₂, ↓reduceIte, ite_eq_right_iff] + omega + · simp only [show i - n < w by omega, h₂, ↓reduceIte, decide_True, Bool.true_and] + by_cases h₄ : n + (w - 1 - i) < w <;> (simp only [h₄, ↓reduceIte]; congr; omega) + · simp [h] + /-! ### sshiftRight reductions from BitVec to Nat -/ @[simp] theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl +@[simp] +theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} : + getLsbD (x.sshiftRight' y) i = + (!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by + simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight] + +@[simp] +theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} : + (x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by + simp only [BitVec.sshiftRight', getMsbD, BitVec.getLsbD_sshiftRight] + by_cases h : i < w + · simp only [h, decide_True, Bool.true_and] + by_cases h₁ : w ≤ w - 1 - i + · simp [h₁] + omega + · simp only [h₁, decide_False, Bool.not_false, Bool.true_and] + by_cases h₂ : i < y.toNat + · simp only [h₂, ↓reduceIte, ite_eq_right_iff] + omega + · simp only [show i - y.toNat < w by omega, h₂, ↓reduceIte, decide_True, Bool.true_and] + by_cases h₄ : y.toNat + (w - 1 - i) < w <;> (simp only [h₄, ↓reduceIte]; congr; omega) + · simp [h] + +@[simp] +theorem msb_sshiftRight' {x y: BitVec w} : + (x.sshiftRight' y).msb = x.msb := by + simp [BitVec.sshiftRight', BitVec.msb_sshiftRight] + /-! ### udiv -/ theorem udiv_def {x y : BitVec n} : x / y = BitVec.ofNat n (x.toNat / y.toNat) := by @@ -1690,6 +1758,11 @@ theorem shiftLeft_ushiftRight {x : BitVec w} {n : Nat}: · simp [hi₂] · simp [Nat.lt_one_iff, hi₂, show 1 + (i.val - 1) = i by omega] +@[simp] +theorem msb_shiftLeft {x : BitVec w} {n : Nat} : + (x <<< n).msb = x.getMsbD n := by + simp [BitVec.msb] + @[deprecated shiftRight_add (since := "2024-06-02")] theorem shiftRight_shiftRight {w : Nat} (x : BitVec w) (n m : Nat) : (x >>> n) >>> m = x >>> (n + m) := by @@ -2971,4 +3044,7 @@ abbrev zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsbD_true @[deprecated and_one_eq_setWidth_ofBool_getLsbD (since := "2024-09-18")] abbrev and_one_eq_zeroExtend_ofBool_getLsbD := @and_one_eq_setWidth_ofBool_getLsbD +@[deprecated msb_sshiftRight (since := "2024-10-03")] +abbrev sshiftRight_msb_eq_msb := @msb_sshiftRight + end BitVec