diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/component.rs b/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/component.rs index 6661b12d0..cce58db3c 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/component.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/component.rs @@ -7,11 +7,14 @@ use stwo_prover::constraint_framework::{ EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, }; use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::m31::M31; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use stwo_prover::core::pcs::TreeVec; use stwo_prover::relation; +use crate::cairo_air::preprocessed::Seq; +use crate::input::memory::LARGE_MEMORY_VALUE_ID_BASE; use crate::relations; // TODO(AlonH): Make memory size configurable. @@ -19,12 +22,10 @@ pub const MEMORY_ID_SIZE: usize = 1; pub const N_M31_IN_FELT252: usize = 28; pub const N_M31_IN_SMALL_FELT252: usize = 8; // 72 bits. pub const N_MULTIPLICITY_COLUMNS: usize = 1; -pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = BIG_N_ID_AND_VALUE_COLUMNS; -pub const BIG_N_COLUMNS: usize = BIG_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS; -pub const BIG_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_FELT252; -pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = SMALL_N_ID_AND_VALUE_COLUMNS; -pub const SMALL_N_COLUMNS: usize = SMALL_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS; -pub const SMALL_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_SMALL_FELT252; +pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_FELT252; +pub const BIG_N_COLUMNS: usize = N_M31_IN_FELT252 + N_MULTIPLICITY_COLUMNS; +pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_SMALL_FELT252; +pub const SMALL_N_COLUMNS: usize = N_M31_IN_SMALL_FELT252 + N_MULTIPLICITY_COLUMNS; pub type BigComponent = FrameworkComponent; pub type SmallComponent = FrameworkComponent; @@ -68,12 +69,12 @@ impl FrameworkEval for BigEval { } fn evaluate(&self, mut eval: E) -> E { - let id_and_value: [E::F; MEMORY_ID_SIZE + N_M31_IN_FELT252] = - std::array::from_fn(|_| eval.next_trace_mask()); + let seq = eval.get_preprocessed_column(Seq::new(self.log_size()).id()); + let value: [E::F; N_M31_IN_FELT252] = std::array::from_fn(|_| eval.next_trace_mask()); let multiplicity = eval.next_trace_mask(); // Range check limbs. - for (l, r) in id_and_value[MEMORY_ID_SIZE..].iter().tuples() { + for (l, r) in value.iter().tuples() { eval.add_to_relation(RelationEntry::new( &self.range9_9_lookup_elements, E::EF::one(), @@ -85,7 +86,11 @@ impl FrameworkEval for BigEval { eval.add_to_relation(RelationEntry::new( &self.lookup_elements, E::EF::from(-multiplicity), - &id_and_value, + &chain!( + [seq + E::F::from(M31::from(LARGE_MEMORY_VALUE_ID_BASE))], + value + ) + .collect_vec(), )); eval.finalize_logup_in_pairs(); @@ -124,12 +129,12 @@ impl FrameworkEval for SmallEval { } fn evaluate(&self, mut eval: E) -> E { - let id_and_value: [E::F; SMALL_N_ID_AND_VALUE_COLUMNS] = - std::array::from_fn(|_| eval.next_trace_mask()); + let seq = eval.get_preprocessed_column(Seq::new(self.log_size()).id()); + let value: [E::F; N_M31_IN_SMALL_FELT252] = std::array::from_fn(|_| eval.next_trace_mask()); let multiplicity = eval.next_trace_mask(); // Range check limbs. - for (l, r) in id_and_value[MEMORY_ID_SIZE..].iter().tuples() { + for (l, r) in value.iter().tuples() { eval.add_to_relation(RelationEntry::new( &self.range_check_9_9_relation, E::EF::one(), @@ -141,7 +146,7 @@ impl FrameworkEval for SmallEval { eval.add_to_relation(RelationEntry::new( &self.lookup_elements, E::EF::from(-multiplicity), - &id_and_value, + &chain!([seq], value).collect_vec(), )); eval.finalize_logup(); @@ -156,7 +161,7 @@ pub struct Claim { } impl Claim { pub fn log_sizes(&self) -> TreeVec> { - let preprocessed_log_sizes = vec![self.big_log_size, self.small_log_size]; + let preprocessed_log_sizes = vec![]; let trace_log_sizes = chain!( vec![self.big_log_size; BIG_N_COLUMNS], vec![self.small_log_size; SMALL_N_COLUMNS] diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/prover.rs b/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/prover.rs index eea546c93..33778ce16 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/prover.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/memory_id_to_big/prover.rs @@ -19,15 +19,14 @@ use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; use super::component::{ - Claim, InteractionClaim, BIG_N_ID_AND_VALUE_COLUMNS, MEMORY_ID_SIZE, N_M31_IN_SMALL_FELT252, - SMALL_N_ID_AND_VALUE_COLUMNS, + Claim, InteractionClaim, MEMORY_ID_SIZE, N_M31_IN_FELT252, N_M31_IN_SMALL_FELT252, }; use crate::components::memory::MEMORY_ADDRESS_BOUND; -use crate::components::range_check_vector::range_check_9_9; +use crate::components::range_check_vector::{range_check_9_9, SIMD_ENUMERATION_0}; use crate::components::utils::AtomicMultiplicityColumn; use crate::felt::split_f252_simd; use crate::input::memory::{ - u128_to_4_limbs, EncodedMemoryValueId, Memory, MemoryValueId, LARGE_MEMORY_VALUE_ID_TAG, + u128_to_4_limbs, EncodedMemoryValueId, Memory, MemoryValueId, LARGE_MEMORY_VALUE_ID_BASE, }; use crate::relations; @@ -130,22 +129,22 @@ impl ClaimGenerator { gen_small_memory_trace(self.small_values, self.small_mults.into_simd_vec()); // Lookup data. - let big_ids_and_values: [_; BIG_N_ID_AND_VALUE_COLUMNS] = + let big_values: [_; N_M31_IN_FELT252] = std::array::from_fn(|i| big_table_trace[i].data.clone()); let big_multiplicities = big_table_trace.last().unwrap().data.clone(); - let small_ids_and_values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] = + let small_values: [_; N_M31_IN_SMALL_FELT252] = std::array::from_fn(|i| small_table_trace[i].data.clone()); let small_multiplicities = small_table_trace.last().unwrap().data.clone(); // Add inputs to range check that all the values are 9-bit felts. - for (col0, col1) in big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { + for (col0, col1) in big_values.iter().tuples() { col0.par_iter() .zip(col1.par_iter()) .for_each(|(val0, val1)| { range_check_9_9_trace_generator.add_packed_m31(&[*val0, *val1]); }); } - for (col0, col1) in small_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { + for (col0, col1) in small_values.iter().tuples() { col0.par_iter() .zip(col1.par_iter()) .for_each(|(val0, val1)| { @@ -183,9 +182,9 @@ impl ClaimGenerator { small_log_size, }, InteractionClaimGenerator { - big_ids_and_values, + big_values, big_multiplicities, - small_ids_and_values, + small_values, small_multiplicities, }, ) @@ -203,26 +202,20 @@ fn gen_big_memory_trace(values: Vec<[u32; 8]>, mults: Vec) -> Vec, mults: Vec) -> Vec, mults: Vec) -> Vec; BIG_N_ID_AND_VALUE_COLUMNS], + pub big_values: [Vec; N_M31_IN_FELT252], pub big_multiplicities: Vec, - pub small_ids_and_values: [Vec; SMALL_N_ID_AND_VALUE_COLUMNS], + pub small_values: [Vec; N_M31_IN_SMALL_FELT252], pub small_multiplicities: Vec, } impl InteractionClaimGenerator { pub fn with_capacity(capacity: usize) -> Self { Self { - big_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), + big_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), big_multiplicities: Vec::with_capacity(capacity), - small_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), + small_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), small_multiplicities: Vec::with_capacity(capacity), } } @@ -313,12 +303,11 @@ impl InteractionClaimGenerator { Vec>, QM31, ) { - let big_table_log_size = self.big_ids_and_values[0].len().ilog2() + LOG_N_LANES; + let big_table_log_size = self.big_values[0].len().ilog2() + LOG_N_LANES; let mut big_values_logup_gen = LogupTraceGenerator::new(big_table_log_size); // Every element is 9-bit. - for (limb0, limb1, limb2, lim3) in self.big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() - { + for (limb0, limb1, limb2, lim3) in self.big_values.iter().tuples() { let mut col_gen = big_values_logup_gen.new_col(); for (vec_row, (limb0, limb1, limb2, limb3)) in izip!(limb0, limb1, limb2, lim3).enumerate() @@ -332,10 +321,21 @@ impl InteractionClaimGenerator { // Yield large values. let mut col_gen = big_values_logup_gen.new_col(); + let large_memory_value_id_tag = Simd::splat(LARGE_MEMORY_VALUE_ID_BASE); for vec_row in 0..1 << (big_table_log_size - LOG_N_LANES) { - let values: [_; BIG_N_ID_AND_VALUE_COLUMNS] = - std::array::from_fn(|i| self.big_ids_and_values[i][vec_row]); - let denom: PackedQM31 = lookup_elements.combine(&values); + let id_and_value: [_; N_M31_IN_FELT252 + MEMORY_ID_SIZE] = std::array::from_fn(|i| { + if i == 0 { + unsafe { + PackedM31::from_simd_unchecked( + (SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) + | large_memory_value_id_tag, + ) + } + } else { + self.big_values[i - 1][vec_row] + } + }); + let denom: PackedQM31 = lookup_elements.combine(&id_and_value); col_gen.write_frac(vec_row, (-self.big_multiplicities[vec_row]).into(), denom); } col_gen.finalize_col(); @@ -351,11 +351,11 @@ impl InteractionClaimGenerator { Vec>, QM31, ) { - let small_table_log_size = self.small_ids_and_values[0].len().ilog2() + LOG_N_LANES; + let small_table_log_size = self.small_values[0].len().ilog2() + LOG_N_LANES; let mut small_values_logup_gen = LogupTraceGenerator::new(small_table_log_size); // Every element is 9-bit. - for (l, r) in self.small_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { + for (l, r) in self.small_values.iter().tuples() { let mut col_gen = small_values_logup_gen.new_col(); for (vec_row, (l1, l2)) in zip(l, r).enumerate() { // TOOD(alont) Add 2-batching. @@ -371,9 +371,19 @@ impl InteractionClaimGenerator { // Yield small values. let mut col_gen = small_values_logup_gen.new_col(); for vec_row in 0..1 << (small_table_log_size - LOG_N_LANES) { - let values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] = - std::array::from_fn(|i| self.small_ids_and_values[i][vec_row]); - let denom: PackedQM31 = lookup_elements.combine(&values); + let id_and_value: [_; N_M31_IN_SMALL_FELT252 + MEMORY_ID_SIZE] = + std::array::from_fn(|i| { + if i == 0 { + unsafe { + PackedM31::from_simd_unchecked( + SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32), + ) + } + } else { + self.small_values[i - 1][vec_row] + } + }); + let denom: PackedQM31 = lookup_elements.combine(&id_and_value); col_gen.write_frac(vec_row, (-self.small_multiplicities[vec_row]).into(), denom); } col_gen.finalize_col(); diff --git a/stwo_cairo_prover/crates/prover/src/input/memory.rs b/stwo_cairo_prover/crates/prover/src/input/memory.rs index 89fced9d8..2f7454461 100644 --- a/stwo_cairo_prover/crates/prover/src/input/memory.rs +++ b/stwo_cairo_prover/crates/prover/src/input/memory.rs @@ -211,14 +211,14 @@ impl DerefMut for MemoryBuilder { } } -pub const LARGE_MEMORY_VALUE_ID_TAG: u32 = 0x4000_0000; +pub const LARGE_MEMORY_VALUE_ID_BASE: u32 = 0x4000_0000; #[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] pub struct EncodedMemoryValueId(pub u32); impl EncodedMemoryValueId { pub fn encode(value: MemoryValueId) -> EncodedMemoryValueId { match value { MemoryValueId::Small(id) => EncodedMemoryValueId(id), - MemoryValueId::F252(id) => EncodedMemoryValueId(id | LARGE_MEMORY_VALUE_ID_TAG), + MemoryValueId::F252(id) => EncodedMemoryValueId(id | LARGE_MEMORY_VALUE_ID_BASE), } } pub fn decode(&self) -> MemoryValueId {