From bf7cb1dce4c1a1a3a0337e278d7f5542484c56e1 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 21 Jan 2025 11:42:59 +0200 Subject: [PATCH] Bitwise Xor Struct --- .../prover/src/cairo_air/preprocessed.rs | 97 ++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs index fc76d56ae..81ae7ee3e 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs @@ -1,3 +1,5 @@ +use std::simd::{u32x16, Simd}; + use itertools::{chain, Itertools}; use prover_types::simd::LOG_N_LANES; use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId}; @@ -97,14 +99,69 @@ impl Seq { pub fn id(&self) -> PreProcessedColumnId { PreProcessedColumnId { - id: format!("preprocessed_seq_{}", self.log_size).to_string(), + id: format!("seq_{}", self.log_size).to_string(), } } } +/// Columns for BitwiseXor preprocessed columns. +#[derive(Debug)] +enum XorColumn { + A, + B, + C, +} + +/// A table of a,b,c, where a,b,c are `n_bits` integers and a ^ b = c. +/// The 'col_index' is the column index in the preprocessed table. +#[derive(Debug)] +pub struct BitwiseXor { + n_bits: u32, + col: XorColumn, +} +impl BitwiseXor { + pub const fn new(n_bits: u32, col: usize) -> Self { + let col = match col { + 0 => XorColumn::A, + 1 => XorColumn::B, + 2 => XorColumn::C, + _ => panic!("col_index must be in range 0..=2"), + }; + Self { n_bits, col } + } + + pub fn id(&self) -> PreProcessedColumnId { + PreProcessedColumnId { + id: format!("bitwise_xor_{}_{:?}", self.n_bits, self.col), + } + } + + pub const fn log_size(&self) -> u32 { + 2 * self.n_bits + } + + pub fn packed_at(&self, vec_row: usize) -> PackedM31 { + let lhs = || -> u32x16 { + (SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) >> self.n_bits + }; + let rhs = || -> u32x16 { + (SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) + & Simd::splat((1 << self.n_bits) - 1) + }; + let simd = match self.col { + XorColumn::A => lhs(), + XorColumn::B => rhs(), + XorColumn::C => lhs() ^ rhs(), + }; + unsafe { PackedM31::from_simd_unchecked(simd) } + } +} + #[cfg(test)] mod tests { use super::*; + const LOG_SIZE: u32 = 8; + use stwo_prover::core::backend::Column; #[test] fn test_columns_are_in_decending_order() { @@ -114,4 +171,42 @@ mod tests { .windows(2) .all(|w| w[0].log_size() >= w[1].log_size())); } + + #[test] + fn test_gen_seq() { + let seq = Seq::new(LOG_SIZE).gen_column_simd(); + for i in 0..(1 << LOG_SIZE) { + assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32)); + } + } + + #[test] + fn test_packed_at_seq() { + let seq = Seq::new(LOG_SIZE); + let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32)); + let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| { + seq.packed_at(i).to_array() + }) + .concat(); + assert_eq!(packed_seq, expected_seq); + } + + #[test] + fn test_packed_at_bitwise_xor() { + let bitwise_a = BitwiseXor::new(LOG_SIZE, 0); + let bitwise_b = BitwiseXor::new(LOG_SIZE, 1); + let bitwise_xor = BitwiseXor::new(LOG_SIZE, 2); + let index: usize = 41; + let a = index / (1 << LOG_SIZE); + let b = index % (1 << LOG_SIZE); + let expected_xor = a ^ b; + + let res_a = bitwise_a.packed_at(index / N_LANES).to_array()[index % N_LANES]; + let res_b = bitwise_b.packed_at(index / N_LANES).to_array()[index % N_LANES]; + let res_xor = bitwise_xor.packed_at(index / N_LANES).to_array()[index % N_LANES]; + + assert_eq!(res_a.0, a as u32); + assert_eq!(res_b.0, b as u32); + assert_eq!(res_xor.0, expected_xor as u32); + } }