Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitwise Xor Struct #379

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion stwo_cairo_prover/crates/prover/src/cairo_air/preprocessed.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::simd::{u32x16, Simd};

use itertools::{chain, Itertools};
use prover_types::simd::LOG_N_LANES;
use stwo_prover::constraint_framework::preprocessed_columns::{IsFirst, PreProcessedColumnId};
Expand Down Expand Up @@ -97,14 +99,61 @@ impl Seq {

pub fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!("preprocessed_seq_{}", self.log_size).to_string(),
id: format!("seq_{}", self.log_size).to_string(),
}
}
}

/// A table of a,b,c, where a,b,c are integers and a ^ b = c.
///
/// # Attributes
///
/// - `n_bits`: The number of bits in each integer.
/// - `col_index`: The column index in the preprocessed table.
#[derive(Debug)]
pub struct BitwiseXor {
n_bits: u32,
col_index: usize,
}
impl BitwiseXor {
pub const fn new(n_bits: u32, col_index: usize) -> Self {
assert!(col_index < 3, "col_index must be in range 0..=2");
Self { n_bits, col_index }
}

pub fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!("bitwise_xor_{}_{}", self.n_bits, self.col_index),
}
}

pub const fn log_size(&self) -> u32 {
2 * self.n_bits
}

pub fn packed_at(&self, vec_row: usize) -> PackedM31 {
let lhs = || -> u32x16 {
(SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) >> self.n_bits
};
let rhs = || -> u32x16 {
(SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32))
& Simd::splat((1 << self.n_bits) - 1)
};
let simd = match self.col_index {
0 => lhs(),
1 => rhs(),
2 => lhs() ^ rhs(),
_ => unreachable!(),
};
unsafe { PackedM31::from_simd_unchecked(simd) }
}
}

#[cfg(test)]
mod tests {
use super::*;
const LOG_SIZE: u32 = 8;
use stwo_prover::core::backend::Column;

#[test]
fn test_columns_are_in_decending_order() {
Expand All @@ -114,4 +163,42 @@ mod tests {
.windows(2)
.all(|w| w[0].log_size() >= w[1].log_size()));
}

#[test]
fn test_gen_seq() {
let seq = Seq::new(LOG_SIZE).gen_column_simd();
for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
}
}

#[test]
fn test_packed_at_seq() {
let seq = Seq::new(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));
let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
seq.packed_at(i).to_array()
})
.concat();
assert_eq!(packed_seq, expected_seq);
}

#[test]
fn test_packed_at_bitwise_xor() {
let bitwise_a = BitwiseXor::new(LOG_SIZE, 0);
let bitwise_b = BitwiseXor::new(LOG_SIZE, 1);
let bitwise_xor = BitwiseXor::new(LOG_SIZE, 2);
let index: usize = 1000;
let a = index / (1 << LOG_SIZE);
let b = index % (1 << LOG_SIZE);
let expected_xor = a ^ b;

let res_a = bitwise_a.packed_at(index / N_LANES).to_array()[index % N_LANES];
let res_b = bitwise_b.packed_at(index / N_LANES).to_array()[index % N_LANES];
let res_xor = bitwise_xor.packed_at(index / N_LANES).to_array()[index % N_LANES];

assert_eq!(res_a.0, a as u32);
assert_eq!(res_b.0, b as u32);
assert_eq!(res_xor.0, expected_xor as u32);
}
}
Loading