From c0071cd262e08da6fedd0685c7da84b457101292 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 21 Jan 2025 11:42:59 +0200 Subject: [PATCH] Bitwise Xor Struct --- .../prover/src/cairo_air/preprocessed.rs | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs index fc76d56ae..843c08975 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs @@ -1,6 +1,7 @@ use itertools::{chain, Itertools}; use prover_types::simd::LOG_N_LANES; use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId}; +use stwo_prover::core::backend::simd::column::BaseColumn; use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; use stwo_prover::core::backend::simd::SimdBackend; use stwo_prover::core::backend::Col; @@ -102,9 +103,94 @@ impl Seq { } } +/// Columns for the bitwise xor preprocessed table. +/// The table has three columns (use col_index to select the column when needed): +/// 0: first limb, 1: second limb, 2: result of the limbs' bitwise xor. +#[derive(Debug)] +pub struct BitwiseXor { + pub n_bits: u32, + pub n_expand_bits: u32, + pub col_index: usize, +} +impl BitwiseXor { + pub const fn new(n_bits: u32, n_expand_bits: u32, col_index: usize) -> Self { + assert!(col_index < 3, "col_index must be in range 0..=2"); + Self { + n_bits, + n_expand_bits, + col_index, + } + } + + pub fn id(&self) -> PreProcessedColumnId { + PreProcessedColumnId { + id: format!( + "preprocessed_bitwise_xor_{}_{}_{}", + self.n_bits, self.n_expand_bits, self.col_index + ), + } + } + pub const fn limb_bits(&self) -> u32 { + self.n_bits - self.n_expand_bits + } + + pub const fn column_bits(&self) -> u32 { + 2 * self.limb_bits() + } + + #[allow(clippy::type_complexity)] + pub fn gen_column_simd(&self) -> CircleEvaluation { + let col: BaseColumn = match self.col_index { + 0 => (0..(1 << self.column_bits())) + .map(|i| BaseField::from_u32_unchecked((i >> self.limb_bits()) as u32)) + .collect(), + 1 => (0..(1 << self.column_bits())) + .map(|i| BaseField::from_u32_unchecked((i & ((1 << self.limb_bits()) - 1)) as u32)) + .collect(), + 2 => (0..(1 << self.column_bits())) + .map(|i| { + BaseField::from_u32_unchecked( + ((i >> self.limb_bits()) ^ (i & ((1 << self.limb_bits()) - 1))) as u32, + ) + }) + .collect(), + _ => unreachable!(), + }; + CircleEvaluation::new(CanonicCoset::new(self.column_bits()).circle_domain(), col) + } + + pub fn packed_at(&self, vec_row: usize) -> PackedM31 { + let at_row: [BaseField; N_LANES] = match self.col_index { + 0 => (vec_row * N_LANES..(vec_row + 1) * N_LANES) + .map(|i| BaseField::from_u32_unchecked((i >> self.limb_bits()) as u32)) + .collect_vec() + .try_into() + .unwrap(), + 1 => (vec_row * N_LANES..(vec_row + 1) * N_LANES) + .map(|i| BaseField::from_u32_unchecked((i & ((1 << self.limb_bits()) - 1)) as u32)) + .collect_vec() + .try_into() + .unwrap(), + 2 => (vec_row * N_LANES..(vec_row + 1) * N_LANES) + .map(|i| { + BaseField::from_u32_unchecked( + ((i >> self.limb_bits()) ^ (i & ((1 << self.limb_bits()) - 1))) as u32, + ) + }) + .collect_vec() + .try_into() + .unwrap(), + _ => unreachable!(), + }; + PackedM31::from_array(at_row) + } +} + #[cfg(test)] mod tests { use super::*; + const LOG_SIZE: u32 = 8; + use stwo_prover::core::backend::Column; #[test] fn test_columns_are_in_decending_order() { @@ -114,4 +200,54 @@ mod tests { .windows(2) .all(|w| w[0].log_size() >= w[1].log_size())); } + + #[test] + fn test_gen_seq() { + let seq = Seq::new(LOG_SIZE).gen_column_simd(); + for i in 0..(1 << LOG_SIZE) { + assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32)); + } + } + + #[test] + fn test_packed_at_seq() { + let seq = Seq::new(LOG_SIZE); + let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32)); + let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| { + seq.packed_at(i).to_array() + }) + .concat(); + assert_eq!(packed_seq, expected_seq); + } + + #[test] + fn test_gen_bitwise_xor() { + let bitwise_xor = BitwiseXor::new(LOG_SIZE, 0, 2).gen_column_simd(); + + for i in 0..(1 << (2 * LOG_SIZE)) { + let a = i >> LOG_SIZE; + let b = i & ((1 << LOG_SIZE) - 1); + let expected = BaseField::from_u32_unchecked((a ^ b) as u32); + + assert_eq!(bitwise_xor.at(i), expected,); + } + } + + #[test] + fn test_packed_at_bitwise_xor() { + let bitwise_xor = BitwiseXor::new(LOG_SIZE, 0, 2); + let expected_bitwise_xor: [_; 1 << (2 * LOG_SIZE)] = std::array::from_fn(|i| { + let a = i >> LOG_SIZE; + let b = i & ((1 << LOG_SIZE) - 1); + M31::from((a ^ b) as u32) + }); + + let packed_bitwise_xor = + std::array::from_fn::<_, { (1 << (2 * LOG_SIZE)) / N_LANES }, _>(|i| { + bitwise_xor.packed_at(i).to_array() + }) + .concat(); + + assert_eq!(packed_bitwise_xor, expected_bitwise_xor); + } }