Skip to content

Commit

Permalink
Bitwise Xor Struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Jan 27, 2025
1 parent 13444ff commit 2b943c5
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use itertools::{chain, Itertools};
use prover_types::simd::LOG_N_LANES;
use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId};
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::Col;
Expand Down Expand Up @@ -102,9 +103,94 @@ impl Seq {
}
}

/// A preprocessed table for the xor operation of 2 n_bits numbers.
/// The index_in_table is the column index in the preprocessed table (0 for the first number, 1 for
/// the second number and 2 for the xor operation result).
#[derive(Debug)]
pub struct BitwiseXor {
pub n_bits: u32,
pub col_index: usize,
}
impl BitwiseXor {
pub const fn new(n_bits: u32, col_index: usize) -> Self {
assert!(col_index < 3, "col_index must be in range 0..=2");
Self { n_bits, col_index }
}

pub fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!(
"preprocessed_bitwise_xor_{}_{}",
self.n_bits, self.col_index
),
}
}

pub const fn log_size(&self) -> u32 {
2 * self.n_bits
}

#[allow(clippy::type_complexity)]
pub fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
let col: BaseColumn = match self.col_index {
0 => (0..(1 << self.log_size()))
.map(|i| BaseField::from_u32_unchecked((i >> self.n_bits) as u32))
.collect(),
1 => (0..(1 << self.log_size()))
.map(|i| BaseField::from_u32_unchecked((i & ((1 << self.n_bits) - 1)) as u32))
.collect(),
2 => (0..(1 << self.log_size()))
.map(|i| {
BaseField::from_u32_unchecked(
((i >> self.n_bits) ^ (i & ((1 << self.n_bits) - 1))) as u32,
)
})
.collect(),
_ => unreachable!(),
};
CircleEvaluation::new(CanonicCoset::new(self.log_size()).circle_domain(), col)
}

fn packed_at_lhs(&self, vec_row: usize) -> PackedM31 {
let at_row: [BaseField; N_LANES] = (vec_row * N_LANES..(vec_row + 1) * N_LANES)
.map(|i| BaseField::from_u32_unchecked((i >> self.n_bits) as u32))
.collect_vec()
.try_into()
.unwrap();
PackedM31::from_array(at_row)
}

fn packed_at_rhs(&self, vec_row: usize) -> PackedM31 {
let at_row: [BaseField; N_LANES] = (vec_row * N_LANES..(vec_row + 1) * N_LANES)
.map(|i| BaseField::from_u32_unchecked((i & ((1 << self.n_bits) - 1)) as u32))
.collect_vec()
.try_into()
.unwrap();
PackedM31::from_array(at_row)
}

pub fn packed_at(&self, vec_row: usize) -> PackedM31 {
match self.col_index {
0 => self.packed_at_lhs(vec_row),
1 => self.packed_at_rhs(vec_row),
2 => {
let lhs_array = self.packed_at_lhs(vec_row).to_array();
let rhs_array = self.packed_at_rhs(vec_row).to_array();
let at_row: [BaseField; N_LANES] = std::array::from_fn(|i| {
BaseField::from_u32_unchecked(lhs_array[i].0 ^ rhs_array[i].0)
});
PackedM31::from_array(at_row)
}
_ => unreachable!(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
const LOG_SIZE: u32 = 8;
use stwo_prover::core::backend::Column;

#[test]
fn test_columns_are_in_decending_order() {
Expand All @@ -114,4 +200,54 @@ mod tests {
.windows(2)
.all(|w| w[0].log_size() >= w[1].log_size()));
}

#[test]
fn test_gen_seq() {
let seq = Seq::new(LOG_SIZE).gen_column_simd();
for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
}
}

#[test]
fn test_packed_at_seq() {
let seq = Seq::new(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));
let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
seq.packed_at(i).to_array()
})
.concat();
assert_eq!(packed_seq, expected_seq);
}

#[test]
fn test_gen_bitwise_xor() {
let bitwise_xor = BitwiseXor::new(LOG_SIZE, 2).gen_column_simd();

for i in 0..(1 << (2 * LOG_SIZE)) {
let a = i >> LOG_SIZE;
let b = i & ((1 << LOG_SIZE) - 1);
let expected = BaseField::from_u32_unchecked((a ^ b) as u32);

assert_eq!(bitwise_xor.at(i), expected,);
}
}

#[test]
fn test_packed_at_bitwise_xor() {
let bitwise_xor = BitwiseXor::new(LOG_SIZE, 2);
let expected_bitwise_xor: [_; 1 << (2 * LOG_SIZE)] = std::array::from_fn(|i| {
let a = i >> LOG_SIZE;
let b = i & ((1 << LOG_SIZE) - 1);
M31::from((a ^ b) as u32)
});

let packed_bitwise_xor =
std::array::from_fn::<_, { (1 << (2 * LOG_SIZE)) / N_LANES }, _>(|i| {
bitwise_xor.packed_at(i).to_array()
})
.concat();

assert_eq!(packed_bitwise_xor, expected_bitwise_xor);
}
}

0 comments on commit 2b943c5

Please sign in to comment.