diff --git a/stwo_cairo_prover/crates/prover/src/components/mod.rs b/stwo_cairo_prover/crates/prover/src/components/mod.rs index 24ccdc13..78a95c90 100644 --- a/stwo_cairo_prover/crates/prover/src/components/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/components/mod.rs @@ -25,6 +25,7 @@ pub mod range_check_builtin_bits_128; pub mod range_check_vector; pub mod ret_opcode; pub mod utils; +pub mod verify_bitwise_xor_9; pub mod verify_instruction; // TODO(Ohad): mul small. diff --git a/stwo_cairo_prover/crates/prover/src/components/utils.rs b/stwo_cairo_prover/crates/prover/src/components/utils.rs index c700afab..49684fc4 100644 --- a/stwo_cairo_prover/crates/prover/src/components/utils.rs +++ b/stwo_cairo_prover/crates/prover/src/components/utils.rs @@ -51,6 +51,7 @@ impl MultiplicityColumn { } } +#[derive(Default)] /// A column of multiplicities for lookup arguments. Allows increasing the multiplicity at a given /// index. This version uses atomic operations to increase the multiplicity, and is `Send`. pub struct AtomicMultiplicityColumn { diff --git a/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/component.rs b/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/component.rs new file mode 100644 index 00000000..299c55d9 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/component.rs @@ -0,0 +1,94 @@ +#![allow(non_camel_case_types)] +#![allow(unused_imports)] +use num_traits::{One, Zero}; +use serde::{Deserialize, Serialize}; +use stwo_cairo_serialize::CairoSerialize; +use stwo_prover::constraint_framework::logup::{LogupAtRow, LogupSums, LookupElements}; +use stwo_prover::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, +}; +use stwo_prover::core::backend::simd::m31::LOG_N_LANES; +use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use stwo_prover::core::pcs::TreeVec; + +use crate::cairo_air::preprocessed::{BitwiseXor, PreProcessedColumn}; +use crate::relations; + +pub struct Eval { + pub claim: Claim, + pub verify_bitwise_xor_9_lookup_elements: relations::VerifyBitwiseXor_9, +} + +#[derive(Copy, Clone, Serialize, Deserialize, CairoSerialize)] +pub struct Claim { + pub log_size: u32, +} +impl Claim { + pub fn log_sizes(&self) -> TreeVec> { + let log_size = self.log_size; + let trace_log_sizes = vec![log_size; 1]; + let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE]; + let preprocessed_log_sizes = vec![log_size]; + TreeVec::new(vec![ + preprocessed_log_sizes, + trace_log_sizes, + interaction_log_sizes, + ]) + } + + pub fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_u64(self.log_size as u64); + } +} + +#[derive(Copy, Clone, Serialize, Deserialize, CairoSerialize)] +pub struct InteractionClaim { + pub logup_sums: LogupSums, +} +impl InteractionClaim { + pub fn mix_into(&self, channel: &mut impl Channel) { + let (total_sum, claimed_sum) = self.logup_sums; + channel.mix_felts(&[total_sum]); + if let Some(claimed_sum) = claimed_sum { + channel.mix_felts(&[claimed_sum.0]); + channel.mix_u64(claimed_sum.1 as u64); + } + } +} + +pub type Component = FrameworkComponent; + +impl FrameworkEval for Eval { + fn log_size(&self) -> u32 { + self.claim.log_size + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + #[allow(unused_parens)] + #[allow(clippy::double_parens)] + #[allow(non_snake_case)] + fn evaluate(&self, mut eval: E) -> E { + let BitwiseXor0 = eval + .get_preprocessed_column(PreProcessedColumn::BitwiseXor(BitwiseXor::new(9, 0)).id()); + let BitwiseXor1 = eval + .get_preprocessed_column(PreProcessedColumn::BitwiseXor(BitwiseXor::new(9, 1)).id()); + let BitwiseXor2 = eval + .get_preprocessed_column(PreProcessedColumn::BitwiseXor(BitwiseXor::new(9, 2)).id()); + let multiplicity = eval.next_trace_mask(); + + eval.add_to_relation(RelationEntry::new( + &self.verify_bitwise_xor_9_lookup_elements, + -E::EF::from(multiplicity), + &[BitwiseXor0, BitwiseXor1, BitwiseXor2], + )); + + eval.finalize_logup_in_pairs(); + eval + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/mod.rs b/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/mod.rs new file mode 100644 index 00000000..3f7a8d74 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/mod.rs @@ -0,0 +1,5 @@ +pub mod component; +pub mod prover; + +pub use component::{Claim, Component, Eval, InteractionClaim}; +pub use prover::{ClaimGenerator, InputType, InteractionClaimGenerator}; diff --git a/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/prover.rs b/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/prover.rs new file mode 100644 index 00000000..3980dab9 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/verify_bitwise_xor_9/prover.rs @@ -0,0 +1,165 @@ +#![allow(unused_parens)] +#![allow(dead_code)] +#![allow(unused_imports)] +use std::iter::zip; +use std::sync::Mutex; + +use itertools::{chain, zip_eq, Itertools}; +use num_traits::{One, Zero}; +use prover_types::cpu::*; +use prover_types::simd::*; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; +use stwo_air_utils::trace::component_trace::ComponentTrace; +use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized}; +use stwo_prover::constraint_framework::logup::LogupTraceGenerator; +use stwo_prover::constraint_framework::Relation; +use stwo_prover::core::air::Component; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::conversion::Unpack; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::qm31::PackedQM31; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::{BackendForChannel, Col, Column}; +use stwo_prover::core::channel::{Channel, MerkleChannel}; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::FieldExpOps; +use stwo_prover::core::pcs::TreeBuilder; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::{ + bit_reverse_coset_to_circle_domain_order, bit_reverse_index, coset_index_to_circle_domain_index, +}; + +use super::component::{Claim, InteractionClaim}; +use crate::cairo_air::preprocessed::{BitwiseXor, PreProcessedColumn}; +use crate::components::utils::{pack_values, AtomicMultiplicityColumn}; +use crate::relations; + +pub type InputType = [M31; 3]; +pub type PackedInputType = [PackedM31; 3]; +const N_BITS: u32 = 9; +const N_TRACE_COLUMNS: usize = 1; +const LOG_SIZE: u32 = 18; + +#[derive(Default)] +pub struct ClaimGenerator { + pub log_size: u32, + pub mults: AtomicMultiplicityColumn, +} +impl ClaimGenerator { + pub fn new() -> Self { + Self { + log_size: LOG_SIZE, + mults: AtomicMultiplicityColumn::new(1 << LOG_SIZE), + } + } + + pub fn write_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>, + ) -> (Claim, InteractionClaimGenerator) + where + SimdBackend: BackendForChannel, + { + let log_size = self.log_size; + let (trace, lookup_data) = write_trace_simd(log_size, self.mults); + + tree_builder.extend_evals(trace.to_evals()); + + ( + Claim { log_size }, + InteractionClaimGenerator { + log_size, + lookup_data, + }, + ) + } + + pub fn add_input(&self, input: &InputType) { + self.mults.increase_at((input[0].0 << N_BITS) + input[1].0); + } + + pub fn add_inputs(&mut self, inputs: &[InputType]) { + for input in inputs { + self.add_input(input); + } + } +} + +fn write_trace_simd( + log_size: u32, + mults: AtomicMultiplicityColumn, +) -> (ComponentTrace, LookupData) { + let log_n_packed_rows = log_size - LOG_N_LANES; + let (mut trace, mut lookup_data) = unsafe { + ( + ComponentTrace::::uninitialized(log_size), + LookupData::uninitialized(log_n_packed_rows), + ) + }; + + let mults = mults.into_simd_vec(); + trace + .par_iter_mut() + .enumerate() + .zip(lookup_data.par_iter_mut()) + .for_each(|((row_index, row), lookup_data)| { + *row[0] = mults[row_index]; + + *lookup_data.verify_bitwise_xor_9_0 = [ + PreProcessedColumn::BitwiseXor(BitwiseXor::new(9, 0)).packed_at(row_index), + PreProcessedColumn::BitwiseXor(BitwiseXor::new(9, 1)).packed_at(row_index), + PreProcessedColumn::BitwiseXor(BitwiseXor::new(9, 2)).packed_at(row_index), + ]; + *lookup_data.mults = mults[row_index]; + }); + + (trace, lookup_data) +} + +#[derive(Uninitialized, IterMut, ParIterMut)] +struct LookupData { + verify_bitwise_xor_9_0: Vec<[PackedM31; 3]>, + mults: Vec, +} + +pub struct InteractionClaimGenerator { + log_size: u32, + lookup_data: LookupData, +} +impl InteractionClaimGenerator { + pub fn write_interaction_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, MC>, + verify_bitwise_xor_9: &relations::VerifyBitwiseXor_9, + ) -> InteractionClaim + where + SimdBackend: BackendForChannel, + { + let log_size = self.log_size; + let mut logup_gen = LogupTraceGenerator::new(log_size); + + // Sum last logup term. + let mut col_gen = logup_gen.new_col(); + for (i, (values, mults)) in self + .lookup_data + .verify_bitwise_xor_9_0 + .iter() + .zip(self.lookup_data.mults) + .enumerate() + { + let denom = verify_bitwise_xor_9.combine(values); + col_gen.write_frac(i, -PackedQM31::one() * mults, denom); + } + col_gen.finalize_col(); + + let (trace, claimed_sum) = logup_gen.finalize_last(); + tree_builder.extend_evals(trace); + + InteractionClaim { + logup_sums: (claimed_sum, None), + } + } +} diff --git a/stwo_cairo_prover/crates/prover/src/relations/mod.rs b/stwo_cairo_prover/crates/prover/src/relations/mod.rs index 1826f79b..cd4d8353 100644 --- a/stwo_cairo_prover/crates/prover/src/relations/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/relations/mod.rs @@ -11,3 +11,4 @@ relation!(RangeCheck_9_9, 2); relation!(RangeCheck_4_3, 2); relation!(RangeCheck_7_2_5, 3); relation!(VerifyInstruction, 29); +relation!(VerifyBitwiseXor_9, 3);