diff --git a/CheckedArith.v b/CheckedArith.v index 89e1013..90ea4c4 100644 --- a/CheckedArith.v +++ b/CheckedArith.v @@ -484,4 +484,424 @@ rewrite Z.eqb_neq in Heqx_zero. replace (- x mod 2 ^ 256 =? - x) with false. { trivial. } symmetry. rewrite Z.eqb_neq. intro H. apply Z.mod_small_iff in H; lia. +Qed. + + +(********************************************************************** + Binary search + **********************************************************************) + + + +(* This function descends slightly slower than a normal non-Coq implementation would, + meaning if we have 5 items to check, normally they would be split 2/3, + but we have to do 3/3 to make the recursion happy. + *) +Fixpoint binary_search_rec (p: Z -> bool) + (low: Z) + (n: positive) +{struct n} +: Z +:= match n with + | xH => + let mid := low + 1 in + if p mid + then low + else mid + | xO half | xI half => + let mid := low + Z.pos half + 1 in + if p mid + then binary_search_rec p low half + else binary_search_rec p mid half + end. + +Lemma binary_search_rec_lower (p: Z -> bool) + (low: Z) + (n: positive): + low <= binary_search_rec p low n. +Proof. +revert low. induction n; intro; cbn; + destruct p; try apply IHn; + try refine (Z.le_trans _ _ _ _ (IHn _)); lia. +Qed. + +Lemma binary_search_rec_ok (p: Z -> bool) + (low: Z) + (n: positive) + (Mono: forall x y, x < y -> p x = true -> p y = true) + (LowOk: p low = false) + (HighOk: p (low + Z.pos n + 1) = true): + let m := binary_search_rec p low n in + p m = false /\ p (Z.succ m) = true. +Proof. +revert low LowOk HighOk. induction n; intros. +{ (* 2n + 1 *) + cbn. + remember (p (low + Z.pos n + 1)) as p_mid; symmetry in Heqp_mid; destruct p_mid. + { (* lower half *) now apply IHn. } + (* upper half *) + apply IHn. { assumption. } + replace (low + Z.pos n + 1 + Z.pos n + 1) with (low + Z.pos n~1 + 1) by lia. + assumption. +} +{ (* 2n *) + cbn. + remember (p (low + Z.pos n + 1)) as p_mid; symmetry in Heqp_mid; destruct p_mid. + { (* lower half *) now apply IHn. } + (* upper half *) + apply IHn. { assumption. } + refine (Mono _ _ _ HighOk). + lia. +} +(* 1 *) +unfold m. cbn. +remember (p (low + 1)) as p_mid; symmetry in Heqp_mid; now destruct p_mid. +Qed. + +Lemma binary_search_rec_change (p q: Z -> bool) + (low: Z) + (n: positive) + (ok: forall x, low < x -> p x = q x): + binary_search_rec p low n = binary_search_rec q low n. +Proof. +revert low ok. induction n; intros; cbn; + rewrite ok by lia; repeat rewrite IHn; trivial; intros; apply ok; lia. +Qed. + +(** Perform binary search on a monotone boolean function [p] + which is false on [low] but true on [high]. + Finds a number on which [p] is false but on the next one it's true. + *) +Definition binary_search (p: Z -> bool) + (low high: Z) +: Z +:= match high - low - 1 with + | Z0 => low + | Zpos n => binary_search_rec p low n + | Zneg _ => low (* bad arguments *) + end. + +Lemma binary_search_lower (p: Z -> bool) + (low high: Z): + low <= binary_search p low high. +Proof. +unfold binary_search. +destruct (high - low - 1); try apply Z.le_refl. +apply binary_search_rec_lower. +Qed. + +Lemma binary_search_ok (p: Z -> bool) + (low high: Z) + (Mono: forall x y, x < y -> p x = true -> p y = true) + (LowOk: p low = false) + (HighOk: p high = true): + let m := binary_search p low high in + p m = false /\ p (Z.succ m) = true. +Proof. +unfold binary_search. +remember (high - low - 1) as d. symmetry in Heqd. destruct d as [|n|n]. +{ (* 0 *) assert (high = Z.succ low) by lia; now subst high. } +{ (* pos *) + apply binary_search_rec_ok; try assumption. + assert (high = low + Z.pos n + 1) by lia. + now subst high. +} +(* neg *) +(* this case contradicts [Mono] *) +assert (LE: high <= low) by lia. +apply Zle_lt_or_eq in LE. +case LE; clear LE; intro LE. +{ now rewrite (Mono _ _ LE HighOk) in LowOk. } +subst. now rewrite LowOk in HighOk. +Qed. + +Lemma binary_search_change (p q: Z -> bool) + (low high: Z) + (ok: forall x, low < x -> p x = q x): + binary_search p low high = binary_search q low high. +Proof. +unfold binary_search. destruct (high - low - 1); try easy. +now rewrite (binary_search_rec_change p q) by apply ok. +Qed. + +(****************************************************************************** + * Powers and exponents + ******************************************************************************) + +(** Assuming [base > 2], compute the maximum allowed power before an overflow *) +Definition uint256_max_power {C: VyperConfig} (base: Z) +:= binary_search (fun x => 2 ^ 256 <=? Z.pow base x) 0 162. + +Lemma uint256_max_power_ok {C: VyperConfig} (base: Z) (BaseOk: 2 < base): + let p := uint256_max_power base in + Z.pow base p < 2 ^ 256 /\ 2 ^ 256 <= Z.pow base (Z.succ p). +Proof. +unfold uint256_max_power. +remember (fun x : Z => 2 ^ 256 <=? base ^ x) as f. +assert (F0: f 0 = false) by now subst f. +assert (F162: f 162 = true). +{ + subst f. + apply Z.leb_le. + assert (T: 2 ^ 256 <= 3 ^ 162) by now apply Z.leb_le. + apply (Z.le_trans _ _ _ T). + apply Z.pow_le_mono_l. + lia. +} +assert (Mono: forall x y, x < y -> f x = true -> f y = true). +{ + intros. subst f. + rewrite Z.leb_le in *. + apply (Z.le_trans (2 ^ 256) (base ^ x) (base ^ y)). { assumption. } + apply Z.pow_le_mono_r; lia. +} +assert (B := binary_search_ok f 0 162 Mono F0 F162). +subst f. cbn zeta in B. +rewrite Z.leb_gt in B. rewrite Z.leb_le in B. +apply B. +Qed. + +(** Unchecked power modulo [2^256]. *) +Definition uint256_pow {C: VyperConfig} (a b: uint256) +: uint256 +:= uint256_of_Z (Z_of_uint256 a ^ Z_of_uint256 b). + +(** This is checked power with a constant base close to how it's compiled: + for [base == 0]: [pow == 0] + for [base == 1]: [pow; 1] + for [base == 2]: [assert pow < 256; 1 << pow] + for [base >= 3]: [assert pow <= uint256_max_power base; base ** pow + *) +Definition uint256_checked_pow_const_base {C: VyperConfig} (base pow: uint256) +: option uint256 +:= let base_Z := Z_of_uint256 base in + let pow_Z := Z_of_uint256 pow in + if base_Z =? 0 then Some (if pow_Z =? 0 then one256 else zero256) else + if base_Z =? 1 then Some one256 else + if base_Z =? 2 + then if pow_Z 0) by discriminate. + assert (M := Z.mod_small_iff (2 ^ Z_of_uint256 pow) (2 ^ 256) T). clear T. + assert (R := uint256_range pow). + assert (NN: ~(2 ^ 256 < 2 ^ Z_of_uint256 pow <= 0)) by lia. + assert (L: 0 <= 2 ^ Z_of_uint256 pow) by now apply Z.pow_nonneg. + assert (MM: 2 ^ Z_of_uint256 pow mod 2 ^ 256 = 2 ^ Z_of_uint256 pow + <-> + 2 ^ Z_of_uint256 pow < 2 ^ 256) by tauto. + clear M NN. + assert (Q: 2 ^ Z_of_uint256 pow < 2 ^ 256 <-> Z_of_uint256 pow < 256). + { symmetry. now apply Z.pow_lt_mono_r_iff. } + assert (P: 2 ^ Z_of_uint256 pow mod 2 ^ 256 = 2 ^ Z_of_uint256 pow + <-> + Z_of_uint256 pow < 256) by tauto. + clear MM Q. + remember (Z_of_uint256 pow = 3 *) +assert (L: 2 < Z_of_uint256 base). { assert (R := uint256_range base). lia. } +assert (M := uint256_max_power_ok _ L). +unfold uint256_pow. +remember (Z_of_uint256 base ^ Z_of_uint256 pow mod 2 ^ 256 + =? + Z_of_uint256 base ^ Z_of_uint256 pow) as m. +symmetry in Heqm. +remember (Z_of_uint256 pow <=? uint256_max_power (Z_of_uint256 base)) as l. symmetry in Heql. +enough (H: m = l) by now rewrite H. +destruct l. +{ + rewrite Z.leb_le in Heql. + subst m. + rewrite Z.eqb_eq. + apply Z.mod_small. + split. { apply Z.pow_nonneg. lia. } + refine (Z.le_lt_trans _ _ _ _ (proj1 M)). + apply Z.pow_le_mono_r. { lia. } exact Heql. +} +rewrite Z.leb_gt in Heql. +subst m. rewrite Z.eqb_neq. intro H. +apply Z.mod_small_iff in H. 2:discriminate. +enough (2 ^ 256 <= Z_of_uint256 base ^ Z_of_uint256 pow) by lia. +clear H. +apply (Z.le_trans _ _ _ (proj2 M)). +apply Z.pow_le_mono_r; lia. +Qed. + + +(** Assuming [pow >= 2], compute the maximum allowed base before an overflow *) +Definition uint256_max_base {C: VyperConfig} (pow: Z) +:= binary_search (fun x => 2 ^ 256 <=? Z.pow x pow) 0 (2 ^ 128). + +Lemma uint256_max_base_ok {C: VyperConfig} (pow: Z) (PowOk: 2 <= pow): + let b := uint256_max_base pow in + Z.pow b pow < 2 ^ 256 /\ 2 ^ 256 <= Z.pow (Z.succ b) pow. +Proof. +unfold uint256_max_base. +remember (fun x : Z => 2 ^ 256 <=? x ^ pow) as f. +pose (g := fun x => ((0 <=? x) && f x)%bool). +assert (E: binary_search f 0 (2 ^ 128) = binary_search g 0 (2 ^ 128)). +{ + apply binary_search_change. + intros x L. + subst f g. + cbn beta. + apply Z.lt_le_incl in L. + apply Z.leb_le in L. + rewrite L. + apply Bool.andb_true_l. +} +rewrite E. +assert (Low: g 0 = false). +{ + subst g. subst f. + replace (0 ^ pow) with 0. 2:{ symmetry. apply Z.pow_0_l. lia. } + easy. +} +assert (High: g (2 ^ 128) = true). +{ + subst g. subst f. + apply Z.leb_le. + replace (2 ^ 256) with ((2 ^ 128) ^ 2) by trivial. + now apply Z.pow_le_mono_r. +} +assert (Mono: forall x y, x < y -> g x = true -> g y = true). +{ + intros x y XY H. subst g. subst f. + rewrite Bool.andb_true_iff in *. + destruct H as (PosX, H). + split; rewrite Z.leb_le in *; try lia. + apply (Z.le_trans (2 ^ 256) (x ^ pow) (y ^ pow)). { assumption. } + apply Z.pow_le_mono_l; lia. +} +assert (B := binary_search_ok g _ _ Mono Low High). +cbn zeta in B. destruct B as (U, V). +rewrite<- E in *. +assert (L := binary_search_lower f 0 (2 ^ 128)). +remember (binary_search f 0 (2 ^ 128)) as x. +subst g. cbn in U. cbn in V. +assert (L': 0 <= Z.succ x) by lia. +rewrite<- Z.leb_le in L. +rewrite<- Z.leb_le in L'. +rewrite L in U. rewrite L' in V. +cbn in U. cbn in V. +subst f. +apply Z.leb_gt in U. apply Z.leb_le in V. +tauto. +Qed. + +(** This is ugly but there's no extensionality requirement for uint256s, + so [base ** 1] is not exactly the same as [base]. *) +Definition uint256_unary_plus {C: VyperConfig} (a: uint256) +:= uint256_of_Z (Z_of_uint256 a). + +(** This is checked power with a constant exponent close to how it's compiled: + for [pow == 0]: [base; 1] + for [pow == 1]: [+base] + for [pow >= 2]: [assert base <= uint256_max_base pow; base ** pow + *) +Definition uint256_checked_pow_const_exponent {C: VyperConfig} (base pow: uint256) +: option uint256 +:= let base_Z := Z_of_uint256 base in + let pow_Z := Z_of_uint256 pow in + if pow_Z =? 0 then Some one256 else + if pow_Z =? 1 then Some (uint256_unary_plus base) else + if base_Z <=? uint256_max_base pow_Z + then Some (uint256_pow base pow) + else None. + +Lemma uint256_checked_pow_const_exponent_ok {C: VyperConfig} (base pow: uint256): + uint256_checked_pow_const_exponent base pow = interpret_binop Pow base pow. +Proof. +cbn. unfold uint256_checked_pow_const_exponent. unfold maybe_uint256_of_Z. +rewrite uint256_ok. +remember (Z_of_uint256 pow =? 0) as pow_0. symmetry in Heqpow_0. destruct pow_0. +{ (* base ^ 0 *) + rewrite Z.eqb_eq in Heqpow_0. rewrite Heqpow_0. + now rewrite Z.pow_0_r. +} +rewrite Z.eqb_neq in Heqpow_0. +remember (Z_of_uint256 pow =? 1) as pow_1. symmetry in Heqpow_1. destruct pow_1. +{ (* base ^ 1 *) + rewrite Z.eqb_eq in Heqpow_1. rewrite Heqpow_1. + rewrite Z.pow_1_r. + assert (R := Z.mod_small _ _ (uint256_range base)). + rewrite<- Z.eqb_eq in R. + now rewrite R. +} +rewrite Z.eqb_neq in Heqpow_1. +assert (L: 2 <= Z_of_uint256 pow) by (assert (R := uint256_range pow); lia). +assert (B := uint256_max_base_ok _ L). +remember (Z_of_uint256 base ^ Z_of_uint256 pow mod 2 ^ 256 + =? + Z_of_uint256 base ^ Z_of_uint256 pow) as m. +symmetry in Heqm. +remember (Z_of_uint256 base <=? uint256_max_base (Z_of_uint256 pow)) as l. symmetry in Heql. +enough (H: m = l) by now rewrite H. +assert (R := uint256_range base). +destruct l. +{ + rewrite Z.leb_le in Heql. + subst m. + rewrite Z.eqb_eq. + apply Z.mod_small. + split. { apply Z.pow_nonneg. tauto. } + refine (Z.le_lt_trans _ _ _ _ (proj1 B)). + apply Z.pow_le_mono_l. lia. +} +rewrite Z.leb_gt in Heql. +subst m. rewrite Z.eqb_neq. intro H. +apply Z.mod_small_iff in H. 2:discriminate. +enough (2 ^ 256 <= Z_of_uint256 base ^ Z_of_uint256 pow) by lia. +clear H. +apply (Z.le_trans _ _ _ (proj2 B)). +apply Z.pow_le_mono_l. +split. 2:lia. +enough (0 <= uint256_max_base (Z_of_uint256 pow)) by lia. +unfold uint256_max_base. +apply binary_search_lower. Qed. \ No newline at end of file