diff --git a/Cargo.lock b/Cargo.lock index 942d786eb9..f56c362d30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4143,6 +4143,7 @@ dependencies = [ "openvm-stark-sdk", "rand", "serde", + "test-case", ] [[package]] diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 0bc8ccdd5f..32e6c8a2f3 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -140,7 +140,11 @@ impl VmChipTestBuilder { pointer: usize, writes: Vec<[F; NUM_LIMBS]>, ) { - self.write(1usize, register, [F::from_canonical_usize(pointer)]); + self.write( + 1usize, + register, + pointer.to_le_bytes().map(F::from_canonical_u8), + ); for (i, &write) in writes.iter().enumerate() { self.write(2usize, pointer + i * NUM_LIMBS, write); } diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index 09d68a9d1b..7d133ff151 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -29,6 +29,7 @@ serde.workspace = true openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel", "jemalloc"] diff --git a/extensions/bigint/circuit/src/extension.rs b/extensions/bigint/circuit/src/extension.rs index b9eeeafd99..12f161c0b8 100644 --- a/extensions/bigint/circuit/src/extension.rs +++ b/extensions/bigint/circuit/src/extension.rs @@ -5,11 +5,12 @@ use openvm_bigint_transpiler::{ }; use openvm_circuit::{ arch::{ - SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + ExecutionBridge, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, + VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, @@ -25,6 +26,9 @@ use serde::{Deserialize, Serialize}; use crate::*; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Int256Rv32Config { #[system] @@ -69,7 +73,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { [1 << 8, 32 * (1 << 8)] } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum Int256Executor { BaseAlu256(Rv32BaseAlu256Chip), LessThan256(Rv32LessThan256Chip), @@ -101,6 +105,8 @@ impl VmExtension for Int256 { program_bus, memory_bridge, } = builder.system_port(); + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker_chip = builder.system_base().range_checker_chip.clone(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -113,8 +119,8 @@ impl VmExtension for Int256 { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; + // let offline_memory = builder.system_base().offline_memory(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; let range_tuple_chip = if let Some(chip) = builder .find_chip::>() @@ -133,66 +139,97 @@ impl VmExtension for Int256 { }; let base_alu_chip = Rv32BaseAlu256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BaseAluCoreAir::new(bitwise_lu_chip.bus(), Rv32BaseAlu256Opcode::CLASS_OFFSET), + ), + Rv32BaseAlu256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + Rv32BaseAlu256Opcode::CLASS_OFFSET, ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( base_alu_chip, Rv32BaseAlu256Opcode::iter().map(|x| x.global_opcode()), )?; let less_than_chip = Rv32LessThan256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + LessThanCoreAir::new(bitwise_lu_chip.bus(), Rv32LessThan256Opcode::CLASS_OFFSET), + ), + Rv32LessThan256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + Rv32LessThan256Opcode::CLASS_OFFSET, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( less_than_chip, Rv32LessThan256Opcode::iter().map(|x| x.global_opcode()), )?; let branch_equal_chip = Rv32BranchEqual256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BranchEqualCoreAir::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + Rv32BranchEqual256Step::new( + Rv32HeapBranchAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + Rv32BranchEqual256Opcode::CLASS_OFFSET, + DEFAULT_PC_STEP, ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( branch_equal_chip, Rv32BranchEqual256Opcode::iter().map(|x| x.global_opcode()), )?; let branch_less_than_chip = Rv32BranchLessThan256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BranchLessThanCoreAir::new( + bitwise_lu_chip.bus(), + Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ), ), - BranchLessThanCoreChip::new( + Rv32BranchLessThan256Step::new( + Rv32HeapBranchAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), Rv32BranchLessThan256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( branch_less_than_chip, @@ -200,36 +237,53 @@ impl VmExtension for Int256 { )?; let multiplication_chip = Rv32Multiplication256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + MultiplicationCoreAir::new(*range_tuple_chip.bus(), Rv32Mul256Opcode::CLASS_OFFSET), + ), + Rv32Multiplication256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + range_tuple_chip.clone(), + Rv32Mul256Opcode::CLASS_OFFSET, ), - MultiplicationCoreChip::new(range_tuple_chip, Rv32Mul256Opcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( multiplication_chip, Rv32Mul256Opcode::iter().map(|x| x.global_opcode()), )?; let shift_chip = Rv32Shift256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ShiftCoreAir::new( + bitwise_lu_chip.bus(), + range_checker_chip.bus(), + Rv32Shift256Opcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + Rv32Shift256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), - range_checker_chip, + range_checker_chip.clone(), Rv32Shift256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( shift_chip, Rv32Shift256Opcode::iter().map(|x| x.global_opcode()), diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 295ef73db2..ba971f27a5 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -1,9 +1,15 @@ -use openvm_circuit::{self, arch::VmChipWrapper}; -use openvm_rv32_adapters::{Rv32HeapAdapterChip, Rv32HeapBranchAdapterChip}; +use openvm_circuit::{ + self, + arch::{NewVmChipWrapper, VmAirWrapper}, +}; +use openvm_rv32_adapters::{ + Rv32HeapAdapterAir, Rv32HeapAdapterStep, Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep, +}; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BaseAluStep, BranchEqualCoreAir, BranchEqualStep, BranchLessThanCoreAir, + BranchLessThanStep, LessThanCoreAir, LessThanStep, MultiplicationCoreAir, MultiplicationStep, + ShiftCoreAir, ShiftStep, }; mod extension; @@ -12,38 +18,74 @@ pub use extension::*; #[cfg(test)] mod tests; -pub type Rv32BaseAlu256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - BaseAluCoreChip, +/// BaseAlu256 +pub type Rv32BaseAlu256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + BaseAluCoreAir, >; +pub type Rv32BaseAlu256Step = BaseAluStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, +>; +pub type Rv32BaseAlu256Chip = NewVmChipWrapper; -pub type Rv32LessThan256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - LessThanCoreChip, +/// LessThan256 +pub type Rv32LessThan256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + LessThanCoreAir, +>; +pub type Rv32LessThan256Step = LessThanStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, >; +pub type Rv32LessThan256Chip = NewVmChipWrapper; -pub type Rv32Multiplication256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - MultiplicationCoreChip, +/// Multiplication256 +pub type Rv32Multiplication256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + MultiplicationCoreAir, >; +pub type Rv32Multiplication256Step = MultiplicationStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, +>; +pub type Rv32Multiplication256Chip = + NewVmChipWrapper; -pub type Rv32Shift256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - ShiftCoreChip, +/// Shift256 +pub type Rv32Shift256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + ShiftCoreAir, +>; +pub type Rv32Shift256Step = ShiftStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, >; +pub type Rv32Shift256Chip = NewVmChipWrapper; -pub type Rv32BranchEqual256Chip = VmChipWrapper< - F, - Rv32HeapBranchAdapterChip, - BranchEqualCoreChip, +/// BranchEqual256 +pub type Rv32BranchEqual256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchEqualCoreAir, >; +pub type Rv32BranchEqual256Step = + BranchEqualStep, INT256_NUM_LIMBS>; +pub type Rv32BranchEqual256Chip = + NewVmChipWrapper; -pub type Rv32BranchLessThan256Chip = VmChipWrapper< - F, - Rv32HeapBranchAdapterChip, - BranchLessThanCoreChip, +/// BranchLessThan256 +pub type Rv32BranchLessThan256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchLessThanCoreAir, +>; +pub type Rv32BranchLessThan256Step = BranchLessThanStep< + Rv32HeapBranchAdapterStep<2, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, >; +pub type Rv32BranchLessThan256Chip = + NewVmChipWrapper; diff --git a/extensions/bigint/circuit/src/tests.rs b/extensions/bigint/circuit/src/tests.rs index 0e26352410..8ae7a30894 100644 --- a/extensions/bigint/circuit/src/tests.rs +++ b/extensions/bigint/circuit/src/tests.rs @@ -5,7 +5,7 @@ use openvm_bigint_transpiler::{ use openvm_circuit::{ arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS}, - InstructionExecutor, + InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -13,22 +13,32 @@ use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; -use openvm_instructions::{program::PC_BITS, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_instructions::{ + program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_CELL_BITS, + LocalOpcode, +}; use openvm_rv32_adapters::{ - rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterChip, - Rv32HeapBranchAdapterChip, + rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterAir, Rv32HeapAdapterStep, + Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep, }; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV_B_TYPE_IMM_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BranchEqualCoreAir, BranchLessThanCoreAir, LessThanCoreAir, + MultiplicationCoreAir, ShiftCoreAir, }; use openvm_rv32im_transpiler::{ - BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, ShiftOpcode, + BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, MulOpcode, ShiftOpcode, }; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use crate::{ + Rv32BaseAlu256Step, Rv32BranchEqual256Step, Rv32BranchLessThan256Step, Rv32LessThan256Step, + Rv32Multiplication256Step, Rv32Shift256Step, +}; use super::{ Rv32BaseAlu256Chip, Rv32BranchEqual256Chip, Rv32BranchLessThan256Chip, Rv32LessThan256Chip, @@ -36,148 +46,144 @@ use super::{ }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); #[allow(clippy::type_complexity)] -fn run_int_256_rand_execute>( - opcode: usize, - num_ops: usize, - executor: &mut E, +fn set_and_execute_rand>( tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: usize, branch_fn: Option bool>, ) { - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let mut rng = create_seeded_rng(); let branch = branch_fn.is_some(); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - if branch { - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - let instruction = rv32_heap_branch_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - imm as isize, - opcode, - ); - - tester.execute_with_pc( - executor, - &instruction, - rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), - ); - - let cmp_result = branch_fn.unwrap()(opcode, &b, &c); - let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; - let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; - assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); - } else { - let instruction = rv32_write_heap_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode, - ); - tester.execute(executor, &instruction); - } + let b = generate_long_number::(rng); + let c = generate_long_number::(rng); + if branch { + let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); + let instruction = rv32_heap_branch_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + imm as isize, + opcode, + ); + + tester.execute_with_pc( + chip, + &instruction, + rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), + ); + + let cmp_result = branch_fn.unwrap()(opcode, &b, &c); + let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; + let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; + assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); + } else { + let instruction = rv32_write_heap_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + opcode, + ); + tester.execute(chip, &instruction); } } +#[test_case(BaseAluOpcode::ADD, 24)] +#[test_case(BaseAluOpcode::SUB, 24)] +#[test_case(BaseAluOpcode::XOR, 24)] +#[test_case(BaseAluOpcode::OR, 24)] +#[test_case(BaseAluOpcode::AND, 24)] fn run_alu_256_rand_test(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BaseAlu256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAlu256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BaseAluCoreAir::new(bitwise_bus, offset), + ), + Rv32BaseAlu256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - BaseAluCoreChip::new(bitwise_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32BaseAlu256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn alu_256_add_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::ADD, 24); -} - -#[test] -fn alu_256_sub_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::SUB, 24); -} - -#[test] -fn alu_256_xor_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::XOR, 24); -} - -#[test] -fn alu_256_or_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::OR, 24); -} - -#[test] -fn alu_256_and_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::AND, 24); -} - +#[test_case(LessThanOpcode::SLT, 24)] +#[test_case(LessThanOpcode::SLTU, 24)] fn run_lt_256_rand_test(opcode: LessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32LessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32LessThan256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + LessThanCoreAir::new(bitwise_bus, offset), + ), + Rv32LessThan256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - LessThanCoreChip::new(bitwise_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32LessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn lt_256_slt_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLT, 24); -} - -#[test] -fn lt_256_sltu_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLTU, 24); -} +#[test_case(MulOpcode::MUL, 24)] +fn run_mul_256_rand_test(opcode: MulOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Mul256Opcode::CLASS_OFFSET; -fn run_mul_256_rand_test(num_ops: usize) { let range_tuple_bus = RangeTupleCheckerBus::new( RANGE_TUPLE_CHECKER_BUS, [ @@ -185,105 +191,120 @@ fn run_mul_256_rand_test(num_ops: usize) { (INT256_NUM_LIMBS * (1 << RV32_CELL_BITS)) as u32, ], ); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Multiplication256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + MultiplicationCoreAir::new(range_tuple_bus, offset), ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), Rv32Mul256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + Rv32Multiplication256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + range_tuple_chip.clone(), + offset, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - Rv32Mul256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester .build() .load(chip) - .load(range_tuple_checker) + .load(range_tuple_chip) .load(bitwise_chip) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn mul_256_rand_test() { - run_mul_256_rand_test(24); -} - +#[test_case(ShiftOpcode::SLL, 24)] +#[test_case(ShiftOpcode::SRL, 24)] +#[test_case(ShiftOpcode::SRA, 24)] fn run_shift_256_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Shift256Opcode::CLASS_OFFSET; + + let range_checker_chip = tester.range_checker(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Shift256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + ShiftCoreAir::new(bitwise_bus, range_checker_chip.bus(), offset), ), - ShiftCoreChip::new( + Rv32Shift256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), - Rv32Shift256Opcode::CLASS_OFFSET, + range_checker_chip.clone(), + offset, ), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32Shift256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } + + drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn shift_256_sll_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SLL, 24); -} - -#[test] -fn shift_256_srl_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRL, 24); -} - -#[test] -fn shift_256_sra_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRA, 24); -} - +#[test_case(BranchEqualOpcode::BEQ, 24)] +#[test_case(BranchEqualOpcode::BNE, 24)] fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchEqual256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut chip = Rv32BranchEqual256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BranchEqualCoreAir::new(offset, DEFAULT_PC_STEP), + ), + Rv32BranchEqual256Step::new( + Rv32HeapBranchAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + offset, + DEFAULT_PC_STEP, ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { @@ -294,93 +315,79 @@ fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { == BranchEqualOpcode::BNE.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET) }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn beq_256_beq_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BEQ, 24); -} - -#[test] -fn beq_256_bne_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BNE, 24); -} - +#[test_case(BranchLessThanOpcode::BLT, 24)] +#[test_case(BranchLessThanOpcode::BLTU, 24)] +#[test_case(BranchLessThanOpcode::BGE, 24)] +#[test_case(BranchLessThanOpcode::BGEU, 24)] fn run_blt_256_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchLessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32BranchLessThan256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BranchLessThanCoreAir::new(bitwise_bus, offset), ), - BranchLessThanCoreChip::new( + Rv32BranchLessThan256Step::new( + Rv32HeapBranchAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + offset, ), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - - let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { - let opcode = - BranchLessThanOpcode::from_usize(opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET); - let (is_ge, is_signed) = match opcode { - BranchLessThanOpcode::BLT => (false, true), - BranchLessThanOpcode::BLTU => (false, false), - BranchLessThanOpcode::BGE => (true, true), - BranchLessThanOpcode::BGEU => (true, false), - }; - let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - for (x, y) in x.iter().rev().zip(y.iter().rev()) { - if x != y { - return (x < y) ^ x_sign ^ y_sign ^ is_ge; + let branch_fn = + |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| -> bool { + let opcode = BranchLessThanOpcode::from_usize( + opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ); + let (is_ge, is_signed) = match opcode { + BranchLessThanOpcode::BLT => (false, true), + BranchLessThanOpcode::BLTU => (false, false), + BranchLessThanOpcode::BGE => (true, true), + BranchLessThanOpcode::BGEU => (true, false), + }; + let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + for (x, y) in x.iter().rev().zip(y.iter().rev()) { + if x != y { + return (x < y) ^ x_sign ^ y_sign ^ is_ge; + } } - } - is_ge - }; + is_ge + }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchLessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } - -#[test] -fn blt_256_blt_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLT, 24); -} - -#[test] -fn blt_256_bltu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLTU, 24); -} - -#[test] -fn blt_256_bge_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGE, 24); -} - -#[test] -fn blt_256_bgeu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGEU, 24); -} diff --git a/extensions/rv32-adapters/src/heap.rs b/extensions/rv32-adapters/src/heap.rs index 8217f6833b..cf7e54ddef 100644 --- a/extensions/rv32-adapters/src/heap.rs +++ b/extensions/rv32-adapters/src/heap.rs @@ -1,18 +1,14 @@ -use std::{ - array::{self, from_fn}, - borrow::Borrow, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, BasicAdapterInterface, + ExecutionBridge, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory}, - program::ProgramBus, + system::memory::{ + offline_checker::MemoryBridge, + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -20,20 +16,15 @@ use openvm_circuit_primitives::bitwise_op_lookup::{ }; use openvm_instructions::{ instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, }; -use openvm_rv32im_circuit::adapters::{read_rv32_register, tmp_convert_to_u8s}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, PrimeField32}, }; -use super::{ - vec_heap_generate_trace_row_impl, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, - Rv32VecHeapReadRecord, Rv32VecHeapWriteRecord, -}; +use crate::{RV32VecHeapAdapterStep, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols}; /// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -101,139 +92,100 @@ impl< } } -pub struct Rv32HeapAdapterChip< - F: Field, +pub struct Rv32HeapAdapterStep< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, -> { - pub air: Rv32HeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} +>(RV32VecHeapAdapterStep); -impl - Rv32HeapAdapterChip +impl + Rv32HeapAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Self { - air: Rv32HeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + Rv32HeapAdapterStep(RV32VecHeapAdapterStep::new( + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, - } + )) } } -impl - VmAdapterChip for Rv32HeapAdapterChip +impl< + F: PrimeField32, + CTX, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceStep for Rv32HeapAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord<1, WRITE_SIZE>; - type Air = Rv32HeapAdapterAir; - type Interface = - BasicAdapterInterface, NUM_READS, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; - - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); - - let read_records = rs_vals.map(|address| { - debug_assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - [memory.read::(e, F::from_canonical_u32(address))] - }); - let read_data = read_records.map(|r| r[0].1.map(F::from_canonical_u8)); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| array::from_fn(|i| r[i].0)), - }; - - Ok((read_data, record)) + const WIDTH: usize = + Rv32VecHeapAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; 1]; + + type TraceContext<'a> = (); + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_cols: &mut Rv32VecHeapAdapterCols = + adapter_row.borrow_mut(); + adapter_cols.from_state.pc = F::from_canonical_u32(pc); + adapter_cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let writes = [memory - .write(e, read_record.rd_val, &tmp_convert_to_u8s(output.writes[0])) - .0]; - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 6, - "timestamp delta is {}, expected 6", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + adapter_row: &mut [F], + ) -> Self::ReadData { + let read_data = AdapterTraceStep::::read(&self.0, memory, instruction, adapter_row); + read_data.map(|r| r[0]) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + data: &Self::WriteData, ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ); + AdapterTraceStep::::write(&self.0, memory, instruction, adapter_row, data); + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, ctx: (), adapter_row: &mut [F]) { + AdapterTraceStep::::fill_trace_row(&self.0, mem_helper, ctx, adapter_row); + } +} + +impl + AdapterExecutorE1 for Rv32HeapAdapterStep +{ + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; 1]; + + #[inline(always)] + fn read(&self, memory: &mut Mem, instruction: &Instruction) -> Self::ReadData + where + Mem: GuestMemory, + { + let read_data = AdapterExecutorE1::::read(&self.0, memory, instruction); + read_data.map(|r| r[0]) } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn write(&self, memory: &mut Mem, instruction: &Instruction, data: &Self::WriteData) + where + Mem: GuestMemory, + { + AdapterExecutorE1::::write(&self.0, memory, instruction, data); } } diff --git a/extensions/rv32-adapters/src/heap_branch.rs b/extensions/rv32-adapters/src/heap_branch.rs index fc4067e84a..a80c6c214f 100644 --- a/extensions/rv32-adapters/src/heap_branch.rs +++ b/extensions/rv32-adapters/src/heap_branch.rs @@ -2,22 +2,18 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::once, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, BasicAdapterInterface, + ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -30,15 +26,13 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + memory_read, new_read_rv32_register, tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; /// This adapter reads from NUM_READS <= 2 pointers. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -170,158 +164,148 @@ impl VmA } } -pub struct Rv32HeapBranchAdapterChip { - pub air: Rv32HeapBranchAdapterAir, +pub struct Rv32HeapBranchAdapterStep { + pub pointer_max_bits: usize, + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } -impl - Rv32HeapBranchAdapterChip +impl + Rv32HeapBranchAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); Self { - air: Rv32HeapBranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, } } } +impl AdapterExecutorE1 + for Rv32HeapBranchAdapterStep +{ + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = (); -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32HeapBranchReadRecord { - #[serde(with = "BigArray")] - pub rs_reads: [RecordId; NUM_READS], - #[serde(with = "BigArray")] - pub heap_reads: [RecordId; NUM_READS], + fn read(&self, memory: &mut Mem, instruction: &Instruction) -> Self::ReadData + where + Mem: GuestMemory, + { + let Instruction { a, b, d, e, .. } = *instruction; + + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + // Read register values + let rs_vals = from_fn(|i| { + let addr = if i == 0 { a } else { b }; + new_read_rv32_register(memory, d, addr.as_canonical_u32()) + }); + + // Read memory values + rs_vals.map(|address| { + assert!(address as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)); + memory_read(memory, e, address) + }) + } + + fn write(&self, _memory: &mut Mem, _instruction: &Instruction, _data: &Self::WriteData) + where + Mem: GuestMemory, + { + // This function intentionally does nothing + } } -impl VmAdapterChip - for Rv32HeapBranchAdapterChip +impl AdapterTraceStep + for Rv32HeapBranchAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32HeapBranchReadRecord; - type WriteRecord = ExecutionState; - type Air = Rv32HeapBranchAdapterAir; - type Interface = BasicAdapterInterface, NUM_READS, 0, READ_SIZE, 0>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = (); + type TraceContext<'a> = (); + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let cols: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = + adapter_row.borrow_mut(); + cols.from_state.pc = F::from_canonical_u32(pc); + cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + adapter_row: &mut [F], + ) -> Self::ReadData { let Instruction { a, b, d, e, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { a } else { b }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); + let cols: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = + adapter_row.borrow_mut(); - let heap_records = rs_vals.map(|address| { - assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - memory.read::(e, F::from_canonical_u32(address)) + // Read register values + let rs_vals: [_; NUM_READS] = from_fn(|i| { + let addr = if i == 0 { a } else { b }; + cols.rs_ptr[i] = addr; + let rs_val = tracing_read(memory, e, addr.as_canonical_u32(), &mut cols.rs_read_aux[i]); + cols.rs_val[i] = rs_val.map(F::from_canonical_u8); + u32::from_le_bytes(rs_val) }); - let record = Rv32HeapBranchReadRecord { - rs_reads: rs_records, - heap_reads: heap_records.map(|r| r.0), - }; - Ok((heap_records.map(|r| r.1.map(F::from_canonical_u8)), record)) + // Read memory values + from_fn(|i| { + assert!(rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)); + tracing_read(memory, e, rs_vals[i], &mut cols.heap_read_aux[i]) + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + _memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 4, - "timestamp delta is {}, expected 4", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + _adapter_row: &mut [F], + _data: &Self::WriteData, + ) { + // This function intentionally does nothing } - fn generate_trace_row( + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = - row_slice.borrow_mut(); - row_slice.from_state = write_record.map(F::from_canonical_u32); - - let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r)); - - for (i, rs_read) in rs_reads.iter().enumerate() { - row_slice.rs_ptr[i] = rs_read.pointer; - row_slice.rs_val[i].copy_from_slice(rs_read.data_slice()); - aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]); - } - - for (i, heap_read) in read_record.heap_reads.iter().enumerate() { - let record = memory.record_by_id(*heap_read); - aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]); - } + let cols: &mut Rv32HeapBranchAdapterCols = + adapter_row.borrow_mut(); // Range checks: - let need_range_check: Vec = rs_reads + let need_range_check: Vec = cols + .rs_val .iter() - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }) + .map(|&val| val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32()) .chain(once(0)) // in case NUM_READS is odd .collect(); - debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; for pair in need_range_check.chunks_exact(2) { self.bitwise_lookup_chip .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); } } - - fn air(&self) -> &Self::Air { - &self.air - } } diff --git a/extensions/rv32-adapters/src/lib.rs b/extensions/rv32-adapters/src/lib.rs index d84c82f617..c194f3e8ca 100644 --- a/extensions/rv32-adapters/src/lib.rs +++ b/extensions/rv32-adapters/src/lib.rs @@ -1,14 +1,14 @@ -mod eq_mod; +// mod eq_mod; mod heap; mod heap_branch; mod vec_heap; -mod vec_heap_two_reads; +// mod vec_heap_two_reads; -pub use eq_mod::*; +// pub use eq_mod::*; pub use heap::*; pub use heap_branch::*; pub use vec_heap::*; -pub use vec_heap_two_reads::*; +// pub use vec_heap_two_reads::*; #[cfg(any(test, feature = "test-utils"))] mod test_utils; diff --git a/extensions/rv32-adapters/src/vec_heap.rs b/extensions/rv32-adapters/src/vec_heap.rs index 0a2766e29a..4d430b3fc1 100644 --- a/extensions/rv32-adapters/src/vec_heap.rs +++ b/extensions/rv32-adapters/src/vec_heap.rs @@ -2,21 +2,18 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::{once, zip}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, ExecutionBridge, ExecutionState, + VecHeapAdapterInterface, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -29,16 +26,14 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, tmp_convert_to_u8s, RV32_CELL_BITS, - RV32_REGISTER_NUM_LIMBS, + abstract_compose, memory_read, memory_write, new_read_rv32_register, tracing_read, + tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; /// This adapter reads from R (R <= 2) pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -47,87 +42,6 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the /// heap, starting from the address in `rd`. -#[derive(Clone)] -pub struct Rv32VecHeapAdapterChip< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: - Rv32VecHeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!(NUM_READS <= 2); - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapReadRecord< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - #[serde_as(as = "[_; NUM_READS]")] - pub rs: [RecordId; NUM_READS], - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32VecHeapWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - #[repr(C)] #[derive(AlignedBorrow)] pub struct Rv32VecHeapAdapterCols< @@ -347,204 +261,230 @@ impl< } } +#[derive(derive_new::new)] +pub struct RV32VecHeapAdapterStep< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + impl< F: PrimeField32, + CTX, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapAdapterChip< - F, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > + > AdapterTraceStep + for RV32VecHeapAdapterStep { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord; - type Air = - Rv32VecHeapAdapterAir; - type Interface = VecHeapAdapterInterface< + const WIDTH: usize = Rv32VecHeapAdapterCols::< F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - >; + >::width(); + type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; + type TraceContext<'a> = (); - fn preprocess( - &mut self, - memory: &mut MemoryController, + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + adapter_cols.from_state.pc = F::from_canonical_u32(pc); + adapter_cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; + adapter_row: &mut [F], + ) -> Self::ReadData { + let Instruction { b, c, d, e, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + let e = e.as_canonical_u32(); + let d = d.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); // Read register values - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { + let rs_vals: [_; NUM_READS] = from_fn(|i| { let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record + cols.rs_ptr[i] = addr; + let rs_val = tracing_read(memory, e, addr.as_canonical_u32(), &mut cols.rs_read_aux[i]); + cols.rs_val[i] = rs_val.map(F::from_canonical_u8); + u32::from_le_bytes(rs_val) }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); // Read memory values - let read_records = rs_vals.map(|address| { + from_fn(|i| { assert!( - address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.air.address_bits) + rs_vals[i] as usize + READ_SIZE * BLOCKS_PER_READ - 1 + < (1 << self.pointer_max_bits) ); - from_fn(|i| { - memory.read::( + from_fn(|j| { + tracing_read( + memory, e, - F::from_canonical_u32(address + (i * READ_SIZE) as u32), + rs_vals[i] + (j * READ_SIZE) as u32, + &mut cols.reads_aux[i][j], ) }) - }); - let read_data = read_records.map(|r| r.map(|x| x.1.map(F::from_canonical_u8))); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| r.map(|x| x.0)), - }; - - Ok((read_data, record)) + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut openvm_circuit::system::memory::online::TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let Instruction { a, d, e, .. } = *instruction; + + let e = e.as_canonical_u32(); + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + cols.rd_ptr = a; + let rd_val = tracing_read( + memory, + d.as_canonical_u32(), + a.as_canonical_u32(), + &mut cols.rd_read_aux, + ); + cols.rd_val = rd_val.map(F::from_canonical_u8); + + let rd_val = u32::from_le_bytes(rd_val); + assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.pointer_max_bits)); + + for i in 0..BLOCKS_PER_WRITE { + tracing_write( + memory, e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - &tmp_convert_to_u8s(write), + rd_val + (i * WRITE_SIZE) as u32, + &data[i], + &mut cols.writes_aux[i], ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + } } - fn generate_trace_row( + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); - fn air(&self) -> &Self::Air { - &self.air + // Range checks: + let need_range_check: Vec = cols + .rs_val + .iter() + .chain(std::iter::repeat_n(&cols.rd_val, 2)) + .map(|&val| val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32()) + .collect(); + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + for pair in need_range_check.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + } } } -pub(super) fn vec_heap_generate_trace_row_impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapReadRecord, - write_record: &Rv32VecHeapWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapAdapterCols< - F, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rd = memory.record_by_id(read_record.rd); - let rs = read_record - .rs - .into_iter() - .map(|r| memory.record_by_id(r)) - .collect::>(); - - row_slice.rd_ptr = rd.pointer; - row_slice.rd_val.copy_from_slice(rd.data_slice()); - - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); - } +impl< + F: PrimeField32, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterExecutorE1 + for RV32VecHeapAdapterStep +{ + type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); + fn read(&self, memory: &mut Mem, instruction: &Instruction) -> Self::ReadData + where + Mem: GuestMemory, + { + let Instruction { b, c, d, e, .. } = *instruction; - for (i, reads) in read_record.reads.iter().enumerate() { - for (j, &x) in reads.iter().enumerate() { - let record = memory.record_by_id(x); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]); - } - } + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - for (i, &w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } + // Read register values + let rs_vals = from_fn(|i| { + let addr = if i == 0 { b } else { c }; + new_read_rv32_register(memory, d, addr.as_canonical_u32()) + }); - // Range checks: - let need_range_check: Vec = rs - .iter() - .chain(std::iter::repeat_n(&rd, 2)) - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() + // Read memory values + rs_vals.map(|address| { + assert!( + address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.pointer_max_bits) + ); + from_fn(|i| memory_read(memory, e, address + (i * READ_SIZE) as u32)) }) - .collect(); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + } + + fn write(&self, memory: &mut Mem, instruction: &Instruction, data: &Self::WriteData) + where + Mem: GuestMemory, + { + let Instruction { a, d, e, .. } = *instruction; + let rd_val = new_read_rv32_register(memory, d.as_canonical_u32(), a.as_canonical_u32()); + assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.pointer_max_bits)); + + for i in 0..BLOCKS_PER_WRITE { + memory_write( + memory, + e.as_canonical_u32(), + rd_val + (i * WRITE_SIZE) as u32, + &data[i], + ); + } } } diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 6e71db40db..0e28dd3093 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -21,7 +21,6 @@ derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true -test-case.workspace = true # for div_rem: num-bigint.workspace = true @@ -32,6 +31,7 @@ serde-big-array.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel", "jemalloc"] diff --git a/extensions/rv32im/circuit/src/adapters/alu.rs b/extensions/rv32im/circuit/src/adapters/alu.rs index f8322c143e..654d101469 100644 --- a/extensions/rv32im/circuit/src/adapters/alu.rs +++ b/extensions/rv32im/circuit/src/adapters/alu.rs @@ -12,7 +12,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -161,15 +161,18 @@ impl VmAdapterAir for Rv32BaseAluAdapterAir { } #[derive(derive_new::new)] -pub struct Rv32BaseAluAdapterStep; +pub struct Rv32BaseAluAdapterStep { + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} impl AdapterTraceStep for Rv32BaseAluAdapterStep { const WIDTH: usize = size_of::>(); - type ReadData = ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]); - type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; - type TraceContext<'a> = &'a BitwiseOperationLookupChip; + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type TraceContext<'a> = (); #[inline(always)] fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { @@ -218,7 +221,7 @@ impl AdapterTraceStep tracing_read_imm(memory, c.as_canonical_u32(), &mut adapter_row.rs2) }; - (rs1, rs2) + [rs1, rs2] } #[inline(always)] @@ -240,7 +243,7 @@ impl AdapterTraceStep memory, RV32_REGISTER_AS, a.as_canonical_u32(), - data, + &data[0], &mut adapter_row.writes_aux, ); } @@ -249,7 +252,7 @@ impl AdapterTraceStep fn fill_trace_row( &self, mem_helper: &MemoryAuxColsFactory, - bitwise_lookup_chip: &BitwiseOperationLookupChip, + _ctx: (), adapter_row: &mut [F], ) { let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); @@ -264,7 +267,8 @@ impl AdapterTraceStep } else { let rs2_imm = adapter_row.rs2.as_canonical_u32(); let mask = (1 << RV32_CELL_BITS) - 1; - bitwise_lookup_chip.request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); + self.bitwise_lookup_chip + .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); } timestamp += 1; @@ -277,8 +281,8 @@ where F: PrimeField32, { // TODO(ayush): directly use u32 - type ReadData = ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]); - type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; #[inline(always)] fn read(&self, memory: &mut Mem, instruction: &Instruction) -> Self::ReadData @@ -307,7 +311,7 @@ where imm_le }; - (rs1, rs2) + [rs1, rs2] } #[inline(always)] @@ -319,6 +323,6 @@ where debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - memory_write(memory, d.as_canonical_u32(), a.as_canonical_u32(), rd); + memory_write(memory, d.as_canonical_u32(), a.as_canonical_u32(), &rd[0]); } } diff --git a/extensions/rv32im/circuit/src/adapters/branch.rs b/extensions/rv32im/circuit/src/adapters/branch.rs index 25f55bd93d..6fb6ba00c8 100644 --- a/extensions/rv32im/circuit/src/adapters/branch.rs +++ b/extensions/rv32im/circuit/src/adapters/branch.rs @@ -20,7 +20,6 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; use crate::adapters::{memory_read, tracing_read}; @@ -115,7 +114,7 @@ where F: PrimeField32, { const WIDTH: usize = size_of::>(); - type ReadData = ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; type WriteData = (); type TraceContext<'a> = (); @@ -155,7 +154,7 @@ where &mut adapter_row.reads_aux[1], ); - (rs1, rs2) + [rs1, rs2] } #[inline(always)] @@ -191,7 +190,7 @@ where F: PrimeField32, { // TODO(ayush): directly use u32 - type ReadData = ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; type WriteData = (); #[inline(always)] @@ -209,7 +208,7 @@ where let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = memory_read(memory, RV32_REGISTER_AS, b.as_canonical_u32()); - (rs1, rs2) + [rs1, rs2] } #[inline(always)] diff --git a/extensions/rv32im/circuit/src/adapters/mod.rs b/extensions/rv32im/circuit/src/adapters/mod.rs index bcf906e7fd..f1a29d376b 100644 --- a/extensions/rv32im/circuit/src/adapters/mod.rs +++ b/extensions/rv32im/circuit/src/adapters/mod.rs @@ -53,7 +53,7 @@ pub fn decompose(value: u32) -> [F; RV32_REGISTER_NUM_LIMBS] { } #[inline(always)] -pub fn memory_read(memory: &Mem, address_space: u32, ptr: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] +pub fn memory_read(memory: &Mem, address_space: u32, ptr: u32) -> [u8; N] where Mem: GuestMemory, { @@ -67,15 +67,15 @@ where // SAFETY: // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and // minimum alignment of `RV32_REGISTER_NUM_LIMBS` - unsafe { memory.read::(address_space, ptr) } + unsafe { memory.read::(address_space, ptr) } } #[inline(always)] -pub fn memory_write( +pub fn memory_write( memory: &mut Mem, address_space: u32, ptr: u32, - data: &[u8; RV32_REGISTER_NUM_LIMBS], + data: &[u8; N], ) where Mem: GuestMemory, { @@ -89,18 +89,18 @@ pub fn memory_write( // SAFETY: // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and // minimum alignment of `RV32_REGISTER_NUM_LIMBS` - unsafe { memory.write::(address_space, ptr, data) } + unsafe { memory.write::(address_space, ptr, data) } } /// Atomic read operation which increments the timestamp by 1. /// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory /// access. #[inline(always)] -pub fn timed_read( +pub fn timed_read( memory: &mut TracingMemory, address_space: u32, ptr: u32, -) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) { +) -> (u32, [u8; N]) { debug_assert!( address_space == RV32_REGISTER_AS || address_space == RV32_MEMORY_AS @@ -110,18 +110,16 @@ pub fn timed_read( // SAFETY: // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and // minimum alignment of `RV32_REGISTER_NUM_LIMBS` - unsafe { - memory.read::(address_space, ptr) - } + unsafe { memory.read::(address_space, ptr) } } #[inline(always)] -pub fn timed_write( +pub fn timed_write( memory: &mut TracingMemory, address_space: u32, ptr: u32, - val: &[u8; RV32_REGISTER_NUM_LIMBS], -) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) { + data: &[u8; N], +) -> (u32, [u8; N]) { // TODO(ayush): should this allow public values address space debug_assert!( address_space == RV32_REGISTER_AS @@ -130,27 +128,21 @@ pub fn timed_write( ); // SAFETY: - // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_ASwill always have cell type `u8` and + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and // minimum alignment of `RV32_REGISTER_NUM_LIMBS` - unsafe { - memory.write::( - address_space, - ptr, - val, - ) - } + unsafe { memory.write::(address_space, ptr, data) } } /// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer. /// Trace generation relevant to this memory access can be done fully from the recorded buffer. #[inline(always)] -pub fn tracing_read( +pub fn tracing_read( memory: &mut TracingMemory, address_space: u32, ptr: u32, aux_cols: &mut MemoryReadAuxCols, /* TODO[jpw]: switch to raw u8 * buffer */ -) -> [u8; RV32_REGISTER_NUM_LIMBS] +) -> [u8; N] where F: PrimeField32, { @@ -162,18 +154,18 @@ where /// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer. /// Trace generation relevant to this memory access can be done fully from the recorded buffer. #[inline(always)] -pub fn tracing_write( +pub fn tracing_write( memory: &mut TracingMemory, address_space: u32, ptr: u32, - val: &[u8; RV32_REGISTER_NUM_LIMBS], - aux_cols: &mut MemoryWriteAuxCols, /* TODO[jpw]: switch to raw - * u8 - * buffer */ + data: &[u8; N], + aux_cols: &mut MemoryWriteAuxCols, /* TODO[jpw]: switch to raw + * u8 + * buffer */ ) where F: PrimeField32, { - let (t_prev, data_prev) = timed_write(memory, address_space, ptr, val); + let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); aux_cols.set_prev( F::from_canonical_u32(t_prev), data_prev.map(F::from_canonical_u8), @@ -182,16 +174,16 @@ pub fn tracing_write( // TODO(ayush): this is bad but not sure how to avoid #[inline(always)] -pub fn tracing_write_with_base_aux( +pub fn tracing_write_with_base_aux( memory: &mut TracingMemory, address_space: u32, ptr: u32, - val: &[u8; RV32_REGISTER_NUM_LIMBS], + data: &[u8; N], base_aux_cols: &mut MemoryBaseAuxCols, ) where F: PrimeField32, { - let (t_prev, _) = timed_write(memory, address_space, ptr, val); + let (t_prev, _) = timed_write(memory, address_space, ptr, data); base_aux_cols.set_prev(F::from_canonical_u32(t_prev)); } @@ -231,6 +223,11 @@ pub fn read_rv32_register( (record.0, val) } +#[inline(always)] +pub fn new_read_rv32_register(memory: &Mem, address_space: u32, ptr: u32) -> u32 { + u32::from_le_bytes(memory_read(memory, address_space, ptr)) +} + /// Peeks at the value of a register without updating the memory state or incrementing the /// timestamp. pub fn unsafe_read_rv32_register(memory: &MemoryController, pointer: F) -> u32 { diff --git a/extensions/rv32im/circuit/src/adapters/mul.rs b/extensions/rv32im/circuit/src/adapters/mul.rs index 4b937459f2..9259f5549b 100644 --- a/extensions/rv32im/circuit/src/adapters/mul.rs +++ b/extensions/rv32im/circuit/src/adapters/mul.rs @@ -20,7 +20,6 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::{tracing_write, RV32_REGISTER_NUM_LIMBS}; use crate::adapters::{memory_read, memory_write, tracing_read}; @@ -132,8 +131,8 @@ where F: PrimeField32, { const WIDTH: usize = size_of::>(); - type ReadData = ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]); - type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; type TraceContext<'a> = (); #[inline(always)] @@ -171,7 +170,7 @@ where &mut adapter_row.reads_aux[1], ); - (rs1, rs2) + [rs1, rs2] } #[inline(always)] @@ -193,7 +192,7 @@ where memory, RV32_REGISTER_AS, a.as_canonical_u32(), - data, + &data[0], &mut adapter_row.writes_aux, ) } @@ -224,8 +223,8 @@ where F: PrimeField32, { // TODO(ayush): directly use u32 - type ReadData = ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]); - type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; #[inline(always)] fn read(&self, memory: &mut Mem, instruction: &Instruction) -> Self::ReadData @@ -241,7 +240,7 @@ where let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = memory_read(memory, RV32_REGISTER_AS, c.as_canonical_u32()); - (rs1, rs2) + [rs1, rs2] } #[inline(always)] @@ -253,6 +252,6 @@ where debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - memory_write(memory, RV32_REGISTER_AS, a.as_canonical_u32(), rd); + memory_write(memory, RV32_REGISTER_AS, a.as_canonical_u32(), &rd[0]); } } diff --git a/extensions/rv32im/circuit/src/base_alu/core.rs b/extensions/rv32im/circuit/src/base_alu/core.rs index 21281996e2..aa25a05309 100644 --- a/extensions/rv32im/circuit/src/base_alu/core.rs +++ b/extensions/rv32im/circuit/src/base_alu/core.rs @@ -15,9 +15,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, SharedBitwiseOperationLookupChip, - }, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -200,9 +198,9 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], - TraceContext<'a> = &'a BitwiseOperationLookupChip, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), >, { fn get_opcode_name(&self, opcode: usize) -> String { @@ -221,12 +219,15 @@ where let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let rd = run_alu::(local_opcode, &rs1, &rs2); @@ -241,7 +242,7 @@ where core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND); self.adapter - .write(state.memory, instruction, adapter_row, &rd); + .write(state.memory, instruction, adapter_row, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -253,8 +254,7 @@ where fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - self.adapter - .fill_trace_row(mem_helper, self.bitwise_lookup_chip.as_ref(), adapter_row); + self.adapter.fill_trace_row(mem_helper, (), adapter_row); let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); @@ -279,8 +279,8 @@ where A: 'static + for<'a> AdapterExecutorE1< F, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, >, { fn execute_e1( @@ -295,9 +295,9 @@ where let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); let rd = run_alu::(local_opcode, &rs1, &rs2); - self.adapter.write(state.memory, instruction, &rd); + self.adapter.write(state.memory, instruction, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); diff --git a/extensions/rv32im/circuit/src/base_alu/tests.rs b/extensions/rv32im/circuit/src/base_alu/tests.rs index 3d9456fa26..bc7ff78721 100644 --- a/extensions/rv32im/circuit/src/base_alu/tests.rs +++ b/extensions/rv32im/circuit/src/base_alu/tests.rs @@ -55,7 +55,7 @@ fn create_test_chip( BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), ), Rv32BaseAluStep::new( - Rv32BaseAluAdapterStep::new(), + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET, ), diff --git a/extensions/rv32im/circuit/src/branch_eq/core.rs b/extensions/rv32im/circuit/src/branch_eq/core.rs index ba464dbd9b..3f375b31ab 100644 --- a/extensions/rv32im/circuit/src/branch_eq/core.rs +++ b/extensions/rv32im/circuit/src/branch_eq/core.rs @@ -13,9 +13,12 @@ use openvm_circuit::{ MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupChip, SharedBitwiseOperationLookupChip}, + utils::not, +}; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -23,8 +26,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -140,7 +141,6 @@ where } } -#[derive(Debug)] pub struct BranchEqualStep { adapter: A, pub offset: usize, @@ -164,7 +164,7 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = (), TraceContext<'a> = (), >, @@ -185,12 +185,15 @@ where let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let (cmp_result, diff_idx, diff_inv_val) = run_eq(branch_eq_opcode, &rs1, &rs2); @@ -225,8 +228,7 @@ where impl StepExecutorE1 for BranchEqualStep where F: PrimeField32, - A: 'static - + for<'a> AdapterExecutorE1, + A: 'static + for<'a> AdapterExecutorE1, WriteData = ()>, { fn execute_e1( &mut self, @@ -240,7 +242,7 @@ where let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); // TODO(ayush): probably don't need the other values let (cmp_result, _, _) = run_eq::(branch_eq_opcode, &rs1, &rs2); diff --git a/extensions/rv32im/circuit/src/branch_lt/core.rs b/extensions/rv32im/circuit/src/branch_lt/core.rs index 6f15ec80af..30d287731f 100644 --- a/extensions/rv32im/circuit/src/branch_lt/core.rs +++ b/extensions/rv32im/circuit/src/branch_lt/core.rs @@ -26,8 +26,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -223,7 +221,7 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData = (), TraceContext<'a> = (), >, @@ -247,12 +245,15 @@ where let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let (cmp_result, diff_idx, a_sign, b_sign) = run_cmp::(blt_opcode, &rs1, &rs2); @@ -355,8 +356,7 @@ impl StepExecutorE1 for BranchLessThanStep where F: PrimeField32, - A: 'static - + for<'a> AdapterExecutorE1, + A: 'static + for<'a> AdapterExecutorE1, WriteData = ()>, { fn execute_e1( &mut self, @@ -370,7 +370,7 @@ where let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); // TODO(ayush): probably don't need the other values let (cmp_result, _, _, _) = run_cmp::(blt_opcode, &rs1, &rs2); diff --git a/extensions/rv32im/circuit/src/divrem/core.rs b/extensions/rv32im/circuit/src/divrem/core.rs index 8e38e6aabc..a66048b05c 100644 --- a/extensions/rv32im/circuit/src/divrem/core.rs +++ b/extensions/rv32im/circuit/src/divrem/core.rs @@ -29,8 +29,7 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; + use strum::IntoEnumIterator; #[repr(C)] @@ -401,8 +400,8 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, TraceContext<'a> = (), >, { @@ -425,12 +424,15 @@ where let is_signed = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::REM; let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU; - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let b = rs1.map(u32::from); let c = rs2.map(u32::from); @@ -514,7 +516,7 @@ where }; self.adapter - .write(state.memory, instruction, adapter_row, &rd); + .write(state.memory, instruction, adapter_row, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -537,8 +539,8 @@ where A: 'static + for<'a> AdapterExecutorE1< F, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, >, { fn execute_e1( @@ -554,7 +556,7 @@ where // Determine opcode and operation type let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); let rs1 = rs1.map(u32::from); let rs2 = rs2.map(u32::from); @@ -571,7 +573,7 @@ where r.map(|x| x as u8) }; - self.adapter.write(state.memory, instruction, &rd); + self.adapter.write(state.memory, instruction, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index d9dd57dc4c..c2a6a973c9 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -225,7 +225,7 @@ impl VmExtension for Rv32I { BaseAluCoreAir::new(bitwise_lu_chip.bus(), BaseAluOpcode::CLASS_OFFSET), ), Rv32BaseAluStep::new( - Rv32BaseAluAdapterStep::new(), + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), BaseAluOpcode::CLASS_OFFSET, ), @@ -247,7 +247,7 @@ impl VmExtension for Rv32I { LessThanCoreAir::new(bitwise_lu_chip.bus(), LessThanOpcode::CLASS_OFFSET), ), LessThanStep::new( - Rv32BaseAluAdapterStep::new(), + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), LessThanOpcode::CLASS_OFFSET, ), @@ -270,7 +270,7 @@ impl VmExtension for Rv32I { ), ), ShiftStep::new( - Rv32BaseAluAdapterStep::new(), + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), range_checker.clone(), ShiftOpcode::CLASS_OFFSET, diff --git a/extensions/rv32im/circuit/src/less_than/core.rs b/extensions/rv32im/circuit/src/less_than/core.rs index 5d63f03652..0ad602d63d 100644 --- a/extensions/rv32im/circuit/src/less_than/core.rs +++ b/extensions/rv32im/circuit/src/less_than/core.rs @@ -14,9 +14,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, SharedBitwiseOperationLookupChip, - }, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -28,8 +26,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -199,9 +195,9 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], - TraceContext<'a> = &'a BitwiseOperationLookupChip, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), >, { fn get_opcode_name(&self, opcode: usize) -> String { @@ -222,12 +218,15 @@ where let local_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let (cmp_result, _, _, _) = run_less_than::(local_opcode, &rs1, &rs2); @@ -241,7 +240,7 @@ where output[0] = cmp_result as u8; self.adapter - .write(state.memory, instruction, adapter_row, &output); + .write(state.memory, instruction, adapter_row, &[output].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -253,8 +252,7 @@ where fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - self.adapter - .fill_trace_row(mem_helper, self.bitwise_lookup_chip.as_ref(), adapter_row); + self.adapter.fill_trace_row(mem_helper, (), adapter_row); let core_row: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); @@ -333,8 +331,8 @@ where A: 'static + for<'a> AdapterExecutorE1< F, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, >, { fn execute_e1( @@ -349,7 +347,7 @@ where let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); // Run the comparison let (cmp_result, _, _, _) = @@ -357,7 +355,7 @@ where let mut rd = [0u8; NUM_LIMBS]; rd[0] = cmp_result as u8; - self.adapter.write(state.memory, instruction, &rd); + self.adapter.write(state.memory, instruction, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); diff --git a/extensions/rv32im/circuit/src/less_than/tests.rs b/extensions/rv32im/circuit/src/less_than/tests.rs index 52ac6f67c4..373c2920f7 100644 --- a/extensions/rv32im/circuit/src/less_than/tests.rs +++ b/extensions/rv32im/circuit/src/less_than/tests.rs @@ -58,7 +58,7 @@ fn create_test_chip( LessThanCoreAir::new(bitwise_bus, LessThanOpcode::CLASS_OFFSET), ), LessThanStep::new( - Rv32BaseAluAdapterStep::new(), + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET, ), diff --git a/extensions/rv32im/circuit/src/loadstore/tests.rs b/extensions/rv32im/circuit/src/loadstore/tests.rs index 72821c2e37..3cea7e5067 100644 --- a/extensions/rv32im/circuit/src/loadstore/tests.rs +++ b/extensions/rv32im/circuit/src/loadstore/tests.rs @@ -39,7 +39,7 @@ const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32LoadStoreChip { - let range_checker_chip = tester.memory_controller().range_checker.clone(); + let range_checker_chip = tester.range_checker(); let chip = Rv32LoadStoreChip::::new( VmAirWrapper::new( Rv32LoadStoreAdapterAir::new( diff --git a/extensions/rv32im/circuit/src/mul/core.rs b/extensions/rv32im/circuit/src/mul/core.rs index fb4469fdaf..efa115aa54 100644 --- a/extensions/rv32im/circuit/src/mul/core.rs +++ b/extensions/rv32im/circuit/src/mul/core.rs @@ -23,8 +23,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; #[repr(C)] #[derive(AlignedBorrow)] @@ -160,8 +158,8 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, TraceContext<'a> = (), >, { @@ -184,12 +182,15 @@ where MulOpcode::MUL ); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let (a, carry) = run_mul::(&rs1, &rs2); @@ -206,7 +207,7 @@ where // TODO(ayush): avoid this conversion self.adapter - .write(state.memory, instruction, adapter_row, &a); + .write(state.memory, instruction, adapter_row, &[a].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -229,8 +230,8 @@ where A: 'static + for<'a> AdapterExecutorE1< F, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, >, { fn execute_e1( @@ -250,11 +251,11 @@ where MulOpcode::MUL ); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); let (rd, _) = run_mul::(&rs1, &rs2); - self.adapter.write(state.memory, instruction, &rd); + self.adapter.write(state.memory, instruction, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); diff --git a/extensions/rv32im/circuit/src/mulh/core.rs b/extensions/rv32im/circuit/src/mulh/core.rs index 62672f05dc..555a70bbd7 100644 --- a/extensions/rv32im/circuit/src/mulh/core.rs +++ b/extensions/rv32im/circuit/src/mulh/core.rs @@ -26,8 +26,7 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; + use strum::IntoEnumIterator; #[repr(C)] @@ -231,8 +230,8 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, TraceContext<'a> = (), >, { @@ -255,12 +254,15 @@ where let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let b = rs1.map(u32::from); let c = rs2.map(u32::from); @@ -296,7 +298,7 @@ where // TODO(ayush): avoid this conversion let a = a.map(|x| x as u8); self.adapter - .write(state.memory, instruction, adapter_row, &a); + .write(state.memory, instruction, adapter_row, &[a].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -319,8 +321,8 @@ where A: 'static + for<'a> AdapterExecutorE1< F, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, >, { fn execute_e1( @@ -335,14 +337,14 @@ where let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); let rs1 = rs1.map(u32::from); let rs2 = rs2.map(u32::from); let (rd, _, _, _, _) = run_mulh::(mulh_opcode, &rs1, &rs2); let rd = rd.map(|x| x as u8); - self.adapter.write(state.memory, instruction, &rd); + self.adapter.write(state.memory, instruction, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); diff --git a/extensions/rv32im/circuit/src/shift/core.rs b/extensions/rv32im/circuit/src/shift/core.rs index c31ff76b9b..ba68e61a05 100644 --- a/extensions/rv32im/circuit/src/shift/core.rs +++ b/extensions/rv32im/circuit/src/shift/core.rs @@ -14,9 +14,7 @@ use openvm_circuit::{ }, }; use openvm_circuit_primitives::{ - bitwise_op_lookup::{ - BitwiseOperationLookupBus, BitwiseOperationLookupChip, SharedBitwiseOperationLookupChip, - }, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; @@ -275,9 +273,9 @@ where + for<'a> AdapterTraceStep< F, CTX, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], - TraceContext<'a> = &'a BitwiseOperationLookupChip, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), >, { fn get_opcode_name(&self, opcode: usize) -> String { @@ -296,12 +294,15 @@ where let local_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); - let (rs1, rs2) = self.adapter.read(state.memory, instruction, adapter_row); + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); let (output, limb_shift, bit_shift) = run_shift::(local_opcode, &rs1, &rs2); @@ -316,7 +317,7 @@ where core_row.limb_shift_marker[0] = F::from_canonical_usize(limb_shift); self.adapter - .write(state.memory, instruction, adapter_row, &output); + .write(state.memory, instruction, adapter_row, &[output].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -328,8 +329,7 @@ where fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - self.adapter - .fill_trace_row(mem_helper, self.bitwise_lookup_chip.as_ref(), adapter_row); + self.adapter.fill_trace_row(mem_helper, (), adapter_row); let core_row: &mut ShiftCoreCols = core_row.borrow_mut(); @@ -393,8 +393,8 @@ where A: 'static + for<'a> AdapterExecutorE1< F, - ReadData = ([u8; NUM_LIMBS], [u8; NUM_LIMBS]), - WriteData = [u8; NUM_LIMBS], + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, >, { fn execute_e1( @@ -409,11 +409,11 @@ where let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let (rs1, rs2) = self.adapter.read(state.memory, instruction); + let [rs1, rs2] = self.adapter.read(state.memory, instruction).into(); let (rd, _, _) = run_shift::(shift_opcode, &rs1, &rs2); - self.adapter.write(state.memory, instruction, &rd); + self.adapter.write(state.memory, instruction, &[rd].into()); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); diff --git a/extensions/rv32im/circuit/src/shift/tests.rs b/extensions/rv32im/circuit/src/shift/tests.rs index 7b92556ccd..fd111e2342 100644 --- a/extensions/rv32im/circuit/src/shift/tests.rs +++ b/extensions/rv32im/circuit/src/shift/tests.rs @@ -58,7 +58,7 @@ fn create_test_chip( ), ), ShiftStep::new( - Rv32BaseAluAdapterStep::new(), + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), bitwise_chip.clone(), tester.range_checker().clone(), ShiftOpcode::CLASS_OFFSET,