Skip to content

Commit

Permalink
Bitwise Xor Struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Jan 23, 2025
1 parent a23c93f commit c0071cd
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use itertools::{chain, Itertools};
use prover_types::simd::LOG_N_LANES;
use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId};
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::Col;
Expand Down Expand Up @@ -102,9 +103,94 @@ impl Seq {
}
}

/// Columns for the bitwise xor preprocessed table.
/// The table has three columns (use col_index to select the column when needed):
/// 0: first limb, 1: second limb, 2: result of the limbs' bitwise xor.
#[derive(Debug)]
pub struct BitwiseXor {
pub n_bits: u32,
pub n_expand_bits: u32,
pub col_index: usize,
}
impl BitwiseXor {
pub const fn new(n_bits: u32, n_expand_bits: u32, col_index: usize) -> Self {
assert!(col_index < 3, "col_index must be in range 0..=2");
Self {
n_bits,
n_expand_bits,
col_index,
}
}

pub fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!(
"preprocessed_bitwise_xor_{}_{}_{}",
self.n_bits, self.n_expand_bits, self.col_index
),
}
}
pub const fn limb_bits(&self) -> u32 {
self.n_bits - self.n_expand_bits
}

pub const fn column_bits(&self) -> u32 {
2 * self.limb_bits()
}

#[allow(clippy::type_complexity)]
pub fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
let col: BaseColumn = match self.col_index {
0 => (0..(1 << self.column_bits()))
.map(|i| BaseField::from_u32_unchecked((i >> self.limb_bits()) as u32))
.collect(),
1 => (0..(1 << self.column_bits()))
.map(|i| BaseField::from_u32_unchecked((i & ((1 << self.limb_bits()) - 1)) as u32))
.collect(),
2 => (0..(1 << self.column_bits()))
.map(|i| {
BaseField::from_u32_unchecked(
((i >> self.limb_bits()) ^ (i & ((1 << self.limb_bits()) - 1))) as u32,
)
})
.collect(),
_ => unreachable!(),
};
CircleEvaluation::new(CanonicCoset::new(self.column_bits()).circle_domain(), col)
}

pub fn packed_at(&self, vec_row: usize) -> PackedM31 {
let at_row: [BaseField; N_LANES] = match self.col_index {
0 => (vec_row * N_LANES..(vec_row + 1) * N_LANES)
.map(|i| BaseField::from_u32_unchecked((i >> self.limb_bits()) as u32))
.collect_vec()
.try_into()
.unwrap(),
1 => (vec_row * N_LANES..(vec_row + 1) * N_LANES)
.map(|i| BaseField::from_u32_unchecked((i & ((1 << self.limb_bits()) - 1)) as u32))
.collect_vec()
.try_into()
.unwrap(),
2 => (vec_row * N_LANES..(vec_row + 1) * N_LANES)
.map(|i| {
BaseField::from_u32_unchecked(
((i >> self.limb_bits()) ^ (i & ((1 << self.limb_bits()) - 1))) as u32,
)
})
.collect_vec()
.try_into()
.unwrap(),
_ => unreachable!(),
};
PackedM31::from_array(at_row)
}
}

#[cfg(test)]
mod tests {
use super::*;
const LOG_SIZE: u32 = 8;
use stwo_prover::core::backend::Column;

#[test]
fn test_columns_are_in_decending_order() {
Expand All @@ -114,4 +200,54 @@ mod tests {
.windows(2)
.all(|w| w[0].log_size() >= w[1].log_size()));
}

#[test]
fn test_gen_seq() {
let seq = Seq::new(LOG_SIZE).gen_column_simd();
for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
}
}

#[test]
fn test_packed_at_seq() {
let seq = Seq::new(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));
let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
seq.packed_at(i).to_array()
})
.concat();
assert_eq!(packed_seq, expected_seq);
}

#[test]
fn test_gen_bitwise_xor() {
let bitwise_xor = BitwiseXor::new(LOG_SIZE, 0, 2).gen_column_simd();

for i in 0..(1 << (2 * LOG_SIZE)) {
let a = i >> LOG_SIZE;
let b = i & ((1 << LOG_SIZE) - 1);
let expected = BaseField::from_u32_unchecked((a ^ b) as u32);

assert_eq!(bitwise_xor.at(i), expected,);
}
}

#[test]
fn test_packed_at_bitwise_xor() {
let bitwise_xor = BitwiseXor::new(LOG_SIZE, 0, 2);
let expected_bitwise_xor: [_; 1 << (2 * LOG_SIZE)] = std::array::from_fn(|i| {
let a = i >> LOG_SIZE;
let b = i & ((1 << LOG_SIZE) - 1);
M31::from((a ^ b) as u32)
});

let packed_bitwise_xor =
std::array::from_fn::<_, { (1 << (2 * LOG_SIZE)) / N_LANES }, _>(|i| {
bitwise_xor.packed_at(i).to_array()
})
.concat();

assert_eq!(packed_bitwise_xor, expected_bitwise_xor);
}
}

0 comments on commit c0071cd

Please sign in to comment.