Skip to content

Commit

Permalink
seq in memory id->value
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Feb 4, 2025
1 parent 5cf254f commit 1e3171f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,25 @@ use stwo_prover::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
};
use stwo_prover::core::channel::Channel;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::relation;

use crate::cairo_air::preprocessed::Seq;
use crate::input::memory::LARGE_MEMORY_VALUE_ID_BASE;
use crate::relations;

// TODO(AlonH): Make memory size configurable.
pub const MEMORY_ID_SIZE: usize = 1;
pub const N_M31_IN_FELT252: usize = 28;
pub const N_M31_IN_SMALL_FELT252: usize = 8; // 72 bits.
pub const N_MULTIPLICITY_COLUMNS: usize = 1;
pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = BIG_N_ID_AND_VALUE_COLUMNS;
pub const BIG_N_COLUMNS: usize = BIG_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS;
pub const BIG_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_FELT252;
pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = SMALL_N_ID_AND_VALUE_COLUMNS;
pub const SMALL_N_COLUMNS: usize = SMALL_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS;
pub const SMALL_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_SMALL_FELT252;
pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_FELT252;
pub const BIG_N_COLUMNS: usize = N_M31_IN_FELT252 + N_MULTIPLICITY_COLUMNS;
pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_SMALL_FELT252;
pub const SMALL_N_COLUMNS: usize = N_M31_IN_SMALL_FELT252 + N_MULTIPLICITY_COLUMNS;

pub type BigComponent = FrameworkComponent<BigEval>;
pub type SmallComponent = FrameworkComponent<SmallEval>;
Expand Down Expand Up @@ -68,12 +69,12 @@ impl FrameworkEval for BigEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let id_and_value: [E::F; MEMORY_ID_SIZE + N_M31_IN_FELT252] =
std::array::from_fn(|_| eval.next_trace_mask());
let seq = eval.get_preprocessed_column(Seq::new(self.log_size()).id());
let value: [E::F; N_M31_IN_FELT252] = std::array::from_fn(|_| eval.next_trace_mask());
let multiplicity = eval.next_trace_mask();

// Range check limbs.
for (l, r) in id_and_value[MEMORY_ID_SIZE..].iter().tuples() {
for (l, r) in value.iter().tuples() {
eval.add_to_relation(RelationEntry::new(
&self.range9_9_lookup_elements,
E::EF::one(),
Expand All @@ -82,10 +83,11 @@ impl FrameworkEval for BigEval {
}

// Yield the value.
let id = seq + E::F::from(M31::from(LARGE_MEMORY_VALUE_ID_BASE));
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&id_and_value,
&chain!([id], value).collect_vec(),
));

eval.finalize_logup_in_pairs();
Expand Down Expand Up @@ -124,12 +126,12 @@ impl FrameworkEval for SmallEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let id_and_value: [E::F; SMALL_N_ID_AND_VALUE_COLUMNS] =
std::array::from_fn(|_| eval.next_trace_mask());
let seq = eval.get_preprocessed_column(Seq::new(self.log_size()).id());
let value: [E::F; N_M31_IN_SMALL_FELT252] = std::array::from_fn(|_| eval.next_trace_mask());
let multiplicity = eval.next_trace_mask();

// Range check limbs.
for (l, r) in id_and_value[MEMORY_ID_SIZE..].iter().tuples() {
for (l, r) in value.iter().tuples() {
eval.add_to_relation(RelationEntry::new(
&self.range_check_9_9_relation,
E::EF::one(),
Expand All @@ -138,10 +140,11 @@ impl FrameworkEval for SmallEval {
}

// Yield the value.
let id = seq;
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&id_and_value,
&chain!([id], value).collect_vec(),
));

eval.finalize_logup();
Expand All @@ -156,7 +159,7 @@ pub struct Claim {
}
impl Claim {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let preprocessed_log_sizes = vec![self.big_log_size, self.small_log_size];
let preprocessed_log_sizes = vec![];
let trace_log_sizes = chain!(
vec![self.big_log_size; BIG_N_COLUMNS],
vec![self.small_log_size; SMALL_N_COLUMNS]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;

use super::component::{
Claim, InteractionClaim, BIG_N_ID_AND_VALUE_COLUMNS, MEMORY_ID_SIZE, N_M31_IN_SMALL_FELT252,
SMALL_N_ID_AND_VALUE_COLUMNS,
Claim, InteractionClaim, MEMORY_ID_SIZE, N_M31_IN_FELT252, N_M31_IN_SMALL_FELT252,
};
use crate::components::memory::MEMORY_ADDRESS_BOUND;
use crate::components::range_check_vector::range_check_9_9;
use crate::components::range_check_vector::{range_check_9_9, SIMD_ENUMERATION_0};
use crate::components::utils::AtomicMultiplicityColumn;
use crate::felt::split_f252_simd;
use crate::input::memory::{
u128_to_4_limbs, EncodedMemoryValueId, Memory, MemoryValueId, LARGE_MEMORY_VALUE_ID_TAG,
u128_to_4_limbs, EncodedMemoryValueId, Memory, MemoryValueId, LARGE_MEMORY_VALUE_ID_BASE,
};
use crate::relations;

Expand Down Expand Up @@ -130,22 +129,22 @@ impl ClaimGenerator {
gen_small_memory_trace(self.small_values, self.small_mults.into_simd_vec());

// Lookup data.
let big_ids_and_values: [_; BIG_N_ID_AND_VALUE_COLUMNS] =
let big_values: [_; N_M31_IN_FELT252] =
std::array::from_fn(|i| big_table_trace[i].data.clone());
let big_multiplicities = big_table_trace.last().unwrap().data.clone();
let small_ids_and_values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] =
let small_values: [_; N_M31_IN_SMALL_FELT252] =
std::array::from_fn(|i| small_table_trace[i].data.clone());
let small_multiplicities = small_table_trace.last().unwrap().data.clone();

// Add inputs to range check that all the values are 9-bit felts.
for (col0, col1) in big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() {
for (col0, col1) in big_values.iter().tuples() {
col0.par_iter()
.zip(col1.par_iter())
.for_each(|(val0, val1)| {
range_check_9_9_trace_generator.add_packed_m31(&[*val0, *val1]);
});
}
for (col0, col1) in small_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() {
for (col0, col1) in small_values.iter().tuples() {
col0.par_iter()
.zip(col1.par_iter())
.for_each(|(val0, val1)| {
Expand Down Expand Up @@ -183,9 +182,9 @@ impl ClaimGenerator {
small_log_size,
},
InteractionClaimGenerator {
big_ids_and_values,
big_values,
big_multiplicities,
small_ids_and_values,
small_values,
small_multiplicities,
},
)
Expand All @@ -203,26 +202,20 @@ fn gen_big_memory_trace(values: Vec<[u32; 8]>, mults: Vec<PackedM31>) -> Vec<Bas
})
.collect_vec();

let mut id_and_value_trace =
let mut value_trace =
std::iter::repeat_with(|| unsafe { BaseColumn::uninitialized(column_length) })
.take(BIG_N_ID_AND_VALUE_COLUMNS)
.take(N_M31_IN_FELT252)
.collect_vec();
let inc = Simd::from_array(std::array::from_fn(|i| i as u32));
for (i, values) in packed_values.iter().enumerate() {
let values = split_f252_simd(*values);
id_and_value_trace[0].data[i] = unsafe {
PackedM31::from_simd_unchecked(
Simd::splat((i * N_LANES) as u32 + LARGE_MEMORY_VALUE_ID_TAG) + inc,
)
};
for (j, value) in values.iter().enumerate() {
id_and_value_trace[j + 1].data[i] = *value;
value_trace[j].data[i] = *value;
}
}

let multiplicities = BaseColumn::from_simd(mults);

chain!(id_and_value_trace, [multiplicities]).collect_vec()
chain!(value_trace, [multiplicities]).collect_vec()
}

// Generates the trace of the small value memory table.
Expand All @@ -237,11 +230,10 @@ fn gen_small_memory_trace(values: Vec<u128>, mults: Vec<PackedM31>) -> Vec<BaseC
})
.collect_vec();

let mut id_and_value_trace =
let mut values_trace =
std::iter::repeat_with(|| unsafe { BaseColumn::uninitialized(column_length) })
.take(SMALL_N_ID_AND_VALUE_COLUMNS)
.take(N_M31_IN_SMALL_FELT252)
.collect_vec();
let inc = Simd::from_array(std::array::from_fn(|i| i as u32));
for (i, values) in packed_values.iter().enumerate() {
let values = split_f252_simd([
values[0],
Expand All @@ -253,31 +245,29 @@ fn gen_small_memory_trace(values: Vec<u128>, mults: Vec<PackedM31>) -> Vec<BaseC
Simd::splat(0),
Simd::splat(0),
]);
id_and_value_trace[0].data[i] =
unsafe { PackedM31::from_simd_unchecked(Simd::splat((i * N_LANES) as u32) + inc) };
for (j, value) in values[..N_M31_IN_SMALL_FELT252].iter().enumerate() {
id_and_value_trace[j + 1].data[i] = *value;
values_trace[j].data[i] = *value;
}
}

let multiplicities = BaseColumn::from_simd(mults);

chain!(id_and_value_trace, [multiplicities]).collect_vec()
chain!(values_trace, [multiplicities]).collect_vec()
}

#[derive(Debug)]
pub struct InteractionClaimGenerator {
pub big_ids_and_values: [Vec<PackedM31>; BIG_N_ID_AND_VALUE_COLUMNS],
pub big_values: [Vec<PackedM31>; N_M31_IN_FELT252],
pub big_multiplicities: Vec<PackedM31>,
pub small_ids_and_values: [Vec<PackedM31>; SMALL_N_ID_AND_VALUE_COLUMNS],
pub small_values: [Vec<PackedM31>; N_M31_IN_SMALL_FELT252],
pub small_multiplicities: Vec<PackedM31>,
}
impl InteractionClaimGenerator {
pub fn with_capacity(capacity: usize) -> Self {
Self {
big_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
big_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
big_multiplicities: Vec::with_capacity(capacity),
small_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
small_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
small_multiplicities: Vec::with_capacity(capacity),
}
}
Expand Down Expand Up @@ -313,12 +303,11 @@ impl InteractionClaimGenerator {
Vec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let big_table_log_size = self.big_ids_and_values[0].len().ilog2() + LOG_N_LANES;
let big_table_log_size = self.big_values[0].len().ilog2() + LOG_N_LANES;
let mut big_values_logup_gen = LogupTraceGenerator::new(big_table_log_size);

// Every element is 9-bit.
for (limb0, limb1, limb2, lim3) in self.big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples()
{
for (limb0, limb1, limb2, lim3) in self.big_values.iter().tuples() {
let mut col_gen = big_values_logup_gen.new_col();
for (vec_row, (limb0, limb1, limb2, limb3)) in
izip!(limb0, limb1, limb2, lim3).enumerate()
Expand All @@ -332,10 +321,21 @@ impl InteractionClaimGenerator {

// Yield large values.
let mut col_gen = big_values_logup_gen.new_col();
let large_memory_value_id_tag = Simd::splat(LARGE_MEMORY_VALUE_ID_BASE);
for vec_row in 0..1 << (big_table_log_size - LOG_N_LANES) {
let values: [_; BIG_N_ID_AND_VALUE_COLUMNS] =
std::array::from_fn(|i| self.big_ids_and_values[i][vec_row]);
let denom: PackedQM31 = lookup_elements.combine(&values);
let id_and_value: [_; N_M31_IN_FELT252 + MEMORY_ID_SIZE] = std::array::from_fn(|i| {
if i == 0 {
unsafe {
PackedM31::from_simd_unchecked(
(SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32))
| large_memory_value_id_tag,
)
}
} else {
self.big_values[i - 1][vec_row]
}
});
let denom: PackedQM31 = lookup_elements.combine(&id_and_value);
col_gen.write_frac(vec_row, (-self.big_multiplicities[vec_row]).into(), denom);
}
col_gen.finalize_col();
Expand All @@ -351,11 +351,11 @@ impl InteractionClaimGenerator {
Vec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let small_table_log_size = self.small_ids_and_values[0].len().ilog2() + LOG_N_LANES;
let small_table_log_size = self.small_values[0].len().ilog2() + LOG_N_LANES;
let mut small_values_logup_gen = LogupTraceGenerator::new(small_table_log_size);

// Every element is 9-bit.
for (l, r) in self.small_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() {
for (l, r) in self.small_values.iter().tuples() {
let mut col_gen = small_values_logup_gen.new_col();
for (vec_row, (l1, l2)) in zip(l, r).enumerate() {
// TOOD(alont) Add 2-batching.
Expand All @@ -371,9 +371,19 @@ impl InteractionClaimGenerator {
// Yield small values.
let mut col_gen = small_values_logup_gen.new_col();
for vec_row in 0..1 << (small_table_log_size - LOG_N_LANES) {
let values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] =
std::array::from_fn(|i| self.small_ids_and_values[i][vec_row]);
let denom: PackedQM31 = lookup_elements.combine(&values);
let id_and_value: [_; N_M31_IN_SMALL_FELT252 + MEMORY_ID_SIZE] =
std::array::from_fn(|i| {
if i == 0 {
unsafe {
PackedM31::from_simd_unchecked(
SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32),
)
}
} else {
self.small_values[i - 1][vec_row]
}
});
let denom: PackedQM31 = lookup_elements.combine(&id_and_value);
col_gen.write_frac(vec_row, (-self.small_multiplicities[vec_row]).into(), denom);
}
col_gen.finalize_col();
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ impl DerefMut for MemoryBuilder {
}
}

pub const LARGE_MEMORY_VALUE_ID_TAG: u32 = 0x4000_0000;
pub const LARGE_MEMORY_VALUE_ID_BASE: u32 = 0x4000_0000;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct EncodedMemoryValueId(pub u32);
impl EncodedMemoryValueId {
pub fn encode(value: MemoryValueId) -> EncodedMemoryValueId {
match value {
MemoryValueId::Small(id) => EncodedMemoryValueId(id),
MemoryValueId::F252(id) => EncodedMemoryValueId(id | LARGE_MEMORY_VALUE_ID_TAG),
MemoryValueId::F252(id) => EncodedMemoryValueId(id | LARGE_MEMORY_VALUE_ID_BASE),
}
}
pub fn decode(&self) -> MemoryValueId {
Expand Down

0 comments on commit 1e3171f

Please sign in to comment.