diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs index 81160089f..b520eddcd 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs @@ -32,7 +32,7 @@ use super::range_checks_air::{ RangeChecksInteractionElements, }; use crate::components::memory::{memory_address_to_id, memory_id_to_big}; -use crate::components::verify_instruction; +use crate::components::{verify_bitwise_xor_9, verify_instruction}; use crate::felt::split_f252; use crate::input::ProverInput; use crate::relations; @@ -72,6 +72,7 @@ pub struct CairoClaim { pub memory_address_to_id: memory_address_to_id::Claim, pub memory_id_to_value: memory_id_to_big::Claim, pub range_checks: RangeChecksClaim, + pub verify_bitwise_xor_9: verify_bitwise_xor_9::Claim, // ... } @@ -84,6 +85,7 @@ impl CairoClaim { self.memory_address_to_id.mix_into(channel); self.memory_id_to_value.mix_into(channel); self.range_checks.mix_into(channel); + self.verify_bitwise_xor_9.mix_into(channel); } pub fn log_sizes(&self) -> TreeVec> { @@ -94,6 +96,7 @@ impl CairoClaim { self.memory_address_to_id.log_sizes(), self.memory_id_to_value.log_sizes(), self.range_checks.log_sizes(), + self.verify_bitwise_xor_9.log_sizes(), ]; let mut log_sizes = TreeVec::concat_cols(log_sizes_list.into_iter()); @@ -163,6 +166,7 @@ pub struct CairoClaimGenerator { memory_address_to_id_trace_generator: memory_address_to_id::ClaimGenerator, memory_id_to_value_trace_generator: memory_id_to_big::ClaimGenerator, range_checks_trace_generator: RangeChecksClaimGenerator, + verify_bitwise_xor_9_trace_generator: verify_bitwise_xor_9::ClaimGenerator, // ... } impl CairoClaimGenerator { @@ -178,6 +182,7 @@ impl CairoClaimGenerator { let memory_id_to_value_trace_generator = memory_id_to_big::ClaimGenerator::new(&input.memory); let range_checks_trace_generator = RangeChecksClaimGenerator::new(); + let verify_bitwise_xor_9_trace_generator = verify_bitwise_xor_9::ClaimGenerator::new(); // Yield public memory. for addr in input @@ -215,6 +220,7 @@ impl CairoClaimGenerator { memory_address_to_id_trace_generator, memory_id_to_value_trace_generator, range_checks_trace_generator, + verify_bitwise_xor_9_trace_generator, } } @@ -261,6 +267,9 @@ impl CairoClaimGenerator { ); let (range_checks_claim, range_checks_interaction_gen) = self.range_checks_trace_generator.write_trace(tree_builder); + let (verify_bitwise_xor_9_claim, verify_bitwise_xor_9_interaction_gen) = self + .verify_bitwise_xor_9_trace_generator + .write_trace(tree_builder); span.exit(); ( CairoClaim { @@ -271,6 +280,7 @@ impl CairoClaimGenerator { memory_address_to_id: memory_address_to_id_claim, memory_id_to_value: memory_id_to_value_claim, range_checks: range_checks_claim, + verify_bitwise_xor_9: verify_bitwise_xor_9_claim, }, CairoInteractionClaimGenerator { opcodes_interaction_gen, @@ -279,6 +289,7 @@ impl CairoClaimGenerator { memory_address_to_id_interaction_gen, memory_id_to_value_interaction_gen, range_checks_interaction_gen, + verify_bitwise_xor_9_interaction_gen, }, ) } @@ -291,6 +302,7 @@ pub struct CairoInteractionClaimGenerator { memory_address_to_id_interaction_gen: memory_address_to_id::InteractionClaimGenerator, memory_id_to_value_interaction_gen: memory_id_to_big::InteractionClaimGenerator, range_checks_interaction_gen: RangeChecksInteractionClaimGenerator, + verify_bitwise_xor_9_interaction_gen: verify_bitwise_xor_9::InteractionClaimGenerator, // ... } impl CairoInteractionClaimGenerator { @@ -332,6 +344,9 @@ impl CairoInteractionClaimGenerator { let range_checks_interaction_claim = self .range_checks_interaction_gen .write_interaction_trace(tree_builder, &interaction_elements.range_checks); + let verify_bitwise_xor_9_interaction_claim = self + .verify_bitwise_xor_9_interaction_gen + .write_interaction_trace(tree_builder, &interaction_elements.verify_bitwise_xor_9); CairoInteractionClaim { opcodes: opcodes_interaction_claims, @@ -340,6 +355,7 @@ impl CairoInteractionClaimGenerator { memory_address_to_id: memory_address_to_id_interaction_claim, memory_id_to_value: memory_id_to_value_interaction_claim, range_checks: range_checks_interaction_claim, + verify_bitwise_xor_9: verify_bitwise_xor_9_interaction_claim, } } } @@ -350,6 +366,7 @@ pub struct CairoInteractionElements { pub memory_address_to_id: relations::MemoryAddressToId, pub memory_id_to_value: relations::MemoryIdToBig, pub range_checks: RangeChecksInteractionElements, + pub verify_bitwise_xor_9: relations::VerifyBitwiseXor_9, } impl CairoInteractionElements { pub fn draw(channel: &mut impl Channel) -> CairoInteractionElements { @@ -359,6 +376,7 @@ impl CairoInteractionElements { memory_address_to_id: relations::MemoryAddressToId::draw(channel), memory_id_to_value: relations::MemoryIdToBig::draw(channel), range_checks: RangeChecksInteractionElements::draw(channel), + verify_bitwise_xor_9: relations::VerifyBitwiseXor_9::draw(channel), } } } @@ -371,6 +389,7 @@ pub struct CairoInteractionClaim { pub memory_address_to_id: memory_address_to_id::InteractionClaim, pub memory_id_to_value: memory_id_to_big::InteractionClaim, pub range_checks: RangeChecksInteractionClaim, + pub verify_bitwise_xor_9: verify_bitwise_xor_9::InteractionClaim, } impl CairoInteractionClaim { pub fn mix_into(&self, channel: &mut impl Channel) { @@ -380,6 +399,7 @@ impl CairoInteractionClaim { self.memory_address_to_id.mix_into(channel); self.memory_id_to_value.mix_into(channel); self.range_checks.mix_into(channel); + self.verify_bitwise_xor_9.mix_into(channel); } } @@ -400,6 +420,7 @@ pub fn lookup_sum( sum += interaction_claim.memory_id_to_value.big_claimed_sum; sum += interaction_claim.memory_id_to_value.small_claimed_sum; sum += interaction_claim.range_checks.sum(); + sum += interaction_claim.verify_bitwise_xor_9.claimed_sum; sum } @@ -413,6 +434,7 @@ pub struct CairoComponents { memory_id_to_big::SmallComponent, ), range_checks: RangeChecksComponents, + verify_bitwise_xor_9: verify_bitwise_xor_9::Component, // ... } impl CairoComponents { @@ -491,6 +513,16 @@ impl CairoComponents { &interaction_elements.range_checks, &interaction_claim.range_checks, ); + let verify_bitwise_xor_9_component = verify_bitwise_xor_9::Component::new( + tree_span_provider, + verify_bitwise_xor_9::Eval { + claim: cairo_claim.verify_bitwise_xor_9, + verify_bitwise_xor_9_lookup_elements: interaction_elements + .verify_bitwise_xor_9 + .clone(), + }, + interaction_claim.verify_bitwise_xor_9.claimed_sum, + ); Self { opcodes: opcode_components, verify_instruction: verify_instruction_component, @@ -501,6 +533,7 @@ impl CairoComponents { small_memory_id_to_value_component, ), range_checks: range_checks_component, + verify_bitwise_xor_9: verify_bitwise_xor_9_component, } } @@ -514,7 +547,8 @@ impl CairoComponents { &self.memory_id_to_value.0 as &dyn ComponentProver, &self.memory_id_to_value.1 as &dyn ComponentProver, ], - self.range_checks.provers() + self.range_checks.provers(), + [&self.verify_bitwise_xor_9 as &dyn ComponentProver,], ) .collect() } @@ -553,6 +587,11 @@ impl std::fmt::Display for CairoComponents { indented_component_display(&self.memory_id_to_value.1) )?; writeln!(f, "RangeChecks: {}", self.range_checks)?; + writeln!( + f, + "VerifyBitwiseXor9: {}", + indented_component_display(&self.verify_bitwise_xor_9) + )?; Ok(()) } } diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs index 978a1feb1..3950c0065 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/debug_tools.rs @@ -25,7 +25,8 @@ use crate::components::{ jnz_opcode, jnz_opcode_dst_base_fp, jnz_opcode_taken, jnz_opcode_taken_dst_base_fp, jump_opcode, jump_opcode_double_deref, jump_opcode_rel, jump_opcode_rel_imm, memory_address_to_id, memory_id_to_big, mul_opcode, mul_opcode_imm, - range_check_builtin_bits_128, range_check_builtin_bits_96, ret_opcode, verify_instruction, + range_check_builtin_bits_128, range_check_builtin_bits_96, ret_opcode, verify_bitwise_xor_9, + verify_instruction, }; use crate::felt::split_f252; use crate::relations; @@ -504,6 +505,17 @@ where ); } + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + verify_bitwise_xor_9::Eval { + claim: claim.verify_bitwise_xor_9, + verify_bitwise_xor_9_lookup_elements: relations::VerifyBitwiseXor_9::dummy(), + }, + 1 << crate::components::verify_bitwise_xor_9::component::LOG_SIZE, + ) + .entries(trace), + ); // Memory. entries.extend( RelationTrackerComponent::new( diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs index e533fd0f3..e7f27337a 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs @@ -1,6 +1,6 @@ use std::simd::{u32x16, Simd}; -use itertools::Itertools; +use itertools::{chain, Itertools}; use prover_types::simd::LOG_N_LANES; use stwo_prover::constraint_framework::preprocessed_columns::PreProcessedColumnId; use stwo_prover::core::backend::simd::column::BaseColumn; @@ -19,6 +19,9 @@ const N_PREPROCESSED_COLUMN_SIZES: usize = (LOG_MAX_ROWS - LOG_N_LANES) as usize // List of sizes to initialize the preprocessed trace with for `PreprocessedColumn::Seq`. const SEQ_LOG_SIZES: [u32; N_PREPROCESSED_COLUMN_SIZES] = preprocessed_log_sizes(); +// Size to initialize the preprocessed trace with for `PreprocessedColumn::BitwiseXor`. +const XOR_N_BITS: u32 = 9; + /// [LOG_MAX_ROWS, LOG_MAX_ROWS - 1, ..., LOG_N_LANES] const fn preprocessed_log_sizes() -> [u32; N_PREPROCESSED_COLUMN_SIZES] { let mut arr = [0; N_PREPROCESSED_COLUMN_SIZES]; @@ -61,8 +64,10 @@ impl PreProcessedColumn { /// Returns column info for the preprocessed trace. pub fn preprocessed_trace_columns() -> Vec { let seq_columns = SEQ_LOG_SIZES.map(|log_size| PreProcessedColumn::Seq(Seq::new(log_size))); - seq_columns - .into_iter() + let bitwise_xor_columns = (0..3).map(move |col_index| { + PreProcessedColumn::BitwiseXor(BitwiseXor::new(XOR_N_BITS, col_index)) + }); + chain![seq_columns, bitwise_xor_columns] .sorted_by_key(|column| std::cmp::Reverse(column.log_size())) .collect_vec() }