diff --git a/Cargo.lock b/Cargo.lock index 28968e4d22..ca6ef451da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3931,6 +3931,7 @@ version = "1.0.1-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", + "eyre", "halo2curves-axiom", "itertools 0.14.0", "num-bigint 0.4.6", @@ -3969,8 +3970,11 @@ version = "1.0.1-rc.0" dependencies = [ "halo2curves-axiom", "num-bigint 0.4.6", + "once_cell", "openvm-algebra-complex-macros", "openvm-algebra-moduli-macros", + "openvm-custom-insn", + "openvm-rv32im-guest", "serde-big-array", "strum_macros", ] @@ -4314,6 +4318,7 @@ name = "openvm-ecc-integration-tests" version = "1.0.1-rc.0" dependencies = [ "eyre", + "halo2curves-axiom", "hex-literal", "num-bigint 0.4.6", "openvm-algebra-circuit", @@ -4413,21 +4418,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "openvm-keccak256-integration-tests" -version = "1.0.1-rc.0" -dependencies = [ - "eyre", - "openvm-circuit", - "openvm-instructions", - "openvm-keccak256-circuit", - "openvm-keccak256-transpiler", - "openvm-rv32im-transpiler", - "openvm-stark-sdk", - "openvm-toolchain-tests", - "openvm-transpiler", -] - [[package]] name = "openvm-keccak256-transpiler" version = "1.0.1-rc.0" @@ -4629,6 +4619,7 @@ name = "openvm-pairing-integration-tests" version = "1.0.1-rc.0" dependencies = [ "eyre", + "halo2curves-axiom", "num-bigint 0.4.6", "num-traits", "openvm", @@ -4879,21 +4870,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "openvm-sha256-integration-tests" -version = "1.0.1-rc.0" -dependencies = [ - "eyre", - "openvm-circuit", - "openvm-instructions", - "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", - "openvm-stark-sdk", - "openvm-toolchain-tests", - "openvm-transpiler", -] - [[package]] name = "openvm-sha256-transpiler" version = "1.0.1-rc.0" diff --git a/Cargo.toml b/Cargo.toml index 1cb02d109d..25a0d99212 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,11 +52,9 @@ members = [ "extensions/keccak256/circuit", "extensions/keccak256/transpiler", "extensions/keccak256/guest", - "extensions/keccak256/tests", "extensions/sha256/circuit", "extensions/sha256/transpiler", "extensions/sha256/guest", - "extensions/sha256/tests", "extensions/ecc/circuit", "extensions/ecc/transpiler", "extensions/ecc/guest", diff --git a/benchmarks/guest/ecrecover/src/main.rs b/benchmarks/guest/ecrecover/src/main.rs index fa29562fa6..db91ea1c22 100644 --- a/benchmarks/guest/ecrecover/src/main.rs +++ b/benchmarks/guest/ecrecover/src/main.rs @@ -24,9 +24,6 @@ openvm_ecc_guest::sw_macros::sw_init! { } pub fn main() { - setup_all_moduli(); - setup_all_curves(); - let expected_address = read_vec(); for _ in 0..5 { let input = read_vec(); diff --git a/benchmarks/guest/kitchen-sink/src/main.rs b/benchmarks/guest/kitchen-sink/src/main.rs index 6aa679eb3f..9bea64b283 100644 --- a/benchmarks/guest/kitchen-sink/src/main.rs +++ b/benchmarks/guest/kitchen-sink/src/main.rs @@ -48,10 +48,8 @@ openvm_algebra_guest::complex_macros::complex_init! { } pub fn main() { - // Setup will materialize every chip - setup_all_moduli(); - setup_all_complex_extensions(); - setup_all_curves(); + // TODO: Since we don't explicitly call setup functions anymore, we should rewrite this test + // to use every declared modulus and curve to ensure that every chip is materialized. let [one, six] = [1, 6].map(Seven::from_u32); assert_eq!(one + six, Seven::ZERO); diff --git a/benchmarks/guest/pairing/src/main.rs b/benchmarks/guest/pairing/src/main.rs index 807c5d2866..70e9c7683c 100644 --- a/benchmarks/guest/pairing/src/main.rs +++ b/benchmarks/guest/pairing/src/main.rs @@ -22,11 +22,6 @@ openvm_algebra_guest::complex_macros::complex_init! { const PAIR_ELEMENT_LEN: usize = 32 * (2 + 4); // 1 G1Affine (2 Fp), 1 G2Affine (4 Fp) pub fn main() { - setup_all_moduli(); - setup_all_complex_extensions(); - // Pairing doesn't need G1Affine intrinsics, but we trigger it anyways to test the chips - setup_all_curves(); - // copied from https://github.com/bluealloy/revm/blob/9e39df5dbc5fdc98779c644629b28b8bee75794a/crates/precompile/src/bn128.rs#L395 let input = hex::decode( "\ diff --git a/book/src/custom-extensions/algebra.md b/book/src/custom-extensions/algebra.md index 7ce48698b3..6224413ddc 100644 --- a/book/src/custom-extensions/algebra.md +++ b/book/src/custom-extensions/algebra.md @@ -17,6 +17,9 @@ The functional part is provided by the `openvm-algebra-guest` crate, which is a - `Field` trait: Provides constants `ZERO` and `ONE` and methods for basic arithmetic operations within a field. +- `Sqrt` trait: + Implements square root in a field using hinting. + ## Modular arithmetic To [leverage](./overview.md) compile-time known moduli for performance, you declare, initialize, and then set up the arithmetic structures: @@ -26,11 +29,15 @@ To [leverage](./overview.md) compile-time known moduli for performance, you decl ```rust moduli_declare! { Bls12_381Fp { modulus = "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" }, - Bn254Fp { modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" }, + Bn254Fp { + modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583", impl_field = true + }, } ``` -This creates `Bls12_381Fp` and `Bn254Fp` structs, each implementing the `IntMod` trait. The modulus parameter must be a string literal in decimal or hexadecimal format. +This creates `Bls12_381Fp` and `Bn254Fp` structs, each implementing the `IntMod` trait. +Since `impl_field = true` is specified for `Bn254Fp`, it also implements the `Field` and `Sqrt` traits. +The modulus parameter must be a string literal in decimal or hexadecimal format. 2. **Init**: Use the `init!` macro exactly once in the final binary: @@ -46,13 +53,10 @@ moduli_init! { This step enumerates the declared moduli (e.g., `0` for the first one, `1` for the second one) and sets up internal linkage so the compiler can generate the appropriate RISC-V instructions associated with each modulus. -3. **Setup**: At runtime, before performing arithmetic, a setup instruction must be sent to ensure security and correctness. For the \\(i\\)-th modulus, you call `setup_()` (e.g., `setup_0()` or `setup_1()`). Alternatively, `setup_all_moduli()` can be used to handle all declared moduli. - **Summary**: - `moduli_declare!`: Declares modular arithmetic structures and can be done multiple times. - `init!`: Called once in the final binary to assign and lock in the moduli. -- `setup_()`/`setup_all_moduli()`: Ensures at runtime that the correct modulus is in use, providing a security check and finalizing the environment for safe arithmetic operations. ## Complex field extension @@ -83,8 +87,6 @@ complex_init! { */ ``` -3. **Setup**: Similar to moduli, call `setup_complex_()` or `setup_all_complex_extensions()` at runtime to secure the environment. - ### Config parameters For the guest program to build successfully, all used moduli must be declared in the `.toml` config file in the following format: diff --git a/book/src/custom-extensions/ecc.md b/book/src/custom-extensions/ecc.md index 25f98ba183..4a1b76d6eb 100644 --- a/book/src/custom-extensions/ecc.md +++ b/book/src/custom-extensions/ecc.md @@ -54,13 +54,10 @@ sw_init! { */ ``` -3. **Setup**: Similar to the moduli and complex extensions, runtime setup instructions ensure that the correct curve parameters are being used, guaranteeing secure operation. - **Summary**: - `sw_declare!`: Declares elliptic curve structures. - `init!`: Initializes them once, linking them to the underlying moduli. -- `setup_sw_()`/`setup_all_curves()`: Secures runtime correctness. To use elliptic curve operations on a struct defined with `sw_declare!`, it is expected that the struct for the curve's coordinate field was defined using `moduli_declare!`. In particular, the coordinate field needs to be initialized and set up as described in the [algebra extension](./algebra.md) chapter. diff --git a/book/src/custom-extensions/pairing.md b/book/src/custom-extensions/pairing.md index 807de50e20..24d933dba1 100644 --- a/book/src/custom-extensions/pairing.md +++ b/book/src/custom-extensions/pairing.md @@ -36,14 +36,6 @@ Additionally, we'll need to initialize our moduli and `Fp2` struct via the follo {{ #include ../../../examples/pairing/src/main.rs:init }} ``` -And we'll run the required setup functions at the top of the guest program's `main()` function: - -```rust,no_run,noplayground -{{ #include ../../../examples/pairing/src/main.rs:setup }} -``` - -There are two moduli defined internally in the `Bls12_381` feature. The `moduli_init!` macro thus requires both of them to be initialized. However, we do not need the scalar field of BLS12-381 (which is at index 1), and thus we only initialize the modulus from index 0, thus we only use `setup_0()` (as opposed to `setup_all_moduli()`, which will save us some columns when generating the trace). - ## Input values The inputs to the pairing check are `AffinePoint`s in \\(\mathbb{F}\_p\\) and \\(\mathbb{F}\_{p^2}\\). They can be constructed via the `AffinePoint::new` function, with the inner `Fp` and `Fp2` values constructed via various `from_...` functions. diff --git a/docs/specs/ISA.md b/docs/specs/ISA.md index 1bc3d0c4a0..4fe25c59d8 100644 --- a/docs/specs/ISA.md +++ b/docs/specs/ISA.md @@ -611,9 +611,7 @@ the same format that is congruent modulo `N` to the respective operation applied For each instruction, the operand `d` is fixed to be `1` and `e` is fixed to be `2`. Each instruction performs block accesses with block size `4` in address space `1` and block size `N::BLOCK_SIZE` in -address space `2`, where `N::NUM_LIMBS` is divisible by `N::BLOCK_SIZE`. Recall that `N::BLOCK_SIZE` must be a power of - -2. +address space `2`, where `N::NUM_LIMBS` is divisible by `N::BLOCK_SIZE`. Recall that `N::BLOCK_SIZE` must be a power of 2. | Name | Operands | Description | | ------------------------- | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -635,6 +633,16 @@ format with each limb having `LIMB_BITS` bits. | ISEQMOD_RV32\ | `a,b,c,1,2` | `[a:4]_1 = [r32{0}(b): N::NUM_LIMBS]_2 == [r32{0}(c): N::NUM_LIMBS]_2 (mod N) ? 1 : 0`. Enforces that `[r32{0}(b): N::NUM_LIMBS]_2, [r32{0}(c): N::NUM_LIMBS]_2` are less than `N` and then sets the register value of `[a:4]_1` to `1` or `0` depending on whether the two big integers are equal. | | SETUP_ISEQMOD_RV32\ | `a,b,c,1,2` | `assert([r32{0}(b): N::NUM_LIMBS]_2 == N)` in the chip that handles modular equality. For the sake of implementation convenience it also writes something (can be anything) into register value of `[a:4]_1` | +#### Phantom Sub-Instructions + + +| Name | Discriminant | Operands | Description | +| -------------- | ------------ | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| HintNonQr\ | 0x50 | `_,_,c_upper` | Use `c_upper` to determine the index of the modulus from the list of supported moduli. Reset the hint stream to equal a quadratic nonresidue modulo `N`. | +| HintSqrt\ | 0x51 | `a,_,c_upper` | Use `c_upper` to determine the index of the modulus from the list of supported moduli. Read from memory `x = [r32{0}(a): N::NUM_LIMBS]_2`. If `x` is a quadratic residue modulo `N`, reset the hint stream to `[1u8, 0u8, 0u8, 0u8]` followed by a square root of `x`. If `x` is not a quadratic residue, reset the hint stream to `[0u8; 4]` followed by a square root of `x * non_qr`, where `non_qr` is the quadratic nonresidue returned by `HintNonQr`. | + +# + #### Complex Extension Field A complex extension field `Fp2` is the quadratic extension of a prime field `Fp` with irreducible polynomial `X^2 + 1`. diff --git a/docs/specs/RISCV.md b/docs/specs/RISCV.md index 7948272c66..0bbe9d297c 100644 --- a/docs/specs/RISCV.md +++ b/docs/specs/RISCV.md @@ -132,7 +132,9 @@ generates classes `Bls12381` and `Bn254` that represent the elements of the corr ### Field Arithmetic -For each created modular class, one must call a corresponding `setup_*` function once at the beginning of the program. For example, for the structs above this would be `setup_0()` and `setup_1()`. This function generates the `setup` intrinsics which are distinguished by the `rs2` operand that specifies the chip this instruction is passed to.. +For each created modular class, one must call a corresponding `setup_*` function before using the intrinsics. +For example, for the structs above this would be `setup_0()` and `setup_1()`. This function generates the `setup` intrinsics which are distinguished by the `rs2` operand that specifies the chip this instruction is passed to.. +For convenience, each modulus's `setup_*` function is automatically called on the first use of any of its intrinsics. We use `config.mod_idx(N)` to denote the index of `N` in this list. In the list below, `idx` denotes `config.mod_idx(N)`. @@ -146,6 +148,8 @@ We use `config.mod_idx(N)` to denote the index of `N` in this list. In the list | divmod\ | R | 0101011 | 000 | `idx*8+3` | `[rd: N::NUM_LIMBS]_2 = [rs1: N::NUM_LIMBS]_2 / [rs2: N::NUM_LIMBS]_2 (mod N)` (undefined when `gcd([rs2: N::NUM_LIMBS]_2, N) != 1`) | | iseqmod\ | R | 0101011 | 000 | `idx*8+4` | `rd = [rs1: N::NUM_LIMBS]_2 == [rs2: N::NUM_LIMBS]_2 (mod N) ? 1 : 0`. If `rd != x0`, enforces that `[rs1: N::NUM_LIMBS]_2` and `[rs2: N::NUM_LIMBS]_2` are both less than `N` and then sets `rd` equal to boolean comparison value. If `rd = x0`, this is a no-op. | | setup\ | R | 0101011 | 000 | `idx*8+5` | `assert([rs1: N::NUM_LIMBS]_2 == N)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: N::NUM_LIMBS]_2` if `ind(rs2) = 0,1` (for add_sub, mul_div) or it overwrites the register value of `rd` with an unconstrained value if `ind(rs2) = 2` (for iseq). If `ind(rs2) = 2`, then the instruction is **invalid** if `rd = x0`. | +| hint_non_qr\ | R | 0101011 | 001 | `idx*8+6` | Reset the hint stream to equal `non_qr` where `non_qr` is a quadratic nonresidue modulo `N`. The same `non_qr` is returned in each execution of this instruction. `rd`, `rs1`, and `rs2` should be `x0`. | +| hint_sqrt\ | R | 0101011 | 001 | `idx*8+7` | Read `x = [rs1: N::NUM_LIMBS]_2`. If `x` is a quadratic residue modulo `N` then reset the hint stream to `[1u0, 0u8, 0u8, 0u8]` concatenated with a square root of `x`. If `x` is not a quadratic residue, then reset the hint stream to `[0u8; 4]` concatenated with a square root of `x * non_qr` where `non_qr` is the quadratic nonresidue returned by `hint_non_qr`. `rd` and `rs2` should be `x0`. | Since `funct7` is 7-bits, up to 16 moduli can be supported simultaneously. We use `idx*8` to leave some room for future expansion. diff --git a/docs/specs/transpiler.md b/docs/specs/transpiler.md index a8b94ef2b0..e334d446a4 100644 --- a/docs/specs/transpiler.md +++ b/docs/specs/transpiler.md @@ -186,6 +186,8 @@ Each VM extension's behavior is specified below. | divmod\ | DIVMOD_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | | iseqmod\ | ISEQMOD_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `rd != x0`, otherwise PHANTOM `_, _, disc(Nop)` | | setup\ | SETUP_ADDSUBMOD_RV32\ `ind(rd), ind(rs1), x0, 1, 2` if `ind(rs2) = 0`, SETUP_MULDIVMOD_RV32\ `ind(rd), ind(rs1), x0, 1, 2` if `ind(rs2) = 1`, SETUP_ISEQMOD_RV32\ `ind(rd), ind(rs1), x0, 1, 2` if `ind(rs2) = 2` | +| hint_non_qr | PHANTOM `0, 0, phantom_c(curve_idx, HintNonQr)` | +| hint_sqrt | PHANTOM `ind(rs1), 0, phantom_c(curve_idx, HintSqrt)` | #### Complex Extension Field Arithmetic diff --git a/examples/algebra/src/main.rs b/examples/algebra/src/main.rs index 72c5d78ed7..de3d49697c 100644 --- a/examples/algebra/src/main.rs +++ b/examples/algebra/src/main.rs @@ -31,10 +31,6 @@ openvm_algebra_complex_macros::complex_init! { */ pub fn main() { - // Since we only use an arithmetic operation with `Mod1` and not `Mod2`, - // we only need to call `setup_0()` here. - setup_0(); - setup_all_complex_extensions(); let a = Complex1::new(Mod1::ZERO, Mod1::from_u32(0x3b8) * Mod1::from_u32(0x100000)); // a = -i in the corresponding field let b = Complex2::new(Mod2::ZERO, Mod2::from_u32(1000000006)); // b = -i in the corresponding field assert_eq!(a.clone() * &a * &a * &a * &a, a); // a^5 = a diff --git a/examples/ecc/src/main.rs b/examples/ecc/src/main.rs index c8e3299dc5..3468d6b26b 100644 --- a/examples/ecc/src/main.rs +++ b/examples/ecc/src/main.rs @@ -23,8 +23,6 @@ openvm_ecc_guest::sw_macros::sw_init! { // ANCHOR: main pub fn main() { - setup_all_moduli(); - setup_all_curves(); let x1 = Secp256k1Coord::from_u32(1); let y1 = Secp256k1Coord::from_le_bytes(&hex!( "EEA7767E580D75BC6FDD7F58D2A84C2614FB22586068DB63B346C6E60AF21842" diff --git a/examples/pairing/src/main.rs b/examples/pairing/src/main.rs index 60a601efe3..afb3828d0a 100644 --- a/examples/pairing/src/main.rs +++ b/examples/pairing/src/main.rs @@ -26,11 +26,6 @@ openvm_algebra_complex_macros::complex_init! { // ANCHOR: main pub fn main() { - // ANCHOR: setup - setup_0(); - setup_all_complex_extensions(); - // ANCHOR_END: setup - let p0 = AffinePoint::new( Fp::from_be_bytes(&hex!("17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb")), Fp::from_be_bytes(&hex!("08b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1")) diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 7949fb0946..258bff450b 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -30,6 +30,7 @@ derive-new = { workspace = true } serde.workspace = true serde_with = { workspace = true } serde-big-array = { workspace = true } +eyre = { workspace = true } [dev-dependencies] halo2curves-axiom = { workspace = true } diff --git a/extensions/algebra/circuit/src/fp2_extension.rs b/extensions/algebra/circuit/src/fp2_extension.rs index 30a55cf015..20ce310d23 100644 --- a/extensions/algebra/circuit/src/fp2_extension.rs +++ b/extensions/algebra/circuit/src/fp2_extension.rs @@ -36,7 +36,7 @@ impl Fp2Extension { pub fn generate_complex_init(&self, modular_config: &ModularExtension) -> String { fn get_index_of_modulus(modulus: &BigUint, modular_config: &ModularExtension) -> usize { modular_config - .supported_modulus + .supported_moduli .iter() .position(|m| m == modulus) .expect("Modulus used in Fp2Extension not found in ModularExtension") diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index 6050045aed..b51696841b 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -1,6 +1,7 @@ use derive_more::derive::From; -use num_bigint::BigUint; -use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; +use num_bigint::{BigUint, RandBigInt}; +use num_traits::{FromPrimitive, One}; +use openvm_algebra_transpiler::{ModularPhantom, Rv32ModularArithmeticOpcode}; use openvm_circuit::{ self, arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, @@ -11,10 +12,11 @@ use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_instructions::{LocalOpcode, VmOpcode}; +use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; use openvm_rv32_adapters::{Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip}; use openvm_stark_backend::p3_field::PrimeField32; +use rand::{rngs::StdRng, SeedableRng}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use strum::EnumCount; @@ -27,14 +29,14 @@ use crate::modular_chip::{ #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct ModularExtension { #[serde_as(as = "Vec")] - pub supported_modulus: Vec, + pub supported_moduli: Vec, } impl ModularExtension { // Generates a call to the moduli_init! macro with moduli in the correct order pub fn generate_moduli_init(&self) -> String { let supported_moduli = self - .supported_modulus + .supported_moduli .iter() .map(|modulus| format!("\"{}\"", modulus)) .collect::>() @@ -99,7 +101,7 @@ impl VmExtension for ModularExtension { let iseq_opcodes = (Rv32ModularArithmeticOpcode::IS_EQ as usize) ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize); - for (i, modulus) in self.supported_modulus.iter().enumerate() { + for (i, modulus) in self.supported_moduli.iter().enumerate() { // determine the number of bytes needed to represent a prime field element let bytes = modulus.bits().div_ceil(8); let start_offset = @@ -230,7 +232,255 @@ impl VmExtension for ModularExtension { panic!("Modulus too large"); } } + let non_qr_hint_sub_ex = phantom::NonQrHintSubEx::new(self.supported_moduli.clone()); + builder.add_phantom_sub_executor( + non_qr_hint_sub_ex.clone(), + PhantomDiscriminant(ModularPhantom::HintNonQr as u16), + )?; + + let sqrt_hint_sub_ex = phantom::SqrtHintSubEx::new(non_qr_hint_sub_ex); + builder.add_phantom_sub_executor( + sqrt_hint_sub_ex, + PhantomDiscriminant(ModularPhantom::HintSqrt as u16), + )?; Ok(inventory) } } + +pub(crate) mod phantom { + use std::{ + iter::{once, repeat}, + ops::Deref, + }; + + use eyre::bail; + use num_bigint::BigUint; + use openvm_circuit::{ + arch::{PhantomSubExecutor, Streams}, + system::memory::MemoryController, + }; + use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant}; + use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register; + use openvm_stark_backend::p3_field::PrimeField32; + + use super::{find_non_qr, mod_sqrt}; + + #[derive(derive_new::new)] + pub struct SqrtHintSubEx(NonQrHintSubEx); + + impl Deref for SqrtHintSubEx { + type Target = NonQrHintSubEx; + + fn deref(&self) -> &NonQrHintSubEx { + &self.0 + } + } + + // Given x returns either a sqrt of x or a sqrt of x * non_qr, whichever exists. + // Note that non_qr is fixed for each modulus. + impl PhantomSubExecutor for SqrtHintSubEx { + fn phantom_execute( + &mut self, + memory: &MemoryController, + streams: &mut Streams, + _: PhantomDiscriminant, + a: F, + _: F, + c_upper: u16, + ) -> eyre::Result<()> { + let mod_idx = c_upper as usize; + if mod_idx >= self.supported_moduli.len() { + bail!( + "Modulus index {mod_idx} out of range: {} supported moduli", + self.supported_moduli.len() + ); + } + let modulus = &self.supported_moduli[mod_idx]; + let num_limbs: usize = if modulus.bits().div_ceil(8) <= 32 { + 32 + } else if modulus.bits().div_ceil(8) <= 48 { + 48 + } else { + bail!("Modulus too large") + }; + + let rs1 = unsafe_read_rv32_register(memory, a); + let mut x_limbs: Vec = Vec::with_capacity(num_limbs); + for i in 0..num_limbs { + let limb = memory.unsafe_read_cell( + F::from_canonical_u32(RV32_MEMORY_AS), + F::from_canonical_u32(rs1 + i as u32), + ); + x_limbs.push(limb.as_canonical_u32() as u8); + } + let x = BigUint::from_bytes_le(&x_limbs); + + let (success, sqrt) = match mod_sqrt(&x, modulus, &self.non_qrs[mod_idx]) { + Some(sqrt) => (true, sqrt), + None => { + let sqrt = mod_sqrt( + &(&x * &self.non_qrs[mod_idx]), + modulus, + &self.non_qrs[mod_idx], + ) + .expect("Either x or x * non_qr should be a square"); + (false, sqrt) + } + }; + + let hint_bytes = once(F::from_bool(success)) + .chain(repeat(F::ZERO)) + .take(4) + .chain( + sqrt.to_bytes_le() + .into_iter() + .map(F::from_canonical_u8) + .chain(repeat(F::ZERO)) + .take(num_limbs), + ) + .collect(); + streams.hint_stream = hint_bytes; + Ok(()) + } + } + + #[derive(Clone)] + pub struct NonQrHintSubEx { + pub supported_moduli: Vec, + pub non_qrs: Vec, + } + + impl NonQrHintSubEx { + pub fn new(supported_moduli: Vec) -> Self { + let non_qrs = supported_moduli.iter().map(find_non_qr).collect(); + Self { + supported_moduli, + non_qrs, + } + } + } + + impl PhantomSubExecutor for NonQrHintSubEx { + fn phantom_execute( + &mut self, + _: &MemoryController, + streams: &mut Streams, + _: PhantomDiscriminant, + _: F, + _: F, + c_upper: u16, + ) -> eyre::Result<()> { + let mod_idx = c_upper as usize; + if mod_idx >= self.supported_moduli.len() { + bail!( + "Modulus index {mod_idx} out of range: {} supported moduli", + self.supported_moduli.len() + ); + } + let modulus = &self.supported_moduli[mod_idx]; + + let num_limbs: usize = if modulus.bits().div_ceil(8) <= 32 { + 32 + } else if modulus.bits().div_ceil(8) <= 48 { + 48 + } else { + bail!("Modulus too large") + }; + + let hint_bytes = self.non_qrs[mod_idx] + .to_bytes_le() + .into_iter() + .map(F::from_canonical_u8) + .chain(repeat(F::ZERO)) + .take(num_limbs) + .collect(); + streams.hint_stream = hint_bytes; + Ok(()) + } + } +} + +/// Find the square root of `x` modulo `modulus` with `non_qr` a +/// quadratic nonresidue of the field. +pub fn mod_sqrt(x: &BigUint, modulus: &BigUint, non_qr: &BigUint) -> Option { + if modulus % 4u32 == BigUint::from_u8(3).unwrap() { + // x^(1/2) = x^((p+1)/4) when p = 3 mod 4 + let exponent = (modulus + BigUint::one()) >> 2; + let ret = x.modpow(&exponent, modulus); + if &ret * &ret % modulus == x % modulus { + Some(ret) + } else { + None + } + } else { + // Tonelli-Shanks algorithm + // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm#The_algorithm + let mut q = modulus - BigUint::one(); + let mut s = 0; + while &q % 2u32 == BigUint::ZERO { + s += 1; + q /= 2u32; + } + let z = non_qr; + let mut m = s; + let mut c = z.modpow(&q, modulus); + let mut t = x.modpow(&q, modulus); + let mut r = x.modpow(&((q + BigUint::one()) >> 1), modulus); + loop { + if t == BigUint::ZERO { + return Some(BigUint::ZERO); + } + if t == BigUint::one() { + return Some(r); + } + let mut i = 0; + let mut tmp = t.clone(); + while tmp != BigUint::one() && i < m { + tmp = &tmp * &tmp % modulus; + i += 1; + } + if i == m { + // self is not a quadratic residue + return None; + } + for _ in 0..m - i - 1 { + c = &c * &c % modulus; + } + let b = c; + m = i; + c = &b * &b % modulus; + t = ((t * &b % modulus) * &b) % modulus; + r = (r * b) % modulus; + } + } +} + +// Returns a non-quadratic residue in the field +pub fn find_non_qr(modulus: &BigUint) -> BigUint { + if modulus % 4u32 == BigUint::from(3u8) { + // p = 3 mod 4 then -1 is a quadratic residue + modulus - BigUint::one() + } else if modulus % 8u32 == BigUint::from(5u8) { + // p = 5 mod 8 then 2 is a non-quadratic residue + // since 2^((p-1)/2) = (-1)^((p^2-1)/8) + BigUint::from_u8(2u8).unwrap() + } else { + let mut rng = StdRng::from_entropy(); + let mut non_qr = rng.gen_biguint_range( + &BigUint::from_u8(2).unwrap(), + &(modulus - BigUint::from_u8(1).unwrap()), + ); + // To check if non_qr is a quadratic nonresidue, we compute non_qr^((p-1)/2) + // If the result is p-1, then non_qr is a quadratic nonresidue + // Otherwise, non_qr is a quadratic residue + let exponent = (modulus - BigUint::one()) >> 1; + while non_qr.modpow(&exponent, modulus) != modulus - BigUint::one() { + non_qr = rng.gen_biguint_range( + &BigUint::from_u8(2).unwrap(), + &(modulus - BigUint::from_u8(1).unwrap()), + ); + } + non_qr + } +} diff --git a/extensions/algebra/complex-macros/README.md b/extensions/algebra/complex-macros/README.md index 9c786a73ca..091baac614 100644 --- a/extensions/algebra/complex-macros/README.md +++ b/extensions/algebra/complex-macros/README.md @@ -27,8 +27,6 @@ openvm_algebra_complex_macros::complex_init! { */ pub fn main() { - setup_all_moduli(); - setup_all_complex_extensions(); // ... } ``` @@ -79,13 +77,9 @@ mod openvm_intrinsics_ffi_complex { pub fn setup_complex_0() { // send the setup instructions } -pub fn setup_all_complex_extensions() { - setup_complex_0(); - // call all other setup_complex_* for all the items in the moduli_init! macro -} ``` -3. Obviously, `mod_idx` in the `complex_init!` must match the position of the corresponding modulus in the `moduli_init!` macro. The order of the items in `complex_init!` affects what `setup_complex_*` function will correspond to what complex class. Also, it **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of numbers in `Fp2Extension::supported_modulus`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). However, it again imposes the restriction that we only can invoke `complex_init!` once. Again analogous to the moduli setups, we must call `setup_complex_*` for each used complex extension before doing anything with entities of that class (or one can call `setup_all_complex_extensions` to setup all of them, if all are used). +3. Obviously, `mod_idx` in the `complex_init!` must match the position of the corresponding modulus in the `moduli_init!` macro. The order of the items in `complex_init!` affects what `setup_complex_*` function will correspond to what complex class. Also, it **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of numbers in `Fp2Extension::supported_modulus`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). However, it again imposes the restriction that we only can invoke `complex_init!` once. Again analogous to the moduli setups, `setup_complex_*` is automatically called on each complex extension on first use of its intrinsics. 4. Note that, due to the nature of function names, the name of the struct used in `complex_init!` must be the same as in `complex_declare!`. To illustrate, the following code will **fail** to compile: diff --git a/extensions/algebra/complex-macros/src/lib.rs b/extensions/algebra/complex-macros/src/lib.rs index d829abeeed..08f1456f99 100644 --- a/extensions/algebra/complex-macros/src/lib.rs +++ b/extensions/algebra/complex-macros/src/lib.rs @@ -59,12 +59,15 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { create_extern_func!(complex_mul_extern_func); create_extern_func!(complex_div_extern_func); + let setup_function = syn::Ident::new(&format!("setup_{}", struct_name), span.into()); + let result = TokenStream::from(quote::quote_spanned! { span.into() => extern "C" { fn #complex_add_extern_func(rd: usize, rs1: usize, rs2: usize); fn #complex_sub_extern_func(rd: usize, rs1: usize, rs2: usize); fn #complex_mul_extern_func(rd: usize, rs1: usize, rs2: usize); fn #complex_div_extern_func(rd: usize, rs1: usize, rs2: usize); + fn #setup_function(); } @@ -110,6 +113,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #complex_add_extern_func( self as *mut Self as usize, @@ -130,6 +134,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #complex_sub_extern_func( self as *mut Self as usize, @@ -154,6 +159,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #complex_mul_extern_func( self as *mut Self as usize, @@ -179,6 +185,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #complex_div_extern_func( self as *mut Self as usize, @@ -199,6 +206,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); unsafe { #complex_add_extern_func( @@ -222,6 +230,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); unsafe { #complex_sub_extern_func( @@ -249,6 +258,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #complex_mul_extern_func( dst_ptr as usize, @@ -270,6 +280,7 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); unsafe { #complex_div_extern_func( @@ -281,6 +292,15 @@ pub fn complex_declare(input: TokenStream) -> TokenStream { unsafe { uninit.assume_init() } } } + + // Helper function to call the setup instruction on first use + fn assert_is_setup() { + static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new(); + is_setup.get_or_init(|| { + unsafe { #setup_function(); } + true + }); + } } impl openvm_algebra_guest::field::ComplexConjugate for #struct_name { @@ -528,7 +548,6 @@ pub fn complex_init(input: TokenStream) -> TokenStream { let mut externs = Vec::new(); let mut setups = Vec::new(); - let mut setup_all_complex_extensions = Vec::new(); let span = proc_macro::Span::call_site(); @@ -587,22 +606,20 @@ pub fn complex_init(input: TokenStream) -> TokenStream { }); } - let setup_function = - syn::Ident::new(&format!("setup_complex_{}", complex_idx), span.into()); + let setup_function = syn::Ident::new(&format!("setup_{}", struct_name), span.into()); - setup_all_complex_extensions.push(quote::quote_spanned! { span.into() => - #setup_function(); - }); setups.push(quote::quote_spanned! { span.into() => #[allow(non_snake_case)] - pub fn #setup_function() { + #[no_mangle] + extern "C" fn #setup_function() { #[cfg(target_os = "zkvm")] { - let two_modulus_bytes = &openvm_intrinsics_meta_do_not_type_this_by_yourself::two_modular_limbs_list[openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx]..openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx + 1]]; + use super::openvm_intrinsics_meta_do_not_type_this_by_yourself::{two_modular_limbs_list, limb_list_borders}; + let two_modulus_bytes = &two_modular_limbs_list[limb_list_borders[#mod_idx]..limb_list_borders[#mod_idx + 1]]; // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup. // The transpiler will transform this instruction, based on whether `rs2` is `x0` or `x1`, into a `SETUP_ADDSUB` or `SETUP_MULDIV` instruction. - let mut uninit: core::mem::MaybeUninit<[u8; openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx + 1] - openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx]]> = core::mem::MaybeUninit::uninit(); + let mut uninit: core::mem::MaybeUninit<[u8; limb_list_borders[#mod_idx + 1] - limb_list_borders[#mod_idx]]> = core::mem::MaybeUninit::uninit(); openvm::platform::custom_insn_r!( opcode = ::openvm_algebra_guest::OPCODE, funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3, @@ -632,10 +649,7 @@ pub fn complex_init(input: TokenStream) -> TokenStream { #[cfg(target_os = "zkvm")] mod openvm_intrinsics_ffi_complex { #(#externs)* - } - #(#setups)* - pub fn setup_all_complex_extensions() { - #(#setup_all_complex_extensions)* + #(#setups)* } }) } diff --git a/extensions/algebra/guest/Cargo.toml b/extensions/algebra/guest/Cargo.toml index d57c9117e3..89548bbb60 100644 --- a/extensions/algebra/guest/Cargo.toml +++ b/extensions/algebra/guest/Cargo.toml @@ -10,8 +10,11 @@ repository.workspace = true [dependencies] openvm-algebra-moduli-macros = { workspace = true } openvm-algebra-complex-macros = { workspace = true } +openvm-rv32im-guest = { workspace = true } +openvm-custom-insn = { workspace = true } serde-big-array.workspace = true strum_macros.workspace = true +once_cell = { workspace = true, features = ["race", "alloc"] } [target.'cfg(not(target_os = "zkvm"))'.dependencies] num-bigint.workspace = true diff --git a/extensions/algebra/guest/src/halo2curves.rs b/extensions/algebra/guest/src/halo2curves.rs index a9a6513ac6..7b913e0d42 100644 --- a/extensions/algebra/guest/src/halo2curves.rs +++ b/extensions/algebra/guest/src/halo2curves.rs @@ -4,37 +4,72 @@ use halo2curves_axiom::ff; use crate::{field::Field, DivAssignUnsafe, DivUnsafe}; -impl<'a, F: ff::Field> DivUnsafe<&'a F> for F { - type Output = F; +macro_rules! div_unsafe_impl { + ($($t:ty),*) => { + $( + impl DivUnsafe for $t { + type Output = $t; + + fn div_unsafe(self, other: Self) -> Self::Output { + self * other.invert().unwrap() + } + } - fn div_unsafe(self, other: &'a F) -> Self::Output { - self * other.invert().unwrap() - } -} + impl<'a> DivUnsafe<&'a $t> for $t { + type Output = $t; -impl<'a, F: ff::Field> DivUnsafe<&'a F> for &'a F { - type Output = F; + fn div_unsafe(self, other: &'a $t) -> Self::Output { + self * other.invert().unwrap() + } + } - fn div_unsafe(self, other: &'a F) -> Self::Output { - *self * other.invert().unwrap() - } -} + impl<'a> DivUnsafe<&'a $t> for &'a $t { + type Output = $t; -impl DivAssignUnsafe for F { - fn div_assign_unsafe(&mut self, other: Self) { - *self *= other.invert().unwrap(); - } -} + fn div_unsafe(self, other: &'a $t) -> Self::Output { + *self * other.invert().unwrap() + } + } -impl<'a, F: ff::Field> DivAssignUnsafe<&'a F> for F { - fn div_assign_unsafe(&mut self, other: &'a F) { - *self *= other.invert().unwrap(); - } + impl DivAssignUnsafe for $t { + fn div_assign_unsafe(&mut self, other: Self) { + *self *= other.invert().unwrap(); + } + } + + impl<'a> DivAssignUnsafe<&'a $t> for $t { + fn div_assign_unsafe(&mut self, other: &'a $t) { + *self *= other.invert().unwrap(); + } + } + )* + }; } +div_unsafe_impl!( + halo2curves_axiom::bls12_381::Fq, + halo2curves_axiom::bls12_381::Fq12, + halo2curves_axiom::bls12_381::Fq2 +); +div_unsafe_impl!( + halo2curves_axiom::bn256::Fq, + halo2curves_axiom::bn256::Fq12, + halo2curves_axiom::bn256::Fq2 +); + impl Field for F where - for<'a> &'a F: Add<&'a F, Output = F> + Sub<&'a F, Output = F> + Mul<&'a F, Output = F>, + for<'a> &'a F: Add<&'a F, Output = F> + + Sub<&'a F, Output = F> + + Mul<&'a F, Output = F> + + DivUnsafe<&'a F, Output = F>, + for<'a> F: Add<&'a F, Output = F> + + Sub<&'a F, Output = F> + + Mul<&'a F, Output = F> + + DivAssignUnsafe + + DivUnsafe + + DivAssignUnsafe<&'a F> + + DivUnsafe<&'a F, Output = F>, { const ZERO: Self = ::ZERO; const ONE: Self = ::ONE; diff --git a/extensions/algebra/guest/src/lib.rs b/extensions/algebra/guest/src/lib.rs index 1d778f0013..4bb4aaec5b 100644 --- a/extensions/algebra/guest/src/lib.rs +++ b/extensions/algebra/guest/src/lib.rs @@ -17,6 +17,8 @@ pub enum ModArithBaseFunct7 { DivMod, IsEqMod, SetupMod, + HintNonQr, + HintSqrt, } impl ModArithBaseFunct7 { @@ -54,11 +56,13 @@ pub use field::Field; use num_bigint::BigUint; pub use openvm_algebra_complex_macros as complex_macros; pub use openvm_algebra_moduli_macros as moduli_macros; +#[cfg(target_os = "zkvm")] +pub use openvm_custom_insn; +#[cfg(target_os = "zkvm")] +pub use openvm_rv32im_guest; pub use serde_big_array::BigArray; use strum_macros::FromRepr; -/// Field traits -pub mod field; /// Implementation of this library's traits on halo2curves types. /// Used for testing and also VM runtime execution. /// These should **only** be importable on a host machine. @@ -67,7 +71,10 @@ mod halo2curves; /// Exponentiation by bytes mod exp_bytes; +/// Field traits +pub mod field; pub use exp_bytes::*; +pub use once_cell; /// Division operation that is undefined behavior when the denominator is not invertible. pub trait DivUnsafe: Sized { @@ -232,3 +239,11 @@ pub trait Reduce: Sized { Self::reduce_le_bytes(&bytes.iter().rev().copied().collect::>()) } } + +// Note that we use a hint-based approach to prove whether the square root exists. +// This approach works for prime moduli, but not necessarily for composite moduli, +// which is why the Sqrt trait requires the Field trait, not just the IntMod trait. +pub trait Sqrt: Field { + /// Returns a square root of `self` if it exists. + fn sqrt(&self) -> Option; +} diff --git a/extensions/algebra/moduli-macros/README.md b/extensions/algebra/moduli-macros/README.md index 61513d0603..ba605fdc9b 100644 --- a/extensions/algebra/moduli-macros/README.md +++ b/extensions/algebra/moduli-macros/README.md @@ -6,7 +6,10 @@ Procedural macros for use in guest program to generate modular arithmetic struct ```rust openvm_algebra_moduli_macros::moduli_declare! { - Bls12381 { modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" }, + Bls12381 { + modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787", + impl_field = true + }, Mod1e18 { modulus = "1000000000000000003" }, } @@ -28,7 +31,10 @@ openvm_algebra_moduli_macros::moduli_init! { The crate provides two macros: `moduli_declare!` and `moduli_init!`. The signatures are: -- `moduli_declare!` receives comma-separated list of moduli classes descriptions. Each description looks like `ModulusName { modulus = "modulus_value" }`. Here `ModulusName` is the name of the struct, and `modulus_value` is the modulus value in decimal or hex format. +- `moduli_declare!` receives comma-separated list of moduli classes descriptions. Each description looks like `ModulusName { modulus = "modulus_value", impl_field = }`. Here `ModulusName` is the name of the struct, and `modulus_value` is the modulus value in decimal or hex format. + - The `impl_field` argument indicates whether or not the `Field` and `Sqrt` traits should be automatically implemented on the resulting struct. + It should only be set to `true` if the modulus is prime. + If unspecified, it defaults to `false`. - `moduli_init!` receives comma-separated list of modulus values in decimal or hex format. @@ -145,15 +151,9 @@ pub fn setup_2() { // send the setup instruction designed for the chip number 2 } } -pub fn setup_all_moduli() { - setup_0(); - setup_1(); - setup_2(); - // setup functions for all the other moduli provided in the `moduli_init!` function -} ``` -The setup operation (e.g., `setup_2`) consists of reading the value `OPENVM_SERIALIZED_MODULUS_2` from memory and constraining that the read value is equal to the modulus the chip has been configured with. For each used modulus, its corresponding setup instruction **must** be called before all other operations -- this currently must be checked by inspecting the program code; it is not enforced by the virtual machine. +The setup operation (e.g., `setup_2`) consists of reading the value `OPENVM_SERIALIZED_MODULUS_2` from memory and constraining that the read value is equal to the modulus the chip has been configured with. For each used modulus, its corresponding setup instruction is automatically called on first use of any of its intrinsics. 5. It follows from the above that the `moduli_declare!` invocations may be in multiple places in various compilation units, but all the `declare!`d moduli must be specified at least once in `moduli_init!` so that there will be no linker errors due to missing function implementations. Correspondingly, the `moduli_init!` macro should only be called once in the entire program (in the guest crate as the topmost compilation unit). Finally, the order of the moduli in `moduli_init!` has nothing to do with the `moduli_declare!` invocations, but it **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of numbers in `ModularExtension::supported_modulus`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). diff --git a/extensions/algebra/moduli-macros/src/lib.rs b/extensions/algebra/moduli-macros/src/lib.rs index 5d8d921f2f..bc6f118ceb 100644 --- a/extensions/algebra/moduli-macros/src/lib.rs +++ b/extensions/algebra/moduli-macros/src/lib.rs @@ -34,6 +34,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { let struct_name = item.name.to_string(); let struct_name = syn::Ident::new(&struct_name, span.into()); let mut modulus: Option = None; + let mut impl_field: Option = None; for param in item.params { match param.name.to_string().as_str() { "modulus" => { @@ -44,9 +45,28 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { { modulus = Some(value.value()); } else { - return syn::Error::new_spanned(param.value, "Expected a string literal") - .to_compile_error() - .into(); + return syn::Error::new_spanned( + param.value, + "Expected a string literal for macro argument `modulus`", + ) + .to_compile_error() + .into(); + } + } + "impl_field" => { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Bool(value), + .. + }) = param.value + { + impl_field = Some(value.value()); + } else { + return syn::Error::new_spanned( + param.value, + "Expected a boolean literal for macro argument `impl_field`", + ) + .to_compile_error() + .into(); } } _ => { @@ -64,6 +84,8 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { let mut limbs = modulus_bytes.len(); let mut block_size = 32; + let impl_field = impl_field.unwrap_or(false); + if limbs <= 32 { limbs = 32; } else if limbs <= 48 { @@ -98,12 +120,16 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { create_extern_func!(mul_extern_func); create_extern_func!(div_extern_func); create_extern_func!(is_eq_extern_func); + create_extern_func!(hint_sqrt_extern_func); + create_extern_func!(hint_non_qr_extern_func); let block_size = proc_macro::Literal::usize_unsuffixed(block_size); let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap()); let module_name = format_ident!("algebra_impl_{}", mod_idx); + let setup_function = syn::Ident::new(&format!("setup_{}", modulus_hex), span.into()); + let result = TokenStream::from(quote::quote_spanned! { span.into() => /// An element of the ring of integers modulo a positive integer. /// The element is internally represented as a fixed size array of bytes. @@ -126,6 +152,9 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { fn #mul_extern_func(rd: usize, rs1: usize, rs2: usize); fn #div_extern_func(rd: usize, rs1: usize, rs2: usize); fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool; + fn #hint_sqrt_extern_func(rs1: usize); + fn #hint_non_qr_extern_func(); + fn #setup_function(); } impl #struct_name { @@ -152,6 +181,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #add_extern_func( self as *mut Self as usize, @@ -173,6 +203,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #sub_extern_func( self as *mut Self as usize, @@ -193,6 +224,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #mul_extern_func( self as *mut Self as usize, @@ -213,6 +245,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #div_extern_func( self as *mut Self as usize, @@ -237,6 +270,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #add_extern_func( dst_ptr as usize, @@ -261,6 +295,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #sub_extern_func( dst_ptr as usize, @@ -285,6 +320,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #mul_extern_func( dst_ptr as usize, @@ -305,6 +341,7 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit(); unsafe { #div_extern_func( @@ -325,11 +362,21 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #is_eq_extern_func(self as *const #struct_name as usize, other as *const #struct_name as usize) } } } + + // Helper function to call the setup instruction on first use + fn assert_is_setup() { + static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new(); + is_setup.get_or_init(|| { + unsafe { #setup_function(); } + true + }); + } } // Put trait implementations in a private module to avoid conflicts @@ -699,10 +746,10 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { fn reduce_le_bytes(bytes: &[u8]) -> Self { let mut res = ::ZERO; // base should be 2 ^ #limbs which exceeds what Self can represent - let mut base = Self::from_le_bytes(&[255u8; #limbs]); + let mut base = ::from_le_bytes(&[255u8; #limbs]); base += ::ONE; for chunk in bytes.chunks(#limbs).rev() { - res = res * &base + Self::from_le_bytes(chunk); + res = res * &base + ::from_le_bytes(chunk); } res } @@ -710,6 +757,154 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { }); output.push(result); + + if impl_field { + // implement Field and Sqrt traits for prime moduli + let field_and_sqrt_impl = TokenStream::from(quote::quote_spanned! { span.into() => + impl ::openvm_algebra_guest::Field for #struct_name { + const ZERO: Self = ::ZERO; + const ONE: Self = ::ONE; + + type SelfRef<'a> = &'a Self; + + fn double_assign(&mut self) { + ::openvm_algebra_guest::IntMod::double_assign(self); + } + + fn square_assign(&mut self) { + ::openvm_algebra_guest::IntMod::square_assign(self); + } + + } + + impl openvm_algebra_guest::Sqrt for #struct_name { + // Returns a sqrt of self if it exists, otherwise None. + // Note that we use a hint-based approach to prove whether the square root exists. + // This approach works for prime moduli, but not necessarily for composite moduli, + // which is why we have the sqrt method in the Field trait, not the IntMod trait. + fn sqrt(&self) -> Option { + match self.honest_host_sqrt() { + // self is a square + Some(Some(sqrt)) => Some(sqrt), + // self is not a square + Some(None) => None, + // host is dishonest + None => { + // host is dishonest, enter infinite loop + loop { + openvm::io::println("ERROR: Square root hint is invalid. Entering infinite loop."); + } + } + } + } + } + + impl #struct_name { + // Returns None if the hint is incorrect (i.e. the host is dishonest) + // Returns Some(None) if the hint proves that self is not a quadratic residue + // Otherwise, returns Some(Some(sqrt)) where sqrt is a square root of self + fn honest_host_sqrt(&self) -> Option> { + let (is_square, sqrt) = self.hint_sqrt_impl()?; + + if is_square { + // ensure sqrt < modulus + ::assert_reduced(&sqrt); + + if &(&sqrt * &sqrt) == self { + Some(Some(sqrt)) + } else { + None + } + } else { + // ensure sqrt < modulus + ::assert_reduced(&sqrt); + + if &sqrt * &sqrt == self * Self::get_non_qr() { + Some(None) + } else { + None + } + } + } + + + // Returns None if the hint is malformed. + // Otherwise, returns Some((is_square, sqrt)) where sqrt is a square root of self if is_square is true, + // and a square root of self * non_qr if is_square is false. + fn hint_sqrt_impl(&self) -> Option<(bool, Self)> { + #[cfg(not(target_os = "zkvm"))] + { + unimplemented!(); + } + #[cfg(target_os = "zkvm")] + { + use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_store_u32! and hint_buffer_u32! + + let is_square = core::mem::MaybeUninit::::uninit(); + let sqrt = core::mem::MaybeUninit::<#struct_name>::uninit(); + unsafe { + #hint_sqrt_extern_func(self as *const #struct_name as usize); + let is_square_ptr = is_square.as_ptr() as *const u32; + openvm_rv32im_guest::hint_store_u32!(is_square_ptr); + openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#struct_name as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4); + let is_square = is_square.assume_init(); + if is_square == 0 || is_square == 1 { + Some((is_square == 1, sqrt.assume_init())) + } else { + None + } + } + } + } + + // Generate a non quadratic residue by using a hint + fn init_non_qr() -> alloc::boxed::Box<#struct_name> { + #[cfg(not(target_os = "zkvm"))] + { + unimplemented!(); + } + #[cfg(target_os = "zkvm")] + { + use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_buffer_u32! + + let mut non_qr_uninit = core::mem::MaybeUninit::::uninit(); + let mut non_qr; + unsafe { + #hint_non_qr_extern_func(); + let ptr = non_qr_uninit.as_ptr() as *const u8; + openvm_rv32im_guest::hint_buffer_u32!(ptr, ::NUM_LIMBS / 4); + non_qr = non_qr_uninit.assume_init(); + } + // ensure non_qr < modulus + ::assert_reduced(&non_qr); + + use ::openvm_algebra_guest::{DivUnsafe, ExpBytes}; + // construct exp = (p-1)/2 as an integer by first constraining exp = (p-1)/2 (mod p) and then exp < p + let exp = -::ONE.div_unsafe(Self::from_const_u8(2)); + ::assert_reduced(&exp); + + if non_qr.exp_bytes(true, &::to_be_bytes(&exp)) != -::ONE + { + // non_qr is not a non quadratic residue, so host is dishonest + loop { + openvm::io::println("ERROR: Non quadratic residue hint is invalid. Entering infinite loop."); + } + } + + alloc::boxed::Box::new(non_qr) + } + } + + // This function is public for use in tests + pub fn get_non_qr() -> &'static #struct_name { + static non_qr: ::openvm_algebra_guest::once_cell::race::OnceBox<#struct_name> = ::openvm_algebra_guest::once_cell::race::OnceBox::new(); + &non_qr.get_or_init(Self::init_non_qr) + } + } + }); + + output.push(field_and_sqrt_impl); + } } TokenStream::from_iter(output) @@ -735,7 +930,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let mut externs = Vec::new(); let mut setups = Vec::new(); let mut openvm_section = Vec::new(); - let mut setup_all_moduli = Vec::new(); // List of all modular limbs in one (that is, with a compile-time known size) array. let mut two_modular_limbs_flattened_list = Vec::::new(); @@ -794,7 +988,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { span.into(), ); let serialized_len = serialized_modulus.len(); - let setup_function = syn::Ident::new(&format!("setup_{}", mod_idx), span.into()); + let setup_function = syn::Ident::new(&format!("setup_{}", modulus_hex), span.into()); openvm_section.push(quote::quote_spanned! { span.into() => #[cfg(target_os = "zkvm")] @@ -848,23 +1042,58 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { } }); - setup_all_moduli.push(quote::quote_spanned! { span.into() => - #setup_function(); + let hint_non_qr_extern_func = syn::Ident::new( + &format!("hint_non_qr_extern_func_{}", modulus_hex), + span.into(), + ); + externs.push(quote::quote_spanned! { span.into() => + #[no_mangle] + extern "C" fn #hint_non_qr_extern_func() { + openvm::platform::custom_insn_r!( + opcode = ::openvm_algebra_guest::OPCODE, + funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize, + funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintNonQr as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), + rd = Const "x0", + rs1 = Const "x0", + rs2 = Const "x0" + ); + } + }); + + // This function will be defined regardless of whether impl_field is true or false, + // but it will be called only if the impl_field is true. + let hint_sqrt_extern_func = syn::Ident::new( + &format!("hint_sqrt_extern_func_{}", modulus_hex), + span.into(), + ); + externs.push(quote::quote_spanned! { span.into() => + #[no_mangle] + extern "C" fn #hint_sqrt_extern_func(rs1: usize) { + openvm::platform::custom_insn_r!( + opcode = ::openvm_algebra_guest::OPCODE, + funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize, + funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintSqrt as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), + rd = Const "x0", + rs1 = In rs1, + rs2 = Const "x0" + ); + } }); setups.push(quote::quote_spanned! { span.into() => #[allow(non_snake_case)] - pub fn #setup_function() { + #[no_mangle] + extern "C" fn #setup_function() { #[cfg(target_os = "zkvm")] { let mut ptr = 0; - assert_eq!(#serialized_name[ptr], 1); + assert_eq!(super::#serialized_name[ptr], 1); ptr += 1; - assert_eq!(#serialized_name[ptr], #mod_idx as u8); + assert_eq!(super::#serialized_name[ptr], #mod_idx as u8); ptr += 1; - assert_eq!(#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs); + assert_eq!(super::#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs); ptr += 4; - let remaining = &#serialized_name[ptr..]; + let remaining = &super::#serialized_name[ptr..]; // To avoid importing #struct_name, we create a placeholder struct with the same size and alignment. #[repr(C, align(#block_size))] @@ -920,15 +1149,12 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { #[cfg(target_os = "zkvm")] mod openvm_intrinsics_ffi { #(#externs)* + #(#setups)* } #[allow(non_snake_case, non_upper_case_globals)] pub mod openvm_intrinsics_meta_do_not_type_this_by_yourself { pub const two_modular_limbs_list: [u8; #total_limbs_cnt] = [#(#two_modular_limbs_flattened_list),*]; pub const limb_list_borders: [usize; #cnt_limbs_list_len] = [#(#limb_list_borders),*]; } - #(#setups)* - pub fn setup_all_moduli() { - #(#setup_all_moduli)* - } }) } diff --git a/extensions/algebra/tests/programs/Cargo.toml b/extensions/algebra/tests/programs/Cargo.toml index ef4a8c4c1b..9d067de693 100644 --- a/extensions/algebra/tests/programs/Cargo.toml +++ b/extensions/algebra/tests/programs/Cargo.toml @@ -11,6 +11,8 @@ openvm-platform = { path = "../../../../crates/toolchain/platform" } openvm-algebra-guest = { path = "../../guest" } openvm-algebra-moduli-macros = { path = "../../../algebra/moduli-macros", default-features = false } openvm-algebra-complex-macros = { path = "../../../algebra/complex-macros", default-features = false } +#openvm-custom-insn = { path = "../../../../crates/toolchain/custom_insn" } + num-bigint = { version = "0.4", default-features = false } serde = { version = "1.0", default-features = false, features = [ "alloc", diff --git a/extensions/algebra/tests/programs/examples/complex_redundant_modulus.rs b/extensions/algebra/tests/programs/examples/complex_redundant_modulus.rs index 678b8b1838..950f83c1ef 100644 --- a/extensions/algebra/tests/programs/examples/complex_redundant_modulus.rs +++ b/extensions/algebra/tests/programs/examples/complex_redundant_modulus.rs @@ -19,8 +19,6 @@ openvm_algebra_complex_macros::complex_declare! { openvm::init!("openvm_init_complex_redundant_modulus.rs"); pub fn main() { - setup_all_moduli(); - setup_all_complex_extensions(); let b = Complex2::new(Mod3::ZERO, Mod3::from_u32(1000000008)); assert_eq!(b.clone() * &b * &b * &b * &b, b); } diff --git a/extensions/algebra/tests/programs/examples/complex_secp256k1.rs b/extensions/algebra/tests/programs/examples/complex_secp256k1.rs index ab02878cfd..9f47e79be1 100644 --- a/extensions/algebra/tests/programs/examples/complex_secp256k1.rs +++ b/extensions/algebra/tests/programs/examples/complex_secp256k1.rs @@ -16,8 +16,6 @@ openvm_algebra_complex_macros::complex_declare! { openvm::init!("openvm_init_complex_secp256k1.rs"); pub fn main() { - setup_all_moduli(); - setup_all_complex_extensions(); let mut a = Complex::new( Secp256k1Coord::from_repr(core::array::from_fn(|_| 10)), Secp256k1Coord::from_repr(core::array::from_fn(|_| 21)), diff --git a/extensions/algebra/tests/programs/examples/complex_two_moduli.rs b/extensions/algebra/tests/programs/examples/complex_two_moduli.rs index 3e67028722..01c1ba6a63 100644 --- a/extensions/algebra/tests/programs/examples/complex_two_moduli.rs +++ b/extensions/algebra/tests/programs/examples/complex_two_moduli.rs @@ -18,8 +18,6 @@ openvm_algebra_complex_macros::complex_declare! { openvm::init!("openvm_init_complex_two_moduli.rs"); pub fn main() { - setup_all_moduli(); - setup_all_complex_extensions(); let a = Complex1::new(Mod1::ZERO, Mod1::from_u32(998244352)); let b = Complex2::new(Mod2::ZERO, Mod2::from_u32(1000000006)); assert_eq!(a.clone() * &a * &a * &a * &a, a); diff --git a/extensions/algebra/tests/programs/examples/invalid_setup.rs b/extensions/algebra/tests/programs/examples/invalid_setup.rs index 3321dd3238..37d6946b1f 100644 --- a/extensions/algebra/tests/programs/examples/invalid_setup.rs +++ b/extensions/algebra/tests/programs/examples/invalid_setup.rs @@ -18,5 +18,7 @@ openvm_algebra_moduli_macros::moduli_init! { pub fn main() { // this should cause a debug assertion to fail - setup_all_moduli(); + let x = Mod1::from_u32(1); + let y = Mod1::from_u32(1); + let _z = x + y; } diff --git a/extensions/algebra/tests/programs/examples/little.rs b/extensions/algebra/tests/programs/examples/little.rs index 15203d98b7..caeb8de732 100644 --- a/extensions/algebra/tests/programs/examples/little.rs +++ b/extensions/algebra/tests/programs/examples/little.rs @@ -12,7 +12,6 @@ openvm_algebra_moduli_macros::moduli_declare! { openvm::init!("openvm_init_little.rs"); pub fn main() { - setup_all_moduli(); let mut pow = Secp256k1Coord::MODULUS; pow[0] -= 2; diff --git a/extensions/algebra/tests/programs/examples/moduli_setup.rs b/extensions/algebra/tests/programs/examples/moduli_setup.rs index cda883d133..b9abec0a9f 100644 --- a/extensions/algebra/tests/programs/examples/moduli_setup.rs +++ b/extensions/algebra/tests/programs/examples/moduli_setup.rs @@ -18,7 +18,6 @@ openvm_algebra_moduli_macros::moduli_declare! { openvm::init!("openvm_init_moduli_setup.rs"); pub fn main() { - setup_all_moduli(); let x = Bls12381::from_repr(core::array::from_fn(|i| i as u8)); assert_eq!(x.0.len(), 48); diff --git a/extensions/algebra/tests/programs/examples/sqrt.rs b/extensions/algebra/tests/programs/examples/sqrt.rs new file mode 100644 index 0000000000..194e522b97 --- /dev/null +++ b/extensions/algebra/tests/programs/examples/sqrt.rs @@ -0,0 +1,34 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +use openvm_algebra_guest::{Field, IntMod, Sqrt}; + +extern crate alloc; + +openvm::entry!(main); + +openvm_algebra_moduli_macros::moduli_declare! { + Secp256k1Coord { + modulus = "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F", + prime = true + } +} + +openvm::init!(); + +pub fn main() { + let a = Secp256k1Coord::from_u32(4); + let sqrt = a.sqrt(); + assert_eq!(sqrt, Some(Secp256k1Coord::from_u32(2))); + + let b = ::ZERO - ::ONE; + let sqrt = b.sqrt(); + // -1 is not a quadratic residue modulo p when p = 3 mod 4 + // See https://math.stackexchange.com/questions/735400/if-p-equiv-3-mod-4-with-p-prime-prove-1-is-a-non-quadratic-residue-modulo + assert_eq!(sqrt, None); + + let expected = b * Secp256k1Coord::from_u32(2).invert(); + let c = expected.square(); + let result = c.sqrt(); + assert!(result == Some(expected.clone()) || result == Some(-expected)); +} diff --git a/extensions/algebra/tests/src/lib.rs b/extensions/algebra/tests/src/lib.rs index 3232de8ec1..181f592544 100644 --- a/extensions/algebra/tests/src/lib.rs +++ b/extensions/algebra/tests/src/lib.rs @@ -165,4 +165,20 @@ mod tests { .unwrap(); air_test(config, openvm_exe); } + + #[test] + fn test_sqrt() -> Result<()> { + let config = Rv32ModularConfig::new(vec![SECP256K1_CONFIG.modulus.clone()]); + let elf = build_example_program_at_path(get_programs_dir!(), "sqrt", &config)?; + let openvm_exe = VmExe::from_elf( + elf, + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(ModularTranspilerExtension), + )?; + air_test(config, openvm_exe); + Ok(()) + } } diff --git a/extensions/algebra/transpiler/src/lib.rs b/extensions/algebra/transpiler/src/lib.rs index 8785a93480..74d3f9182a 100644 --- a/extensions/algebra/transpiler/src/lib.rs +++ b/extensions/algebra/transpiler/src/lib.rs @@ -3,7 +3,8 @@ use openvm_algebra_guest::{ MODULAR_ARITHMETIC_FUNCT3, OPCODE, }; use openvm_instructions::{ - instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode, + instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant, + VmOpcode, }; use openvm_instructions_derive::LocalOpcode; use openvm_stark_backend::p3_field::PrimeField32; @@ -28,6 +29,13 @@ pub enum Rv32ModularArithmeticOpcode { SETUP_ISEQ, } +#[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)] +#[repr(u16)] +pub enum ModularPhantom { + HintNonQr = 0x50, + HintSqrt = 0x51, +} + #[derive( Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, )] @@ -73,10 +81,10 @@ impl TranspilerExtension for ModularTranspilerExtension { Rv32ModularArithmeticOpcode::COUNT <= ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize ); - let mod_idx_shift = ((dec_insn.funct7 as u8) + let mod_idx = ((dec_insn.funct7 as u8) / ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS) - as usize - * Rv32ModularArithmeticOpcode::COUNT; + as usize; + let mod_idx_shift = mod_idx * Rv32ModularArithmeticOpcode::COUNT; if base_funct7 == ModArithBaseFunct7::SetupMod as u8 { let local_opcode = match dec_insn.rs2 { 0 => Rv32ModularArithmeticOpcode::SETUP_ADDSUB, @@ -100,6 +108,25 @@ impl TranspilerExtension for ModularTranspilerExtension { F::ZERO, )) } + } else if base_funct7 == ModArithBaseFunct7::HintNonQr as u8 { + assert_eq!(dec_insn.rd, 0); + assert_eq!(dec_insn.rs1, 0); + assert_eq!(dec_insn.rs2, 0); + Some(Instruction::phantom( + PhantomDiscriminant(ModularPhantom::HintNonQr as u16), + F::ZERO, + F::ZERO, + mod_idx as u16, + )) + } else if base_funct7 == ModArithBaseFunct7::HintSqrt as u8 { + assert_eq!(dec_insn.rd, 0); + assert_eq!(dec_insn.rs2, 0); + Some(Instruction::phantom( + PhantomDiscriminant(ModularPhantom::HintSqrt as u16), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1), + F::ZERO, + mod_idx as u16, + )) } else { let global_opcode = match ModArithBaseFunct7::from_repr(base_funct7) { Some(ModArithBaseFunct7::AddMod) => { diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs index 64d01a2711..91653d4913 100644 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ b/extensions/ecc/circuit/src/weierstrass_extension.rs @@ -248,9 +248,10 @@ pub(crate) mod phantom { }; use eyre::bail; - use num_bigint::{BigUint, RandBigInt}; + use num_bigint::BigUint; use num_integer::Integer; - use num_traits::{FromPrimitive, One}; + use num_traits::One; + use openvm_algebra_circuit::{find_non_qr, mod_sqrt}; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, system::memory::MemoryController, @@ -259,7 +260,6 @@ pub(crate) mod phantom { use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant}; use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register; use openvm_stark_backend::p3_field::PrimeField32; - use rand::{rngs::StdRng, SeedableRng}; use super::CurveConfig; @@ -382,61 +382,6 @@ pub(crate) mod phantom { } } - /// Find the square root of `x` modulo `modulus` with `non_qr` a - /// quadratic nonresidue of the field. - pub fn mod_sqrt(x: &BigUint, modulus: &BigUint, non_qr: &BigUint) -> Option { - if modulus % 4u32 == BigUint::from_u8(3).unwrap() { - // x^(1/2) = x^((p+1)/4) when p = 3 mod 4 - let exponent = (modulus + BigUint::one()) >> 2; - let ret = x.modpow(&exponent, modulus); - if &ret * &ret % modulus == x % modulus { - Some(ret) - } else { - None - } - } else { - // Tonelli-Shanks algorithm - // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm#The_algorithm - let mut q = modulus - BigUint::one(); - let mut s = 0; - while &q % 2u32 == BigUint::ZERO { - s += 1; - q /= 2u32; - } - let z = non_qr; - let mut m = s; - let mut c = z.modpow(&q, modulus); - let mut t = x.modpow(&q, modulus); - let mut r = x.modpow(&((q + BigUint::one()) >> 1), modulus); - loop { - if t == BigUint::ZERO { - return Some(BigUint::ZERO); - } - if t == BigUint::one() { - return Some(r); - } - let mut i = 0; - let mut tmp = t.clone(); - while tmp != BigUint::one() && i < m { - tmp = &tmp * &tmp % modulus; - i += 1; - } - if i == m { - // self is not a quadratic residue - return None; - } - for _ in 0..m - i - 1 { - c = &c * &c % modulus; - } - let b = c; - m = i; - c = &b * &b % modulus; - t = ((t * &b % modulus) * &b) % modulus; - r = (r * b) % modulus; - } - } - } - #[derive(Clone)] pub struct NonQrHintSubEx { pub supported_curves: Vec, @@ -494,33 +439,4 @@ pub(crate) mod phantom { Ok(()) } } - - // Returns a non-quadratic residue in the field - fn find_non_qr(modulus: &BigUint) -> BigUint { - if modulus % 4u32 == BigUint::from(3u8) { - // p = 3 mod 4 then -1 is a quadratic residue - modulus - BigUint::one() - } else if modulus % 8u32 == BigUint::from(5u8) { - // p = 5 mod 8 then 2 is a non-quadratic residue - // since 2^((p-1)/2) = (-1)^((p^2-1)/8) - BigUint::from_u8(2u8).unwrap() - } else { - let mut rng = StdRng::from_entropy(); - let mut non_qr = rng.gen_biguint_range( - &BigUint::from_u8(2).unwrap(), - &(modulus - BigUint::from_u8(1).unwrap()), - ); - // To check if non_qr is a quadratic nonresidue, we compute non_qr^((p-1)/2) - // If the result is p-1, then non_qr is a quadratic nonresidue - // Otherwise, non_qr is a quadratic residue - let exponent = (modulus - BigUint::one()) >> 1; - while non_qr.modpow(&exponent, modulus) != modulus - BigUint::one() { - non_qr = rng.gen_biguint_range( - &BigUint::from_u8(2).unwrap(), - &(modulus - BigUint::from_u8(1).unwrap()), - ); - } - non_qr - } - } } diff --git a/extensions/ecc/guest/src/lib.rs b/extensions/ecc/guest/src/lib.rs index e780ccf891..823c87261f 100644 --- a/extensions/ecc/guest/src/lib.rs +++ b/extensions/ecc/guest/src/lib.rs @@ -3,8 +3,6 @@ extern crate self as openvm_ecc_guest; #[macro_use] extern crate alloc; -#[cfg(feature = "halo2curves")] -pub use halo2curves_axiom as halo2curves; pub use once_cell; pub use openvm_algebra_guest as algebra; pub use openvm_ecc_sw_macros as sw_macros; diff --git a/extensions/ecc/guest/src/weierstrass.rs b/extensions/ecc/guest/src/weierstrass.rs index e1783e429d..276c2045e5 100644 --- a/extensions/ecc/guest/src/weierstrass.rs +++ b/extensions/ecc/guest/src/weierstrass.rs @@ -426,7 +426,7 @@ macro_rules! impl_sw_group_ops { p2.clone() } else if p2.is_identity() { self.clone() - } else if self.x() == p2.x() { + } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) { if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO { <$struct_name as WeierstrassPoint>::IDENTITY } else { @@ -444,7 +444,7 @@ macro_rules! impl_sw_group_ops { *self = p2.clone(); } else if p2.is_identity() { // do nothing - } else if self.x() == p2.x() { + } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) { if self.y() + p2.y() == <$field as openvm_algebra_guest::Field>::ZERO { *self = <$struct_name as WeierstrassPoint>::IDENTITY; } else { @@ -486,7 +486,7 @@ macro_rules! impl_sw_group_ops { self.clone() } else if self.is_identity() { core::ops::Neg::neg(p2) - } else if self.x() == p2.x() { + } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) { if self.y() == p2.y() { <$struct_name as WeierstrassPoint>::IDENTITY } else { @@ -504,7 +504,7 @@ macro_rules! impl_sw_group_ops { // do nothing } else if self.is_identity() { *self = core::ops::Neg::neg(p2); - } else if self.x() == p2.x() { + } else if WeierstrassPoint::x(self) == WeierstrassPoint::x(p2) { if self.y() == p2.y() { *self = <$struct_name as WeierstrassPoint>::IDENTITY } else { diff --git a/extensions/ecc/sw-macros/README.md b/extensions/ecc/sw-macros/README.md index 8b2da66e70..368a2e6aed 100644 --- a/extensions/ecc/sw-macros/README.md +++ b/extensions/ecc/sw-macros/README.md @@ -33,8 +33,6 @@ openvm_ecc_guest::sw_macros::sw_init! { */ pub fn main() { - setup_all_moduli(); - setup_all_curves(); // ... } ``` @@ -90,13 +88,9 @@ pub fn setup_sw_Secp256k1Point() { // ... } } -pub fn setup_all_curves() { - setup_sw_Secp256k1Point(); - // other setups -} ``` -3. Again, the `setup` function for every used curve must be called before any other instructions for that curve. If all curves are used, one can call `setup_all_curves()` to setup all of them. +3. Again, the `setup` function for every curve is automatically called on first use of any of the curve's intrinsics. 4. The order of the items in `sw_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `WeierstrassExtension::supported_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). diff --git a/extensions/ecc/sw-macros/src/lib.rs b/extensions/ecc/sw-macros/src/lib.rs index c739b1965c..f1ab6d02af 100644 --- a/extensions/ecc/sw-macros/src/lib.rs +++ b/extensions/ecc/sw-macros/src/lib.rs @@ -90,6 +90,8 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { create_extern_func!(hint_decompress_extern_func); create_extern_func!(hint_non_qr_extern_func); + let setup_function = syn::Ident::new(&format!("setup_sw_{}", struct_name), span.into()); + let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase()); let result = TokenStream::from(quote::quote_spanned! { span.into() => @@ -98,6 +100,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { fn #sw_double_extern_func(rd: usize, rs1: usize); fn #hint_decompress_extern_func(rs1: usize, rs2: usize); fn #hint_non_qr_extern_func(); + fn #setup_function(); } #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -129,6 +132,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit(); unsafe { #sw_add_ne_extern_func( @@ -154,6 +158,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #sw_add_ne_extern_func( self as *mut #struct_name as usize, @@ -179,6 +184,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit(); unsafe { #sw_double_extern_func( @@ -198,6 +204,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + Self::assert_is_setup(); unsafe { #sw_double_extern_func( self as *mut #struct_name as usize, @@ -207,6 +214,14 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } } + // Helper function to call the setup instruction on first use + fn assert_is_setup() { + static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); + is_setup.get_or_init(|| { + unsafe { #setup_function(); } + true + }); + } } impl ::openvm_ecc_guest::weierstrass::WeierstrassPoint for #struct_name { @@ -447,7 +462,6 @@ pub fn sw_init(input: TokenStream) -> TokenStream { let mut externs = Vec::new(); let mut setups = Vec::new(); - let mut setup_all_curves = Vec::new(); let span = proc_macro::Span::call_site(); @@ -527,9 +541,12 @@ pub fn sw_init(input: TokenStream) -> TokenStream { let setup_function = syn::Ident::new(&format!("setup_sw_{}", str_path), span.into()); setups.push(quote::quote_spanned! { span.into() => #[allow(non_snake_case)] - pub fn #setup_function() { + #[no_mangle] + extern "C" fn #setup_function() { #[cfg(target_os = "zkvm")] { + openvm::io::println("setup function called"); + use super::#item; // p1 is (x1, y1), and x1 must be the modulus. // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble let modulus_bytes = <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; @@ -564,10 +581,6 @@ pub fn sw_init(input: TokenStream) -> TokenStream { } } }); - - setup_all_curves.push(quote::quote_spanned! { span.into() => - #setup_function(); - }); } TokenStream::from(quote::quote_spanned! { span.into() => @@ -576,10 +589,7 @@ pub fn sw_init(input: TokenStream) -> TokenStream { use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7}; #(#externs)* - } - #(#setups)* - pub fn setup_all_curves() { - #(#setup_all_curves)* + #(#setups)* } }) } diff --git a/extensions/ecc/tests/Cargo.toml b/extensions/ecc/tests/Cargo.toml index 9a743ac00b..b877277da6 100644 --- a/extensions/ecc/tests/Cargo.toml +++ b/extensions/ecc/tests/Cargo.toml @@ -23,6 +23,8 @@ openvm-sdk.workspace = true eyre.workspace = true hex-literal.workspace = true num-bigint.workspace = true +halo2curves-axiom = { workspace = true } + [features] default = ["parallel"] parallel = ["openvm-circuit/parallel"] diff --git a/extensions/ecc/tests/programs/examples/decompress.rs b/extensions/ecc/tests/programs/examples/decompress.rs index 8692b13675..206562ef90 100644 --- a/extensions/ecc/tests/programs/examples/decompress.rs +++ b/extensions/ecc/tests/programs/examples/decompress.rs @@ -77,11 +77,6 @@ openvm::init!("openvm_init_decompress.rs"); // test decompression under an honest host pub fn main() { - setup_0(); - setup_2(); - setup_4(); - setup_all_curves(); - let bytes = read_vec(); let x = Secp256k1Coord::from_le_bytes(&bytes[..32]); let y = Secp256k1Coord::from_le_bytes(&bytes[32..64]); diff --git a/extensions/ecc/tests/programs/examples/decompress_invalid_hint.rs b/extensions/ecc/tests/programs/examples/decompress_invalid_hint.rs index c55a3e329d..c341ef6311 100644 --- a/extensions/ecc/tests/programs/examples/decompress_invalid_hint.rs +++ b/extensions/ecc/tests/programs/examples/decompress_invalid_hint.rs @@ -192,11 +192,6 @@ type CurvePoint1mod4Wrapper = CurvePointWrapper; // Check that decompress enters an infinite loop when hint_decompress returns an incorrect value. pub fn main() { - setup_0(); - setup_2(); - setup_4(); - setup_all_curves(); - let bytes = read_vec(); test_p_3_mod_4(&bytes[..32], &bytes[32..64]); diff --git a/extensions/ecc/tests/programs/examples/ec.rs b/extensions/ecc/tests/programs/examples/ec.rs index 056e12f41d..b4a03bea6c 100644 --- a/extensions/ecc/tests/programs/examples/ec.rs +++ b/extensions/ecc/tests/programs/examples/ec.rs @@ -15,9 +15,6 @@ openvm::init!("openvm_init_ec.rs"); openvm::entry!(main); pub fn main() { - setup_all_moduli(); - setup_all_curves(); - // Sample points got from https://asecuritysite.com/ecc/ecc_points2 and // https://learnmeabitcoin.com/technical/cryptography/elliptic-curve/#add let x1 = Secp256k1Coord::from_u32(1); diff --git a/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs b/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs index a25277b4bb..2402151bdf 100644 --- a/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs +++ b/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs @@ -14,9 +14,6 @@ openvm::entry!(main); openvm::init!("openvm_init_ec_nonzero_a.rs"); pub fn main() { - setup_all_moduli(); - setup_all_curves(); - // Sample points got from https://asecuritysite.com/ecc/p256p let x1 = P256Coord::from_u32(5); let y1 = P256Coord::from_le_bytes(&hex!( diff --git a/extensions/ecc/tests/programs/examples/ec_two_curves.rs b/extensions/ecc/tests/programs/examples/ec_two_curves.rs index 499d678593..21cdb2eb51 100644 --- a/extensions/ecc/tests/programs/examples/ec_two_curves.rs +++ b/extensions/ecc/tests/programs/examples/ec_two_curves.rs @@ -16,9 +16,6 @@ openvm::init!("openvm_init_ec_two_curves.rs"); openvm::entry!(main); pub fn main() { - setup_all_moduli(); - setup_all_curves(); - // Sample points got from https://asecuritysite.com/ecc/ecc_points2 and // https://learnmeabitcoin.com/technical/cryptography/elliptic-curve/#add let x1 = Secp256k1Coord::from_u32(1); diff --git a/extensions/ecc/tests/programs/examples/ecdsa.rs b/extensions/ecc/tests/programs/examples/ecdsa.rs index 80b0734941..da556cd096 100644 --- a/extensions/ecc/tests/programs/examples/ecdsa.rs +++ b/extensions/ecc/tests/programs/examples/ecdsa.rs @@ -19,9 +19,6 @@ openvm::init!("openvm_init_ecdsa.rs"); // Ref: https://docs.rs/k256/latest/k256/ecdsa/index.html pub fn main() { - setup_all_moduli(); - setup_all_curves(); - let msg = b"example message"; let signature = hex!( diff --git a/extensions/ecc/tests/programs/examples/invalid_setup.rs b/extensions/ecc/tests/programs/examples/invalid_setup.rs index 9c6e8a3a51..44b70f1d59 100644 --- a/extensions/ecc/tests/programs/examples/invalid_setup.rs +++ b/extensions/ecc/tests/programs/examples/invalid_setup.rs @@ -2,7 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #[allow(unused_imports)] -use openvm_ecc_guest::{k256::Secp256k1Point, p256::P256Point}; +use openvm_ecc_guest::{k256::Secp256k1Point, p256::P256Point, CyclicGroup}; openvm_algebra_moduli_macros::moduli_init! { "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F", @@ -20,7 +20,7 @@ openvm_ecc_sw_macros::sw_init! { openvm::entry!(main); pub fn main() { - setup_all_moduli(); // this should cause a debug assertion to fail - setup_all_curves(); + let p1 = Secp256k1Point::GENERATOR; + let _p2 = &p1 + &p1; } diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index dd3f885e90..37a71c9e3b 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -97,7 +97,7 @@ mod tests { #[test] fn test_decompress() -> Result<()> { - use openvm_ecc_guest::halo2curves::{group::Curve, secp256k1::Secp256k1Affine}; + use halo2curves_axiom::{group::Curve, secp256k1::Secp256k1Affine}; let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), @@ -162,7 +162,7 @@ mod tests { } fn test_decompress_invalid_specific_test(test_type: &str) -> Result<()> { - use openvm_ecc_guest::halo2curves::{group::Curve, secp256k1::Secp256k1Affine}; + use halo2curves_axiom::{group::Curve, secp256k1::Secp256k1Affine}; let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), diff --git a/extensions/keccak256/guest/src/lib.rs b/extensions/keccak256/guest/src/lib.rs index 459c4c910d..7e2bb3da54 100644 --- a/extensions/keccak256/guest/src/lib.rs +++ b/extensions/keccak256/guest/src/lib.rs @@ -1,30 +1,10 @@ #![no_std] -#[cfg(target_os = "zkvm")] -use core::mem::MaybeUninit; - /// This is custom-0 defined in RISC-V spec document pub const OPCODE: u8 = 0x0b; pub const KECCAK256_FUNCT3: u8 = 0b100; pub const KECCAK256_FUNCT7: u8 = 0; -/// The keccak256 cryptographic hash function. -#[inline(always)] -pub fn keccak256(input: &[u8]) -> [u8; 32] { - #[cfg(not(target_os = "zkvm"))] - { - let mut output = [0u8; 32]; - set_keccak256(input, &mut output); - output - } - #[cfg(target_os = "zkvm")] - { - let mut output = MaybeUninit::<[u8; 32]>::uninit(); - native_keccak256(input.as_ptr(), input.len(), output.as_mut_ptr() as *mut u8); - unsafe { output.assume_init() } - } -} - /// Native hook for keccak256 for use with `alloy-primitives` "native-keccak" feature. /// /// # Safety @@ -40,7 +20,7 @@ pub fn keccak256(input: &[u8]) -> [u8; 32] { #[cfg(target_os = "zkvm")] #[inline(always)] #[no_mangle] -extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { +pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!( opcode = OPCODE, funct3 = KECCAK256_FUNCT3, @@ -50,16 +30,3 @@ extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { rs2 = In len ); } - -/// Sets `output` to the keccak256 hash of `input`. -pub fn set_keccak256(input: &[u8], output: &mut [u8; 32]) { - #[cfg(not(target_os = "zkvm"))] - { - use tiny_keccak::Hasher; - let mut hasher = tiny_keccak::Keccak::v256(); - hasher.update(input); - hasher.finalize(output); - } - #[cfg(target_os = "zkvm")] - native_keccak256(input.as_ptr(), input.len(), output.as_mut_ptr() as *mut u8); -} diff --git a/extensions/keccak256/tests/Cargo.toml b/extensions/keccak256/tests/Cargo.toml deleted file mode 100644 index d79e18dd5f..0000000000 --- a/extensions/keccak256/tests/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "openvm-keccak256-integration-tests" -description = "Integration tests for the OpenVM keccak256 extension" -version.workspace = true -authors.workspace = true -edition.workspace = true -homepage.workspace = true -repository.workspace = true - -[dependencies] -openvm-instructions = { workspace = true } -openvm-stark-sdk.workspace = true -openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-transpiler.workspace = true -openvm-keccak256-transpiler.workspace = true -openvm-keccak256-circuit.workspace = true -openvm-rv32im-transpiler.workspace = true -openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } -eyre.workspace = true - -[features] -default = ["parallel"] -parallel = ["openvm-circuit/parallel"] diff --git a/extensions/keccak256/tests/programs/Cargo.toml b/extensions/keccak256/tests/programs/Cargo.toml deleted file mode 100644 index 8eb24c3af1..0000000000 --- a/extensions/keccak256/tests/programs/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[workspace] -[package] -name = "openvm-keccak256-test-programs" -version = "0.0.0" -edition = "2021" - -[dependencies] -openvm = { path = "../../../../crates/toolchain/openvm" } -openvm-platform = { path = "../../../../crates/toolchain/platform" } -openvm-keccak256-guest = { path = "../../guest" } -hex = { version = "0.4.3", default-features = false, features = ["alloc"] } -serde = { version = "1.0", default-features = false, features = [ - "alloc", - "derive", -] } - - -[features] -default = [] -std = ["serde/std", "openvm/std"] - -[profile.release] -panic = "abort" -lto = "thin" # turn on lto = fat to decrease binary size, but this optimizes out some missing extern links so we shouldn't use it for testing -# strip = "symbols" diff --git a/extensions/keccak256/tests/programs/examples/keccak.rs b/extensions/keccak256/tests/programs/examples/keccak.rs deleted file mode 100644 index d850bfdab2..0000000000 --- a/extensions/keccak256/tests/programs/examples/keccak.rs +++ /dev/null @@ -1,32 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_main)] -#![cfg_attr(not(feature = "std"), no_std)] - -extern crate alloc; - -use alloc::vec::Vec; -use core::hint::black_box; - -use hex::FromHex; -use openvm_keccak256_guest::keccak256; - -openvm::entry!(main); - -pub fn main() { - let test_vectors = [ - ("", "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470"), // ShortMsgKAT_256 Len = 0 - ("CC", "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A"), // ShortMsgKAT_256 Len = 8 - ("B55C10EAE0EC684C16D13463F29291BF26C82E2FA0422A99C71DB4AF14DD9C7F33EDA52FD73D017CC0F2DBE734D831F0D820D06D5F89DACC485739144F8CFD4799223B1AFF9031A105CB6A029BA71E6E5867D85A554991C38DF3C9EF8C1E1E9A7630BE61CAABCA69280C399C1FB7A12D12AEFC", "0347901965D3635005E75A1095695CCA050BC9ED2D440C0372A31B348514A889"), // ShortMsgKAT_256 Len = 920 - ("2EDC282FFB90B97118DD03AAA03B145F363905E3CBD2D50ECD692B37BF000185C651D3E9726C690D3773EC1E48510E42B17742B0B0377E7DE6B8F55E00A8A4DB4740CEE6DB0830529DD19617501DC1E9359AA3BCF147E0A76B3AB70C4984C13E339E6806BB35E683AF8527093670859F3D8A0FC7D493BCBA6BB12B5F65E71E705CA5D6C948D66ED3D730B26DB395B3447737C26FAD089AA0AD0E306CB28BF0ACF106F89AF3745F0EC72D534968CCA543CD2CA50C94B1456743254E358C1317C07A07BF2B0ECA438A709367FAFC89A57239028FC5FECFD53B8EF958EF10EE0608B7F5CB9923AD97058EC067700CC746C127A61EE3", "DD1D2A92B3F3F3902F064365838E1F5F3468730C343E2974E7A9ECFCD84AA6DB"), // ShortMsgKAT_256 Len = 1952, - ("724627916C50338643E6996F07877EAFD96BDF01DA7E991D4155B9BE1295EA7D21C9391F4C4A41C75F77E5D27389253393725F1427F57914B273AB862B9E31DABCE506E558720520D33352D119F699E784F9E548FF91BC35CA147042128709820D69A8287EA3257857615EB0321270E94B84F446942765CE882B191FAEE7E1C87E0F0BD4E0CD8A927703524B559B769CA4ECE1F6DBF313FDCF67C572EC4185C1A88E86EC11B6454B371980020F19633B6B95BD280E4FBCB0161E1A82470320CEC6ECFA25AC73D09F1536F286D3F9DACAFB2CD1D0CE72D64D197F5C7520B3CCB2FD74EB72664BA93853EF41EABF52F015DD591500D018DD162815CC993595B195", "EA0E416C0F7B4F11E3F00479FDDF954F2539E5E557753BD546F69EE375A5DE29"), // LongMsgKAT_256 Len = 2048 - ("6E1CADFB2A14C5FFB1DD69919C0124ED1B9A414B2BEA1E5E422D53B022BDD13A9C88E162972EBB9852330006B13C5B2F2AFBE754AB7BACF12479D4558D19DDBB1A6289387B3AC084981DF335330D1570850B97203DBA5F20CF7FF21775367A8401B6EBE5B822ED16C39383232003ABC412B0CE0DD7C7DA064E4BB73E8C58F222A1512D5FE6D947316E02F8AA87E7AA7A3AA1C299D92E6414AE3B927DB8FF708AC86A09B24E1884743BC34067BB0412453B4A6A6509504B550F53D518E4BCC3D9C1EFDB33DA2EACCB84C9F1CAEC81057A8508F423B25DB5500E5FC86AB3B5EB10D6D0BF033A716DDE55B09FD53451BBEA644217AE1EF91FAD2B5DCC6515249C96EE7EABFD12F1EF65256BD1CFF2087DABF2F69AD1FFB9CF3BC8CA437C7F18B6095BC08D65DF99CC7F657C418D8EB109FDC91A13DC20A438941726EF24F9738B6552751A320C4EA9C8D7E8E8592A3B69D30A419C55FB6CB0850989C029AAAE66305E2C14530B39EAA86EA3BA2A7DECF4B2848B01FAA8AA91F2440B7CC4334F63061CE78AA1589BEFA38B194711697AE3AADCB15C9FBF06743315E2F97F1A8B52236ACB444069550C2345F4ED12E5B8E881CDD472E803E5DCE63AE485C2713F81BC307F25AC74D39BAF7E3BC5E7617465C2B9C309CB0AC0A570A7E46C6116B2242E1C54F456F6589E20B1C0925BF1CD5F9344E01F63B5BA9D4671ABBF920C7ED32937A074C33836F0E019DFB6B35D865312C6058DFDAFF844C8D58B75071523E79DFBAB2EA37479DF12C474584F4FF40F00F92C6BADA025CE4DF8FAF0AFB2CE75C07773907CA288167D6B011599C3DE0FFF16C1161D31DF1C1DDE217CB574ED5A33751759F8ED2B1E6979C5088B940926B9155C9D250B479948C20ACB5578DC02C97593F646CC5C558A6A0F3D8D273258887CCFF259197CB1A7380622E371FD2EB5376225EC04F9ED1D1F2F08FA2376DB5B790E73086F581064ED1C5F47E989E955D77716B50FB64B853388FBA01DAC2CEAE99642341F2DA64C56BEFC4789C051E5EB79B063F2F084DB4491C3C5AA7B4BCF7DD7A1D7CED1554FA67DCA1F9515746A237547A4A1D22ACF649FA1ED3B9BB52BDE0C6996620F8CFDB293F8BACAD02BCE428363D0BB3D391469461D212769048219220A7ED39D1F9157DFEA3B4394CA8F5F612D9AC162BF0B961BFBC157E5F863CE659EB235CF98E8444BC8C7880BDDCD0B3B389AAA89D5E05F84D0649EEBACAB4F1C75352E89F0E9D91E4ACA264493A50D2F4AED66BD13650D1F18E7199E931C78AEB763E903807499F1CD99AF81276B615BE8EC709B039584B2B57445B014F6162577F3548329FD288B0800F936FC5EA1A412E3142E609FC8E39988CA53DF4D8FB5B5FB5F42C0A01648946AC6864CFB0E92856345B08E5DF0D235261E44CFE776456B40AEF0AC1A0DFA2FE639486666C05EA196B0C1A9D346435E03965E6139B1CE10129F8A53745F80100A94AE04D996C13AC14CF2713E39DFBB19A936CF3861318BD749B1FB82F40D73D714E406CBEB3D920EA037B7DE566455CCA51980F0F53A762D5BF8A4DBB55AAC0EDDB4B1F2AED2AA3D01449D34A57FDE4329E7FF3F6BECE4456207A4225218EE9F174C2DE0FF51CEAF2A07CF84F03D1DF316331E3E725C5421356C40ED25D5ABF9D24C4570FED618CA41000455DBD759E32E2BF0B6C5E61297C20F752C3042394CE840C70943C451DD5598EB0E4953CE26E833E5AF64FC1007C04456D19F87E45636F456B7DC9D31E757622E2739573342DE75497AE181AAE7A5425756C8E2A7EEF918E5C6A968AEFE92E8B261BBFE936B19F9E69A3C90094096DAE896450E1505ED5828EE2A7F0EA3A28E6EC47C0AF711823E7689166EA07ECA00FFC493131D65F93A4E1D03E0354AFC2115CFB8D23DAE8C6F96891031B23226B8BC82F1A73DAA5BB740FC8CC36C0975BEFA0C7895A9BBC261EDB7FD384103968F7A18353D5FE56274E4515768E4353046C785267DE01E816A2873F97AAD3AB4D7234EBFD9832716F43BE8245CF0B4408BA0F0F764CE9D24947AB6ABDD9879F24FCFF10078F5894B0D64F6A8D3EA3DD92A0C38609D3C14FDC0A44064D501926BE84BF8034F1D7A8C5F382E6989BFFA2109D4FBC56D1F091E8B6FABFF04D21BB19656929D19DECB8E8291E6AE5537A169874E0FE9890DFF11FFD159AD23D749FB9E8B676E2C31313C16D1EFA06F4D7BC191280A4EE63049FCEF23042B20303AECDD412A526D7A53F760A089FBDF13F361586F0DCA76BB928EDB41931D11F679619F948A6A9E8DBA919327769006303C6EF841438A7255C806242E2E7FF4621BB0F8AFA0B4A248EAD1A1E946F3E826FBFBBF8013CE5CC814E20FEF21FA5DB19EC7FF0B06C592247B27E500EB4705E6C37D41D09E83CB0A618008CA1AAAE8A215171D817659063C2FA385CFA3C1078D5C2B28CE7312876A276773821BE145785DFF24BBB24D590678158A61EA49F2BE56FDAC8CE7F94B05D62F15ADD351E5930FD4F31B3E7401D5C0FF7FC845B165FB6ABAFD4788A8B0615FEC91092B34B710A68DA518631622BA2AAE5D19010D307E565A161E64A4319A6B261FB2F6A90533997B1AEC32EF89CF1F232696E213DAFE4DBEB1CF1D5BBD12E5FF2EBB2809184E37CD9A0E58A4E0AF099493E6D8CC98B05A2F040A7E39515038F6EE21FC25F8D459A327B83EC1A28A234237ACD52465506942646AC248EC96EBBA6E1B092475F7ADAE4D35E009FD338613C7D4C12E381847310A10E6F02C02392FC32084FBE939689BC6518BE27AF7842DEEA8043828E3DFFE3BBAC4794CA0CC78699722709F2E4B0EAE7287DEB06A27B462423EC3F0DF227ACF589043292685F2C0E73203E8588B62554FF19D6260C7FE48DF301509D33BE0D8B31D3F658C921EF7F55449FF3887D91BFB894116DF57206098E8C5835B", "3C79A3BD824542C20AF71F21D6C28DF2213A041F77DD79A328A0078123954E7B"), // LongMsgKAT_256 Len = 16664 - ("7ADC0B6693E61C269F278E6944A5A2D8300981E40022F839AC644387BFAC9086650085C2CDC585FEA47B9D2E52D65A2B29A7DC370401EF5D60DD0D21F9E2B90FAE919319B14B8C5565B0423CEFB827D5F1203302A9D01523498A4DB10374", "4CC2AFF141987F4C2E683FA2DE30042BACDCD06087D7A7B014996E9CFEAA58CE"), // ShortMsgKAT_256 Len = 752 - ]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = keccak256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} diff --git a/extensions/keccak256/tests/src/lib.rs b/extensions/keccak256/tests/src/lib.rs deleted file mode 100644 index 69b1ef65d3..0000000000 --- a/extensions/keccak256/tests/src/lib.rs +++ /dev/null @@ -1,32 +0,0 @@ -#[cfg(test)] -mod tests { - use eyre::Result; - use openvm_circuit::utils::air_test; - use openvm_instructions::exe::VmExe; - use openvm_keccak256_circuit::Keccak256Rv32Config; - use openvm_keccak256_transpiler::Keccak256TranspilerExtension; - use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, - }; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; - use openvm_transpiler::{transpiler::Transpiler, FromElf}; - - type F = BabyBear; - - #[test] - fn test_keccak256() -> Result<()> { - let config = Keccak256Rv32Config::default(); - let elf = build_example_program_at_path(get_programs_dir!(), "keccak", &config)?; - let openvm_exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Keccak256TranspilerExtension) - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - )?; - air_test(config, openvm_exe); - Ok(()) - } -} diff --git a/extensions/pairing/circuit/Cargo.toml b/extensions/pairing/circuit/Cargo.toml index 6ec8891541..af16f7eeab 100644 --- a/extensions/pairing/circuit/Cargo.toml +++ b/extensions/pairing/circuit/Cargo.toml @@ -36,6 +36,7 @@ rand = { workspace = true } itertools = { workspace = true } eyre = { workspace = true } serde = { workspace = true, features = ["derive", "std"] } +halo2curves-axiom = { workspace = true } [target.'cfg(not(target_os = "zkvm"))'.dependencies] openvm-pairing-guest = { workspace = true } diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index b1e6afe38d..c75687f404 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -103,11 +103,12 @@ pub(crate) mod phantom { use std::collections::VecDeque; use eyre::bail; + use halo2curves_axiom::ff; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, system::memory::MemoryController, }; - use openvm_ecc_guest::{algebra::field::FieldExtension, halo2curves::ff, AffinePoint}; + use openvm_ecc_guest::{algebra::field::FieldExtension, AffinePoint}; use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS}, PhantomDiscriminant, @@ -168,7 +169,7 @@ pub(crate) mod phantom { match PairingCurve::from_repr(c_upper as usize) { Some(PairingCurve::Bn254) => { - use openvm_ecc_guest::halo2curves::bn256::{Fq, Fq12, Fq2}; + use halo2curves_axiom::bn256::{Fq, Fq12, Fq2}; use openvm_pairing_guest::halo2curves_shims::bn254::Bn254; const N: usize = BN254_NUM_LIMBS; if p_len != q_len { @@ -210,7 +211,7 @@ pub(crate) mod phantom { ); } Some(PairingCurve::Bls12_381) => { - use openvm_ecc_guest::halo2curves::bls12_381::{Fq, Fq12, Fq2}; + use halo2curves_axiom::bls12_381::{Fq, Fq12, Fq2}; use openvm_pairing_guest::halo2curves_shims::bls12_381::Bls12_381; const N: usize = BLS12_381_NUM_LIMBS; if p_len != q_len { diff --git a/extensions/pairing/tests/Cargo.toml b/extensions/pairing/tests/Cargo.toml index a88caa0de3..6a337bfab0 100644 --- a/extensions/pairing/tests/Cargo.toml +++ b/extensions/pairing/tests/Cargo.toml @@ -27,6 +27,7 @@ eyre.workspace = true rand.workspace = true num-bigint.workspace = true num-traits.workspace = true +halo2curves-axiom = { workspace = true } [features] default = ["parallel"] diff --git a/extensions/pairing/tests/programs/examples/bls_ec.rs b/extensions/pairing/tests/programs/examples/bls_ec.rs index c7e13f626a..93f72115cf 100644 --- a/extensions/pairing/tests/programs/examples/bls_ec.rs +++ b/extensions/pairing/tests/programs/examples/bls_ec.rs @@ -8,7 +8,4 @@ openvm::init!("openvm_init_bls_ec_bls12_381.rs"); openvm::entry!(main); -pub fn main() { - setup_all_moduli(); - setup_all_curves(); -} +pub fn main() {} diff --git a/extensions/pairing/tests/programs/examples/fp12_mul.rs b/extensions/pairing/tests/programs/examples/fp12_mul.rs index 0ba32c336d..aeec2fb63d 100644 --- a/extensions/pairing/tests/programs/examples/fp12_mul.rs +++ b/extensions/pairing/tests/programs/examples/fp12_mul.rs @@ -16,8 +16,6 @@ mod bn254 { openvm::init!("openvm_init_fp12_mul_bn254.rs"); pub fn test_fp12_mul(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); assert_eq!(io.len(), 32 * 36); let f0 = &io[0..32 * 12]; @@ -48,8 +46,6 @@ mod bls12_381 { openvm::init!("openvm_init_fp12_mul_bls12_381.rs"); pub fn test_fp12_mul(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); assert_eq!(io.len(), 48 * 36); let f0 = &io[0..48 * 12]; diff --git a/extensions/pairing/tests/programs/examples/pairing_check.rs b/extensions/pairing/tests/programs/examples/pairing_check.rs index e3be6a8a34..c01caded79 100644 --- a/extensions/pairing/tests/programs/examples/pairing_check.rs +++ b/extensions/pairing/tests/programs/examples/pairing_check.rs @@ -22,8 +22,6 @@ mod bn254 { openvm::init!("openvm_init_pairing_check_bn254.rs"); pub fn test_pairing_check(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); let s0 = &io[0..32 * 2]; let s1 = &io[32 * 2..32 * 4]; let q0 = &io[32 * 4..32 * 8]; @@ -57,8 +55,6 @@ mod bls12_381 { openvm::init!("openvm_init_pairing_check_bls12_381.rs"); pub fn test_pairing_check(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); let s0 = &io[0..48 * 2]; let s1 = &io[48 * 2..48 * 4]; let q0 = &io[48 * 4..48 * 8]; diff --git a/extensions/pairing/tests/programs/examples/pairing_check_fallback.rs b/extensions/pairing/tests/programs/examples/pairing_check_fallback.rs index f9b504bb71..da3bcbb16f 100644 --- a/extensions/pairing/tests/programs/examples/pairing_check_fallback.rs +++ b/extensions/pairing/tests/programs/examples/pairing_check_fallback.rs @@ -94,8 +94,6 @@ mod bn254 { } pub fn test_pairing_check(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); let s0 = &io[0..32 * 2]; let s1 = &io[32 * 2..32 * 4]; let q0 = &io[32 * 4..32 * 8]; @@ -202,8 +200,6 @@ mod bls12_381 { } pub fn test_pairing_check(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); let s0 = &io[0..48 * 2]; let s1 = &io[48 * 2..48 * 4]; let q0 = &io[48 * 4..48 * 8]; diff --git a/extensions/pairing/tests/programs/examples/pairing_line.rs b/extensions/pairing/tests/programs/examples/pairing_line.rs index 36a9d1b4ef..b36c200391 100644 --- a/extensions/pairing/tests/programs/examples/pairing_line.rs +++ b/extensions/pairing/tests/programs/examples/pairing_line.rs @@ -136,15 +136,11 @@ pub fn main() { #[cfg(feature = "bn254")] { - bn254::setup_0(); - bn254::setup_all_complex_extensions(); bn254::test_mul_013_by_013(&io[..32 * 18]); bn254::test_mul_by_01234(&io[32 * 18..32 * 52]); } #[cfg(feature = "bls12_381")] { - bls12_381::setup_0(); - bls12_381::setup_all_complex_extensions(); bls12_381::test_mul_023_by_023(&io[..48 * 18]); bls12_381::test_mul_by_02345(&io[48 * 18..48 * 52]); } diff --git a/extensions/pairing/tests/programs/examples/pairing_miller_loop.rs b/extensions/pairing/tests/programs/examples/pairing_miller_loop.rs index afc1ae37a4..a9d8f09dbd 100644 --- a/extensions/pairing/tests/programs/examples/pairing_miller_loop.rs +++ b/extensions/pairing/tests/programs/examples/pairing_miller_loop.rs @@ -20,8 +20,6 @@ mod bn254 { openvm::init!("openvm_init_pairing_miller_loop_bn254.rs"); pub fn test_miller_loop(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); let s0 = &io[0..32 * 2]; let s1 = &io[32 * 2..32 * 4]; let q0 = &io[32 * 4..32 * 8]; @@ -56,8 +54,6 @@ mod bls12_381 { openvm::init!("openvm_init_pairing_miller_loop_bls12_381.rs"); pub fn test_miller_loop(io: &[u8]) { - setup_0(); - setup_all_complex_extensions(); let s0 = &io[0..48 * 2]; let s1 = &io[48 * 2..48 * 4]; let q0 = &io[48 * 4..48 * 8]; diff --git a/extensions/pairing/tests/programs/examples/pairing_miller_step.rs b/extensions/pairing/tests/programs/examples/pairing_miller_step.rs index c944298aa6..a2d4662cf4 100644 --- a/extensions/pairing/tests/programs/examples/pairing_miller_step.rs +++ b/extensions/pairing/tests/programs/examples/pairing_miller_step.rs @@ -155,15 +155,11 @@ pub fn main() { #[cfg(feature = "bn254")] { - bn254::setup_0(); - bn254::setup_all_complex_extensions(); bn254::test_miller_step(&io[..32 * 12]); bn254::test_miller_double_and_add_step(&io[32 * 12..]); } #[cfg(feature = "bls12_381")] { - bls12_381::setup_0(); - bls12_381::setup_all_complex_extensions(); bls12_381::test_miller_step(&io[..48 * 12]); bls12_381::test_miller_double_and_add_step(&io[48 * 12..]); } diff --git a/extensions/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs b/extensions/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs index 00181b03ef..c130859ad8 100644 --- a/extensions/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs +++ b/extensions/pairing/tests/programs/openvm_init_fp12_mul_bls12_381.rs @@ -1,4 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" } -openvm_algebra_guest::complex_macros::complex_init! { Bls12_381Fp2 { mod_idx = 0 } } +openvm_algebra_guest::moduli_macros::moduli_init! { "21888242871839275222246405745257275088696311157297823662689037894645226208583" } +openvm_algebra_guest::complex_macros::complex_init! { Bn254Fp2 { mod_idx = 0 } } openvm_ecc_guest::sw_macros::sw_init! { } diff --git a/extensions/pairing/tests/src/lib.rs b/extensions/pairing/tests/src/lib.rs index f14956f024..7f12f01241 100644 --- a/extensions/pairing/tests/src/lib.rs +++ b/extensions/pairing/tests/src/lib.rs @@ -5,6 +5,10 @@ mod bn254 { use std::iter; use eyre::Result; + use halo2curves_axiom::{ + bn256::{Fq12, Fq2, Fr, G1Affine, G2Affine}, + ff::Field, + }; use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; use openvm_circuit::{ @@ -14,10 +18,6 @@ mod bn254 { use openvm_ecc_circuit::WeierstrassExtension; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, - halo2curves::{ - bn256::{Fq12, Fq2, Fr, G1Affine, G2Affine}, - ff::Field, - }, AffinePoint, }; use openvm_instructions::exe::VmExe; @@ -425,6 +425,10 @@ mod bn254 { #[cfg(test)] mod bls12_381 { use eyre::Result; + use halo2curves_axiom::{ + bls12_381::{Fq12, Fq2, Fr, G1Affine, G2Affine}, + ff::Field, + }; use num_bigint::BigUint; use num_traits::{self, FromPrimitive}; use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; @@ -436,10 +440,6 @@ mod bls12_381 { use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension}; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, - halo2curves::{ - bls12_381::{Fq12, Fq2, Fr, G1Affine, G2Affine}, - ff::Field, - }, AffinePoint, }; use openvm_ecc_transpiler::EccTranspilerExtension; diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs index cb34bcd5aa..1c51a272fd 100644 --- a/extensions/sha256/guest/src/lib.rs +++ b/extensions/sha256/guest/src/lib.rs @@ -5,14 +5,6 @@ pub const OPCODE: u8 = 0x0b; pub const SHA256_FUNCT3: u8 = 0b100; pub const SHA256_FUNCT7: u8 = 0x1; -/// The sha256 cryptographic hash function. -#[inline(always)] -pub fn sha256(input: &[u8]) -> [u8; 32] { - let mut output = [0u8; 32]; - set_sha256(input, &mut output); - output -} - /// zkvm native implementation of sha256 /// # Safety /// @@ -25,21 +17,6 @@ pub fn sha256(input: &[u8]) -> [u8; 32] { #[cfg(target_os = "zkvm")] #[inline(always)] #[no_mangle] -extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { +pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); } - -/// Sets `output` to the sha256 hash of `input`. -pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { - #[cfg(not(target_os = "zkvm"))] - { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(input); - output.copy_from_slice(hasher.finalize().as_ref()); - } - #[cfg(target_os = "zkvm")] - { - zkvm_sha256_impl(input.as_ptr(), input.len(), output.as_mut_ptr() as *mut u8); - } -} diff --git a/extensions/sha256/tests/Cargo.toml b/extensions/sha256/tests/Cargo.toml deleted file mode 100644 index 52fb11f9ea..0000000000 --- a/extensions/sha256/tests/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "openvm-sha256-integration-tests" -description = "Integration tests for the OpenVM sha256 extension" -version.workspace = true -authors.workspace = true -edition.workspace = true -homepage.workspace = true -repository.workspace = true - -[dependencies] -openvm-instructions = { workspace = true } -openvm-stark-sdk.workspace = true -openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-transpiler.workspace = true -openvm-sha256-transpiler.workspace = true -openvm-sha256-circuit.workspace = true -openvm-rv32im-transpiler.workspace = true -openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" } -eyre.workspace = true - -[features] -default = ["parallel"] -parallel = ["openvm-circuit/parallel"] diff --git a/extensions/sha256/tests/programs/Cargo.toml b/extensions/sha256/tests/programs/Cargo.toml deleted file mode 100644 index dad3842cea..0000000000 --- a/extensions/sha256/tests/programs/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[workspace] -[package] -name = "openvm-keccak256-test-programs" -version = "0.0.0" -edition = "2021" - -[dependencies] -openvm = { path = "../../../../crates/toolchain/openvm" } -openvm-platform = { path = "../../../../crates/toolchain/platform" } -openvm-sha256-guest = { path = "../../guest" } -hex = { version = "0.4.3", default-features = false, features = ["alloc"] } -serde = { version = "1.0", default-features = false, features = [ - "alloc", - "derive", -] } - - -[features] -default = [] -std = ["serde/std", "openvm/std"] - -[profile.release] -panic = "abort" -lto = "thin" # turn on lto = fat to decrease binary size, but this optimizes out some missing extern links so we shouldn't use it for testing -# strip = "symbols" diff --git a/extensions/sha256/tests/programs/examples/sha.rs b/extensions/sha256/tests/programs/examples/sha.rs deleted file mode 100644 index fffbc677a7..0000000000 --- a/extensions/sha256/tests/programs/examples/sha.rs +++ /dev/null @@ -1,30 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_main)] -#![cfg_attr(not(feature = "std"), no_std)] - -extern crate alloc; - -use alloc::vec::Vec; -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha256_guest::sha256; - -openvm::entry!(main); - -pub fn main() { - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ("98c1c0bdb7d5fea9a88859f06c6c439f", "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05"), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - - ]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} diff --git a/extensions/sha256/tests/src/lib.rs b/extensions/sha256/tests/src/lib.rs deleted file mode 100644 index 7afaa08e56..0000000000 --- a/extensions/sha256/tests/src/lib.rs +++ /dev/null @@ -1,32 +0,0 @@ -#[cfg(test)] -mod tests { - use eyre::Result; - use openvm_circuit::utils::air_test; - use openvm_instructions::exe::VmExe; - use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, - }; - use openvm_sha256_circuit::Sha256Rv32Config; - use openvm_sha256_transpiler::Sha256TranspilerExtension; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; - use openvm_transpiler::{transpiler::Transpiler, FromElf}; - - type F = BabyBear; - - #[test] - fn test_sha256() -> Result<()> { - let config = Sha256Rv32Config::default(); - let elf = build_example_program_at_path(get_programs_dir!(), "sha", &config)?; - let openvm_exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), - )?; - air_test(config, openvm_exe); - Ok(()) - } -}