Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

seq in memory id->value #411

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,25 @@ use stwo_prover::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
};
use stwo_prover::core::channel::Channel;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::relation;

use crate::cairo_air::preprocessed::Seq;
use crate::input::memory::LARGE_MEMORY_VALUE_ID_BASE;
use crate::relations;

// TODO(AlonH): Make memory size configurable.
pub const MEMORY_ID_SIZE: usize = 1;
pub const N_M31_IN_FELT252: usize = 28;
pub const N_M31_IN_SMALL_FELT252: usize = 8; // 72 bits.
pub const N_MULTIPLICITY_COLUMNS: usize = 1;
pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = BIG_N_ID_AND_VALUE_COLUMNS;
pub const BIG_N_COLUMNS: usize = BIG_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS;
pub const BIG_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_FELT252;
pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = SMALL_N_ID_AND_VALUE_COLUMNS;
pub const SMALL_N_COLUMNS: usize = SMALL_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS;
pub const SMALL_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_SMALL_FELT252;
pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_FELT252;
pub const BIG_N_COLUMNS: usize = N_M31_IN_FELT252 + N_MULTIPLICITY_COLUMNS;
pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_SMALL_FELT252;
pub const SMALL_N_COLUMNS: usize = N_M31_IN_SMALL_FELT252 + N_MULTIPLICITY_COLUMNS;

pub type BigComponent = FrameworkComponent<BigEval>;
pub type SmallComponent = FrameworkComponent<SmallEval>;
Expand Down Expand Up @@ -68,12 +69,12 @@ impl FrameworkEval for BigEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let id_and_value: [E::F; MEMORY_ID_SIZE + N_M31_IN_FELT252] =
std::array::from_fn(|_| eval.next_trace_mask());
let seq = eval.get_preprocessed_column(Seq::new(self.log_size()).id());
let value: [E::F; N_M31_IN_FELT252] = std::array::from_fn(|_| eval.next_trace_mask());
let multiplicity = eval.next_trace_mask();

// Range check limbs.
for (l, r) in id_and_value[MEMORY_ID_SIZE..].iter().tuples() {
for (l, r) in value.iter().tuples() {
eval.add_to_relation(RelationEntry::new(
&self.range9_9_lookup_elements,
E::EF::one(),
Expand All @@ -82,10 +83,11 @@ impl FrameworkEval for BigEval {
}

// Yield the value.
let id = seq + E::F::from(M31::from(LARGE_MEMORY_VALUE_ID_BASE));
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&id_and_value,
&chain!([id], value).collect_vec(),
));

eval.finalize_logup_in_pairs();
Expand Down Expand Up @@ -124,12 +126,12 @@ impl FrameworkEval for SmallEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let id_and_value: [E::F; SMALL_N_ID_AND_VALUE_COLUMNS] =
std::array::from_fn(|_| eval.next_trace_mask());
let seq = eval.get_preprocessed_column(Seq::new(self.log_size()).id());
let value: [E::F; N_M31_IN_SMALL_FELT252] = std::array::from_fn(|_| eval.next_trace_mask());
let multiplicity = eval.next_trace_mask();

// Range check limbs.
for (l, r) in id_and_value[MEMORY_ID_SIZE..].iter().tuples() {
for (l, r) in value.iter().tuples() {
eval.add_to_relation(RelationEntry::new(
&self.range_check_9_9_relation,
E::EF::one(),
Expand All @@ -138,10 +140,11 @@ impl FrameworkEval for SmallEval {
}

// Yield the value.
let id = seq;
eval.add_to_relation(RelationEntry::new(
&self.lookup_elements,
E::EF::from(-multiplicity),
&id_and_value,
&chain!([id], value).collect_vec(),
));

eval.finalize_logup();
Expand All @@ -156,7 +159,7 @@ pub struct Claim {
}
impl Claim {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let preprocessed_log_sizes = vec![self.big_log_size, self.small_log_size];
let preprocessed_log_sizes = vec![];
let trace_log_sizes = chain!(
vec![self.big_log_size; BIG_N_COLUMNS],
vec![self.small_log_size; SMALL_N_COLUMNS]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;

use super::component::{
Claim, InteractionClaim, BIG_N_ID_AND_VALUE_COLUMNS, MEMORY_ID_SIZE, N_M31_IN_SMALL_FELT252,
SMALL_N_ID_AND_VALUE_COLUMNS,
Claim, InteractionClaim, MEMORY_ID_SIZE, N_M31_IN_FELT252, N_M31_IN_SMALL_FELT252,
};
use crate::components::memory::MEMORY_ADDRESS_BOUND;
use crate::components::range_check_vector::range_check_9_9;
use crate::components::range_check_vector::{range_check_9_9, SIMD_ENUMERATION_0};
use crate::components::utils::AtomicMultiplicityColumn;
use crate::felt::split_f252_simd;
use crate::input::memory::{
u128_to_4_limbs, EncodedMemoryValueId, Memory, MemoryValueId, LARGE_MEMORY_VALUE_ID_TAG,
u128_to_4_limbs, EncodedMemoryValueId, Memory, MemoryValueId, LARGE_MEMORY_VALUE_ID_BASE,
};
use crate::relations;

Expand Down Expand Up @@ -130,22 +129,22 @@ impl ClaimGenerator {
gen_small_memory_trace(self.small_values, self.small_mults.into_simd_vec());

// Lookup data.
let big_ids_and_values: [_; BIG_N_ID_AND_VALUE_COLUMNS] =
let big_values: [_; N_M31_IN_FELT252] =
std::array::from_fn(|i| big_table_trace[i].data.clone());
let big_multiplicities = big_table_trace.last().unwrap().data.clone();
let small_ids_and_values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] =
let small_values: [_; N_M31_IN_SMALL_FELT252] =
std::array::from_fn(|i| small_table_trace[i].data.clone());
let small_multiplicities = small_table_trace.last().unwrap().data.clone();

// Add inputs to range check that all the values are 9-bit felts.
for (col0, col1) in big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() {
for (col0, col1) in big_values.iter().tuples() {
col0.par_iter()
.zip(col1.par_iter())
.for_each(|(val0, val1)| {
range_check_9_9_trace_generator.add_packed_m31(&[*val0, *val1]);
});
}
for (col0, col1) in small_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() {
for (col0, col1) in small_values.iter().tuples() {
col0.par_iter()
.zip(col1.par_iter())
.for_each(|(val0, val1)| {
Expand Down Expand Up @@ -183,9 +182,9 @@ impl ClaimGenerator {
small_log_size,
},
InteractionClaimGenerator {
big_ids_and_values,
big_values,
big_multiplicities,
small_ids_and_values,
small_values,
small_multiplicities,
},
)
Expand All @@ -203,26 +202,20 @@ fn gen_big_memory_trace(values: Vec<[u32; 8]>, mults: Vec<PackedM31>) -> Vec<Bas
})
.collect_vec();

let mut id_and_value_trace =
let mut value_trace =
std::iter::repeat_with(|| unsafe { BaseColumn::uninitialized(column_length) })
.take(BIG_N_ID_AND_VALUE_COLUMNS)
.take(N_M31_IN_FELT252)
.collect_vec();
let inc = Simd::from_array(std::array::from_fn(|i| i as u32));
for (i, values) in packed_values.iter().enumerate() {
let values = split_f252_simd(*values);
id_and_value_trace[0].data[i] = unsafe {
PackedM31::from_simd_unchecked(
Simd::splat((i * N_LANES) as u32 + LARGE_MEMORY_VALUE_ID_TAG) + inc,
)
};
for (j, value) in values.iter().enumerate() {
id_and_value_trace[j + 1].data[i] = *value;
value_trace[j].data[i] = *value;
}
}

let multiplicities = BaseColumn::from_simd(mults);

chain!(id_and_value_trace, [multiplicities]).collect_vec()
chain!(value_trace, [multiplicities]).collect_vec()
}

// Generates the trace of the small value memory table.
Expand All @@ -237,11 +230,10 @@ fn gen_small_memory_trace(values: Vec<u128>, mults: Vec<PackedM31>) -> Vec<BaseC
})
.collect_vec();

let mut id_and_value_trace =
let mut values_trace =
std::iter::repeat_with(|| unsafe { BaseColumn::uninitialized(column_length) })
.take(SMALL_N_ID_AND_VALUE_COLUMNS)
.take(N_M31_IN_SMALL_FELT252)
.collect_vec();
let inc = Simd::from_array(std::array::from_fn(|i| i as u32));
for (i, values) in packed_values.iter().enumerate() {
let values = split_f252_simd([
values[0],
Expand All @@ -253,31 +245,29 @@ fn gen_small_memory_trace(values: Vec<u128>, mults: Vec<PackedM31>) -> Vec<BaseC
Simd::splat(0),
Simd::splat(0),
]);
id_and_value_trace[0].data[i] =
unsafe { PackedM31::from_simd_unchecked(Simd::splat((i * N_LANES) as u32) + inc) };
for (j, value) in values[..N_M31_IN_SMALL_FELT252].iter().enumerate() {
id_and_value_trace[j + 1].data[i] = *value;
values_trace[j].data[i] = *value;
}
}

let multiplicities = BaseColumn::from_simd(mults);

chain!(id_and_value_trace, [multiplicities]).collect_vec()
chain!(values_trace, [multiplicities]).collect_vec()
}

#[derive(Debug)]
pub struct InteractionClaimGenerator {
pub big_ids_and_values: [Vec<PackedM31>; BIG_N_ID_AND_VALUE_COLUMNS],
pub big_values: [Vec<PackedM31>; N_M31_IN_FELT252],
pub big_multiplicities: Vec<PackedM31>,
pub small_ids_and_values: [Vec<PackedM31>; SMALL_N_ID_AND_VALUE_COLUMNS],
pub small_values: [Vec<PackedM31>; N_M31_IN_SMALL_FELT252],
pub small_multiplicities: Vec<PackedM31>,
}
impl InteractionClaimGenerator {
pub fn with_capacity(capacity: usize) -> Self {
Self {
big_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
big_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
big_multiplicities: Vec::with_capacity(capacity),
small_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
small_values: std::array::from_fn(|_| Vec::with_capacity(capacity)),
small_multiplicities: Vec::with_capacity(capacity),
}
}
Expand Down Expand Up @@ -313,12 +303,11 @@ impl InteractionClaimGenerator {
Vec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let big_table_log_size = self.big_ids_and_values[0].len().ilog2() + LOG_N_LANES;
let big_table_log_size = self.big_values[0].len().ilog2() + LOG_N_LANES;
let mut big_values_logup_gen = LogupTraceGenerator::new(big_table_log_size);

// Every element is 9-bit.
for (limb0, limb1, limb2, lim3) in self.big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples()
{
for (limb0, limb1, limb2, lim3) in self.big_values.iter().tuples() {
let mut col_gen = big_values_logup_gen.new_col();
for (vec_row, (limb0, limb1, limb2, limb3)) in
izip!(limb0, limb1, limb2, lim3).enumerate()
Expand All @@ -332,10 +321,21 @@ impl InteractionClaimGenerator {

// Yield large values.
let mut col_gen = big_values_logup_gen.new_col();
let large_memory_value_id_tag = Simd::splat(LARGE_MEMORY_VALUE_ID_BASE);
for vec_row in 0..1 << (big_table_log_size - LOG_N_LANES) {
let values: [_; BIG_N_ID_AND_VALUE_COLUMNS] =
std::array::from_fn(|i| self.big_ids_and_values[i][vec_row]);
let denom: PackedQM31 = lookup_elements.combine(&values);
let id_and_value: [_; N_M31_IN_FELT252 + MEMORY_ID_SIZE] = std::array::from_fn(|i| {
if i == 0 {
unsafe {
PackedM31::from_simd_unchecked(
(SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32))
| large_memory_value_id_tag,
)
}
} else {
self.big_values[i - 1][vec_row]
}
});
let denom: PackedQM31 = lookup_elements.combine(&id_and_value);
col_gen.write_frac(vec_row, (-self.big_multiplicities[vec_row]).into(), denom);
}
col_gen.finalize_col();
Expand All @@ -351,11 +351,11 @@ impl InteractionClaimGenerator {
Vec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let small_table_log_size = self.small_ids_and_values[0].len().ilog2() + LOG_N_LANES;
let small_table_log_size = self.small_values[0].len().ilog2() + LOG_N_LANES;
let mut small_values_logup_gen = LogupTraceGenerator::new(small_table_log_size);

// Every element is 9-bit.
for (l, r) in self.small_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() {
for (l, r) in self.small_values.iter().tuples() {
let mut col_gen = small_values_logup_gen.new_col();
for (vec_row, (l1, l2)) in zip(l, r).enumerate() {
// TOOD(alont) Add 2-batching.
Expand All @@ -371,9 +371,19 @@ impl InteractionClaimGenerator {
// Yield small values.
let mut col_gen = small_values_logup_gen.new_col();
for vec_row in 0..1 << (small_table_log_size - LOG_N_LANES) {
let values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] =
std::array::from_fn(|i| self.small_ids_and_values[i][vec_row]);
let denom: PackedQM31 = lookup_elements.combine(&values);
let id_and_value: [_; N_M31_IN_SMALL_FELT252 + MEMORY_ID_SIZE] =
std::array::from_fn(|i| {
if i == 0 {
unsafe {
PackedM31::from_simd_unchecked(
SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32),
)
}
} else {
self.small_values[i - 1][vec_row]
}
});
let denom: PackedQM31 = lookup_elements.combine(&id_and_value);
col_gen.write_frac(vec_row, (-self.small_multiplicities[vec_row]).into(), denom);
}
col_gen.finalize_col();
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ impl DerefMut for MemoryBuilder {
}
}

pub const LARGE_MEMORY_VALUE_ID_TAG: u32 = 0x4000_0000;
pub const LARGE_MEMORY_VALUE_ID_BASE: u32 = 0x4000_0000;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct EncodedMemoryValueId(pub u32);
impl EncodedMemoryValueId {
pub fn encode(value: MemoryValueId) -> EncodedMemoryValueId {
match value {
MemoryValueId::Small(id) => EncodedMemoryValueId(id),
MemoryValueId::F252(id) => EncodedMemoryValueId(id | LARGE_MEMORY_VALUE_ID_TAG),
MemoryValueId::F252(id) => EncodedMemoryValueId(id | LARGE_MEMORY_VALUE_ID_BASE),
}
}
pub fn decode(&self) -> MemoryValueId {
Expand Down
Loading