Skip to content

Commit

Permalink
feat: complete BitVec.[getMsbD|getLsbD|msb] for shifts (#5604)
Browse files Browse the repository at this point in the history
Co-authored-by: Tobias Grosser <[email protected]>
  • Loading branch information
luisacicolini and tobiasgrosser authored Oct 13, 2024
1 parent 5d65530 commit 47e0430
Showing 1 changed file with 78 additions and 2 deletions.
80 changes: 78 additions & 2 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,28 @@ theorem toNat_ushiftRight_lt (x : BitVec w) (n : Nat) (hn : n ≤ w) :
· apply hn
· apply Nat.pow_pos (by decide)

@[simp]
theorem getMsbD_ushiftRight {x : BitVec w} {i n : Nat} :
(x >>> n).getMsbD i = (decide (i < w) && (!decide (i < n) && x.getMsbD (i - n))) := by
simp only [getMsbD, getLsbD_ushiftRight]
by_cases h : i < n
· simp [getLsbD_ge, show w ≤ (n + (w - 1 - i)) by omega]
omega
· by_cases h₁ : i < w
· simp only [h, ushiftRight_eq, getLsbD_ushiftRight, show i - n < w by omega]
congr
omega
· simp [h, h₁]

@[simp]
theorem msb_ushiftRight {x : BitVec w} {n : Nat} :
(x >>> n).msb = (!decide (0 < n) && x.msb) := by
induction n
case zero =>
simp
case succ nn ih =>
simp [BitVec.ushiftRight_eq, getMsbD_ushiftRight, BitVec.msb, ih, show nn + 1 > 0 by omega]

/-! ### ushiftRight reductions from BitVec to Nat -/

@[simp]
Expand Down Expand Up @@ -1327,7 +1349,8 @@ theorem sshiftRight_or_distrib (x y : BitVec w) (n : Nat) :
<;> simp [*]

/-- The msb after arithmetic shifting right equals the original msb. -/
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
@[simp]
theorem msb_sshiftRight {n : Nat} {x : BitVec w} :
(x.sshiftRight n).msb = x.msb := by
rw [msb_eq_getLsbD_last, getLsbD_sshiftRight, msb_eq_getLsbD_last]
by_cases hw₀ : w = 0
Expand All @@ -1354,7 +1377,7 @@ theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
by_cases h₃ : m + (n + ↑i) < w
· simp [h₃]
omega
· simp [h₃, sshiftRight_msb_eq_msb]
· simp [h₃, msb_sshiftRight]

theorem not_sshiftRight {b : BitVec w} :
~~~b.sshiftRight n = (~~~b).sshiftRight n := by
Expand All @@ -1372,11 +1395,56 @@ theorem not_sshiftRight_not {x : BitVec w} {n : Nat} :
~~~((~~~x).sshiftRight n) = x.sshiftRight n := by
simp [not_sshiftRight]

@[simp]
theorem getMsbD_sshiftRight {x : BitVec w} {i n : Nat} :
getMsbD (x.sshiftRight n) i = (decide (i < w) && if i < n then x.msb else getMsbD x (i - n)) := by
simp only [getMsbD, BitVec.getLsbD_sshiftRight]
by_cases h : i < w
· simp only [h, decide_True, Bool.true_and]
by_cases h₁ : w ≤ w - 1 - i
· simp [h₁]
omega
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
by_cases h₂ : i < n
· simp only [h₂, ↓reduceIte, ite_eq_right_iff]
omega
· simp only [show i - n < w by omega, h₂, ↓reduceIte, decide_True, Bool.true_and]
by_cases h₄ : n + (w - 1 - i) < w <;> (simp only [h₄, ↓reduceIte]; congr; omega)
· simp [h]

/-! ### sshiftRight reductions from BitVec to Nat -/

@[simp]
theorem sshiftRight_eq' (x : BitVec w) : x.sshiftRight' y = x.sshiftRight y.toNat := rfl

@[simp]
theorem getLsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
getLsbD (x.sshiftRight' y) i =
(!decide (w ≤ i) && if y.toNat + i < w then x.getLsbD (y.toNat + i) else x.msb) := by
simp only [BitVec.sshiftRight', BitVec.getLsbD_sshiftRight]

@[simp]
theorem getMsbD_sshiftRight' {x y: BitVec w} {i : Nat} :
(x.sshiftRight y.toNat).getMsbD i = (decide (i < w) && if i < y.toNat then x.msb else x.getMsbD (i - y.toNat)) := by
simp only [BitVec.sshiftRight', getMsbD, BitVec.getLsbD_sshiftRight]
by_cases h : i < w
· simp only [h, decide_True, Bool.true_and]
by_cases h₁ : w ≤ w - 1 - i
· simp [h₁]
omega
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
by_cases h₂ : i < y.toNat
· simp only [h₂, ↓reduceIte, ite_eq_right_iff]
omega
· simp only [show i - y.toNat < w by omega, h₂, ↓reduceIte, decide_True, Bool.true_and]
by_cases h₄ : y.toNat + (w - 1 - i) < w <;> (simp only [h₄, ↓reduceIte]; congr; omega)
· simp [h]

@[simp]
theorem msb_sshiftRight' {x y: BitVec w} :
(x.sshiftRight' y).msb = x.msb := by
simp [BitVec.sshiftRight', BitVec.msb_sshiftRight]

/-! ### udiv -/

theorem udiv_def {x y : BitVec n} : x / y = BitVec.ofNat n (x.toNat / y.toNat) := by
Expand Down Expand Up @@ -1690,6 +1758,11 @@ theorem shiftLeft_ushiftRight {x : BitVec w} {n : Nat}:
· simp [hi₂]
· simp [Nat.lt_one_iff, hi₂, show 1 + (i.val - 1) = i by omega]

@[simp]
theorem msb_shiftLeft {x : BitVec w} {n : Nat} :
(x <<< n).msb = x.getMsbD n := by
simp [BitVec.msb]

@[deprecated shiftRight_add (since := "2024-06-02")]
theorem shiftRight_shiftRight {w : Nat} (x : BitVec w) (n m : Nat) :
(x >>> n) >>> m = x >>> (n + m) := by
Expand Down Expand Up @@ -2971,4 +3044,7 @@ abbrev zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsbD_true
@[deprecated and_one_eq_setWidth_ofBool_getLsbD (since := "2024-09-18")]
abbrev and_one_eq_zeroExtend_ofBool_getLsbD := @and_one_eq_setWidth_ofBool_getLsbD

@[deprecated msb_sshiftRight (since := "2024-10-03")]
abbrev sshiftRight_msb_eq_msb := @msb_sshiftRight

end BitVec

0 comments on commit 47e0430

Please sign in to comment.