diff --git a/jolt-core/src/benches/bench.rs b/jolt-core/src/benches/bench.rs index 27ec03faf..25a6d9343 100644 --- a/jolt-core/src/benches/bench.rs +++ b/jolt-core/src/benches/bench.rs @@ -1,3 +1,4 @@ +use crate::field::binius::BiniusField; use crate::field::JoltField; use crate::host; use crate::jolt::vm::rv32i_vm::{RV32IJoltVM, C, M}; @@ -6,7 +7,9 @@ use crate::poly::commitment::commitment_scheme::CommitmentScheme; use crate::poly::commitment::hyperkzg::HyperKZG; use crate::poly::commitment::hyrax::HyraxScheme; use crate::poly::commitment::zeromorph::Zeromorph; +use crate::r1cs::spartan; use ark_bn254::{Bn254, Fr, G1Projective}; +use binius_field::BinaryField128bPolyval; use serde::Serialize; #[derive(Debug, Copy, Clone, clap::ValueEnum)] @@ -14,6 +17,7 @@ pub enum PCSType { Hyrax, Zeromorph, HyperKZG, + Binius } #[derive(Debug, Copy, Clone, clap::ValueEnum)] @@ -22,6 +26,7 @@ pub enum BenchType { Sha2, Sha3, Sha2Chain, + Spartan, } #[allow(unreachable_patterns)] // good errors on new BenchTypes @@ -38,6 +43,7 @@ pub fn benchmarks( BenchType::Sha3 => sha3::>(), BenchType::Sha2Chain => sha2chain::>(), BenchType::Fibonacci => fibonacci::>(), + BenchType::Spartan => spartan::(), _ => panic!("BenchType does not have a mapping"), }, PCSType::Zeromorph => match bench_type { @@ -45,6 +51,7 @@ pub fn benchmarks( BenchType::Sha3 => sha3::>(), BenchType::Sha2Chain => sha2chain::>(), BenchType::Fibonacci => fibonacci::>(), + BenchType::Spartan => spartan::>(), _ => panic!("BenchType does not have a mapping"), }, PCSType::HyperKZG => match bench_type { @@ -82,6 +89,18 @@ where prove_example::, PCS, F>("sha3-guest", &vec![5u8; 2048]) } +fn spartan() -> Vec<(tracing::Span, Box)> +where + F: JoltField, +{ + let task = move || { + spartan::bench::bench::(); + }; + + vec![(tracing::info_span!("Spartan"), Box::new(task))] + +} + #[allow(dead_code)] fn serialize_and_print_size(name: &str, item: &impl ark_serialize::CanonicalSerialize) { use std::fs::File; diff --git a/jolt-core/src/poly/commitment/mod.rs b/jolt-core/src/poly/commitment/mod.rs index 05278ecf2..84cb5fc05 100644 --- a/jolt-core/src/poly/commitment/mod.rs +++ b/jolt-core/src/poly/commitment/mod.rs @@ -6,5 +6,5 @@ pub mod kzg; pub mod pedersen; pub mod zeromorph; -#[cfg(test)] +// #[cfg(test)] pub mod mock; diff --git a/jolt-core/src/poly/dense_mlpoly.rs b/jolt-core/src/poly/dense_mlpoly.rs index 992b7fe69..90b6ccda4 100644 --- a/jolt-core/src/poly/dense_mlpoly.rs +++ b/jolt-core/src/poly/dense_mlpoly.rs @@ -83,6 +83,7 @@ impl DensePolynomial { self.len = n; } + #[tracing::instrument(skip_all)] pub fn bound_poly_var_top_par(&mut self, r: &F) { let n = self.len() / 2; let (left, right) = self.Z.split_at_mut(n); diff --git a/jolt-core/src/poly/unipoly.rs b/jolt-core/src/poly/unipoly.rs index 5c963fb2f..e269d432b 100644 --- a/jolt-core/src/poly/unipoly.rs +++ b/jolt-core/src/poly/unipoly.rs @@ -37,7 +37,8 @@ impl UniPoly { fn vandermonde_interpolation(evals: &[F]) -> Vec { let n = evals.len(); - let xs: Vec = (0..n).map(|x| F::from_u64(x as u64).unwrap()).collect(); + // TODO(sragss): Prolly broken for Binius. + let xs: Vec = (0..n).map(|x: usize| F::from_u64(x as u64).unwrap()).collect(); let mut vandermonde: Vec> = Vec::with_capacity(n); for i in 0..n { @@ -127,10 +128,11 @@ impl UniPoly { } pub fn compress(&self) -> CompressedUniPoly { - let coeffs_except_linear_term = [&self.coeffs[..1], &self.coeffs[2..]].concat(); - debug_assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); + // TODO(sragss): Bring back compression. + // let coeffs_except_linear_term = [&self.coeffs[..1], &self.coeffs[2..]].concat(); + // debug_assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); CompressedUniPoly { - coeffs_except_linear_term, + coeffs_except_linear_term: self.coeffs.clone(), } } @@ -211,16 +213,17 @@ impl CompressedUniPoly { // we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // linear_term = hint - 2 * constant_term - deg2 term - deg3 term pub fn decompress(&self, hint: &F) -> UniPoly { - let mut linear_term = - *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; - for i in 1..self.coeffs_except_linear_term.len() { - linear_term -= self.coeffs_except_linear_term[i]; - } - - let mut coeffs = vec![self.coeffs_except_linear_term[0], linear_term]; - coeffs.extend(&self.coeffs_except_linear_term[1..]); - assert_eq!(self.coeffs_except_linear_term.len() + 1, coeffs.len()); - UniPoly { coeffs } + // let mut linear_term = + // *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; + // for i in 1..self.coeffs_except_linear_term.len() { + // linear_term -= self.coeffs_except_linear_term[i]; + // } + + // let mut coeffs = vec![self.coeffs_except_linear_term[0], linear_term]; + // coeffs.extend(&self.coeffs_except_linear_term[1..]); + // assert_eq!(self.coeffs_except_linear_term.len() + 1, coeffs.len()); + // UniPoly { coeffs } + UniPoly { coeffs: self.coeffs_except_linear_term.clone() } } } diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index d0c1c409e..8740d4d1e 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -530,7 +530,7 @@ impl OffsetEqConstraint { } } - #[cfg(test)] + // #[cfg(test)] pub fn empty() -> Self { Self::new( (LC::new(vec![]), false), diff --git a/jolt-core/src/r1cs/jolt_constraints.rs b/jolt-core/src/r1cs/jolt_constraints.rs index 2df3cdd28..b002dd5af 100644 --- a/jolt-core/src/r1cs/jolt_constraints.rs +++ b/jolt-core/src/r1cs/jolt_constraints.rs @@ -1,6 +1,6 @@ use crate::{ assert_static_aux_index, field::JoltField, impl_r1cs_input_lc_conversions, input_range, - jolt::vm::rv32i_vm::C, + jolt::vm::rv32i_vm::C, r1cs::ops::{Term, LC}, }; use super::{ @@ -229,6 +229,7 @@ impl R1CSConstraintBuilder for UniformJoltConstraints { input_range!(JoltIn::ChunksY_0, JoltIn::ChunksY_3).to_vec(), OPERAND_SIZE, ); + cs.constrain_eq_conditional(JoltIn::OpFlags_IsConcat, chunked_x, x); cs.constrain_eq_conditional(JoltIn::OpFlags_IsConcat, chunked_y, y); diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 42c4a8923..8c0e4a2f3 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -216,7 +216,10 @@ impl> UniformSpartanProof { let padded_segment_len = segmented_padded_witness.segment_len; let error_segment_index = i / padded_segment_len; let error_step_index = i % padded_segment_len; - panic!("witness is not a satisfying assignment. Failed on segment {error_segment_index} at step {error_step_index}"); + println!("Az {az:?}"); + println!("Bz {bz:?}"); + println!("Cz {cz:?}"); + panic!("witness is not a satisfying assignment. Failed on segment {error_segment_index} at step {error_step_index} {i}"); } } } @@ -234,6 +237,11 @@ impl> UniformSpartanProof { } }; + let tau_p = poly_tau.clone(); + let az_p = poly_Az.clone(); + let bz_p = poly_Bz.clone(); + let cz_p = poly_Cz.clone(); + let (outer_sumcheck_proof, outer_sumcheck_r, outer_sumcheck_claims) = SumcheckInstanceProof::prove_spartan_cubic::<_>( &F::zero(), // claim is zero @@ -250,6 +258,11 @@ impl> UniformSpartanProof { drop_in_background_thread(poly_Cz); drop_in_background_thread(poly_tau); + assert_eq!(tau_p.evaluate(&outer_sumcheck_r), outer_sumcheck_claims[0]); + assert_eq!(az_p.evaluate(&outer_sumcheck_r), outer_sumcheck_claims[1]); + assert_eq!(bz_p.evaluate(&outer_sumcheck_r), outer_sumcheck_claims[2]); + assert_eq!(cz_p.evaluate(&outer_sumcheck_r), outer_sumcheck_claims[3]); + // claims from the end of sum-check // claim_Az is the (scalar) value v_A = \sum_y A(r_x, y) * z(r_x) where r_x is the sumcheck randomness let (claim_Az, claim_Bz, claim_Cz): (F, F, F) = ( @@ -259,6 +272,12 @@ impl> UniformSpartanProof { ); ProofTranscript::append_scalars(transcript, [claim_Az, claim_Bz, claim_Cz].as_slice()); + // TODO(sragss): + // - Ensure first one is working + // - Check correct claims from outer_sumcheck + // - Check if next RLC is correct + // - Sumcheck recursive check in prover half. + // inner sum-check let r_inner_sumcheck_RLC: F = transcript.challenge_scalar(); let claim_inner_joint = claim_Az @@ -275,6 +294,7 @@ impl> UniformSpartanProof { let mut poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); + println!("\n\nINNER SUMCHECK"); let (inner_sumcheck_proof, inner_sumcheck_r, _claims_inner) = SumcheckInstanceProof::prove_spartan_quadratic::>( &claim_inner_joint, // r_A * v_A + r_B * v_B + r_C * v_C @@ -283,6 +303,8 @@ impl> UniformSpartanProof { &segmented_padded_witness, transcript, ); + println!("Prover _claims_inner: {_claims_inner:?}"); + println!("Prover inner_sumcheck_r: {inner_sumcheck_r:?}"); drop_in_background_thread(poly_ABC); // Requires 'r_col_segment_bits' to index the (const, segment). Within that segment we index the step using 'r_col_step' @@ -337,6 +359,7 @@ impl> UniformSpartanProof { .map(|_i| transcript.challenge_scalar()) .collect::>(); + println!("Verify OuterSumcheck"); let (claim_outer_final, r_x) = self .outer_sumcheck_proof .verify(F::zero(), num_rounds_x, 3, transcript) @@ -365,6 +388,7 @@ impl> UniformSpartanProof { + r_inner_sumcheck_RLC * self.outer_sumcheck_claims.1 + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * self.outer_sumcheck_claims.2; + println!("\n\nVerify"); let (claim_inner_final, inner_sumcheck_r) = self .inner_sumcheck_proof .verify(claim_inner_joint, num_rounds_y, 2, transcript) @@ -384,6 +408,9 @@ impl> UniformSpartanProof { + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * eval_c; let right_expected = eval_Z; let claim_inner_final_expected = left_expected * right_expected; + println!("Verifier r: {inner_sumcheck_r:?}"); + println!("Verifier rA + rB + rC {:?}", left_expected); + println!("Verifier Z(r) {:?}", right_expected); if claim_inner_final != claim_inner_final_expected { return Err(SpartanError::InvalidInnerSumcheckClaim); } @@ -459,3 +486,297 @@ mod test { .expect("Spartan verifier failed"); } } + +pub mod bench { + use super::*; + use crate::{field::binius::BiniusField, impl_r1cs_input_lc_conversions, poly::commitment::{commitment_scheme::{CommitShape, CommitmentScheme}, mock::MockCommitScheme}, r1cs::builder::{OffsetEqConstraint, R1CSBuilder, R1CSConstraintBuilder}}; + use binius_field::BinaryField128bPolyval as BF; + + #[allow(non_camel_case_types)] + #[derive( + strum_macros::EnumIter, + strum_macros::EnumCount, + Clone, + Copy, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + )] + #[repr(usize)] + pub enum MediumTestIn { + Q, + R, + S, + T, + } + impl ConstraintInput for MediumTestIn {} + impl_r1cs_input_lc_conversions!(MediumTestIn); + + pub fn bench() { + + let mut uniform_builder = R1CSBuilder::::new(); + + struct TestConstraints(); + impl R1CSConstraintBuilder for TestConstraints { + type Inputs = MediumTestIn; + fn build_constraints(&self, builder: &mut R1CSBuilder) { + // if (Q) { R * S == T } + // else { R + S == T } + // let prod = builder.allocate_prod(MediumTestIn::R, MediumTestIn::S); + // builder.constrain_eq_conditional(MediumTestIn::Q - 1, MediumTestIn::R + MediumTestIn::S, MediumTestIn::T); + + // if (Q) { R == T } + // else { R == S } + builder.constrain_eq_conditional(MediumTestIn::Q, MediumTestIn::R, MediumTestIn::T); + builder.constrain_eq_conditional(MediumTestIn::Q - 1, MediumTestIn::R, MediumTestIn::S); + } + } + + let constraints = TestConstraints(); + constraints.build_constraints(&mut uniform_builder); + + let repeat = 1 << 25; + let num_steps_pad = 4 * repeat; + let combined_builder = + CombinedUniformBuilder::construct(uniform_builder, num_steps_pad, OffsetEqConstraint::empty()); + let key = UniformSpartanKey::from_builder(&combined_builder); + + + // type Fr = BiniusField; + let witness_segments: Vec> = vec![ + vec![F::from_u64(1).unwrap(), F::from_u64(1).unwrap(), F::from_u64(0).unwrap(), F::from_u64(1).unwrap()], /* Q */ + vec![F::from_u64(5).unwrap(), F::from_u64(8).unwrap(), F::from_u64(10).unwrap(), F::from_u64(400).unwrap()], /* R */ + vec![F::from_u64(9).unwrap(), F::from_u64(4).unwrap(), F::from_u64(10).unwrap(), F::from_u64(10).unwrap()], /* S */ + vec![F::from_u64(5).unwrap(), F::from_u64(8).unwrap(), F::from_u64(7).unwrap(), F::from_u64(400).unwrap()], /* T */ + ]; + let witness_segments: Vec> = witness_segments + .into_iter() + .map(|segment| { + let mut extended_segment = Vec::with_capacity(segment.len() * repeat); + for _ in 0..repeat { + extended_segment.extend_from_slice(&segment); + } + extended_segment + }) + .collect(); + + // Create a witness and commit + let witness_segments_ref: Vec<&[F]> = witness_segments + .iter() + .map(|segment| segment.as_slice()) + .collect(); + let gens = MockCommitScheme::::setup(&[CommitShape::new(16 * repeat, BatchType::Small)]); + let witness_commitment = + MockCommitScheme::batch_commit(&witness_segments_ref, &gens, BatchType::Small); + + // Prove spartan! + let mut prover_transcript = ProofTranscript::new(b"stuff"); + let proof = + UniformSpartanProof::>::prove_precommitted::< + MediumTestIn, + >( + &gens, + combined_builder, + &key, + witness_segments, + &mut prover_transcript, + ) + .unwrap(); + + let mut verifier_transcript = ProofTranscript::new(b"stuff"); + let witness_commitment_ref: Vec<&_> = witness_commitment.iter().collect(); + proof + .verify_precommitted( + &key, + witness_commitment_ref, + &gens, + &mut verifier_transcript, + ) + .expect("Spartan verifier failed"); + } +} + + +#[cfg(test)] +mod binius_test { + use ark_std::test_rng; + use binius_field::BinaryField128bPolyval as BF; + + use crate::{field::binius::BiniusField, poly::commitment::{binius::Binius128Scheme, commitment_scheme::CommitShape, mock::MockCommitScheme}, r1cs::test::{add_mul_builder_key, simp_test_builder_key, MediumTestIn, SimpTestIn}}; + + use super::*; + + #[test] + fn sumcheck() { + // Test sum_{x \ in {0,1}^n}{eq(r, x) * [Az(x) * Bz(x) - Cz(x)} = 0 + + // 1. Compute eq_table = eq(r, _) + // 2. Assign Az, Bz, Cz + // 3. Test Az * Bz - Cz == 0 + // 4. Sumcheck + // - Evaluate at {0, 1, 2, 3} + // - Bind + // 5. Sumcheck verifier + + const LOG_LEN: usize = 4; + const LEN: usize = 1 << LOG_LEN; + + + let mut rng = test_rng(); + let r: Vec> = (0..LOG_LEN).into_iter().map(|_| BiniusField::::random(&mut rng)).collect(); + + // 1. Compute eq_table = eq(r, _) + let mut tau = DensePolynomial::new(EqPolynomial::evals(&r)); + + type F = BiniusField; + + // 2. Assign Az, Bz, Cz + let az: Vec = vec![BiniusField::::from_u64(1).unwrap(); LEN]; + let bz: Vec = vec![BiniusField::::from_u64(2).unwrap(); LEN]; + let cz: Vec = vec![BiniusField::::from_u64(2).unwrap(); LEN]; + let mut az_poly = DensePolynomial::new(az); + let mut bz_poly = DensePolynomial::new(bz); + let mut cz_poly = DensePolynomial::new(cz); + + let v_tau = tau.clone(); + let v_az_poly = az_poly.clone(); + let v_bz_poly = bz_poly.clone(); + let v_cz_poly = cz_poly.clone(); + + // 3. Test Az * Bz - Cz == 0 + for ((az, bz), cz) in az_poly.evals_ref().iter().zip(bz_poly.evals_ref()).zip(cz_poly.evals_ref()) { + assert_eq!(az * bz, *cz); + } + + // 4. Sumcheck + // - Evaluate at {0, 1, 2, 3} + // - Bind + + let comb_func_outer = |eq: &F, a: &F, b: &F, c: &F| -> F { + *eq * ( a * b - c ) + }; + + let mut transcript = ProofTranscript::new(b"test"); + let (outer_sumcheck_proof, outer_sumcheck_r, outer_sumcheck_claims) = + SumcheckInstanceProof::prove_spartan_cubic::<_>( + &F::from_u64(0).unwrap(), + LOG_LEN, + &mut tau, + &mut az_poly, + &mut bz_poly, + &mut cz_poly, + comb_func_outer, + &mut transcript, + ); + + let mut verify_transcript = ProofTranscript::new(b"test"); + let (f_r, r) = outer_sumcheck_proof.verify(F::from_u64(0).unwrap(), LOG_LEN, 3, &mut verify_transcript).unwrap(); + assert_eq!(outer_sumcheck_r, r); + + let tau_r = v_tau.evaluate(&r); + let az_r = v_az_poly.evaluate(&r); + let bz_r = v_bz_poly.evaluate(&r); + let cz_r = v_cz_poly.evaluate(&r); + assert_eq!(outer_sumcheck_claims[0], tau_r); + assert_eq!(outer_sumcheck_claims[1], az_r); + assert_eq!(outer_sumcheck_claims[2], bz_r); + assert_eq!(outer_sumcheck_claims[3], cz_r); + let verifier_eval = tau_r * (az_r * bz_r - cz_r); + + assert_eq!(f_r, verifier_eval); + } + + #[test] + fn integration() { + type Fr = BiniusField; + let (builder, key) = simp_test_builder_key(); + let witness_segments: Vec> = vec![ + vec![Fr::from_u64(1).unwrap(), Fr::from_u64(5).unwrap(), Fr::from_u64(9).unwrap(), Fr::from_u64(13).unwrap()], /* Q */ + vec![Fr::from_u64(1).unwrap(), Fr::from_u64(5).unwrap(), Fr::from_u64(9).unwrap(), Fr::from_u64(13).unwrap()], /* R */ + vec![Fr::from_u64(1).unwrap(), Fr::from_u64(5).unwrap(), Fr::from_u64(9).unwrap(), Fr::from_u64(13).unwrap()], /* S */ + ]; + + // Create a witness and commit + let witness_segments_ref: Vec<&[Fr]> = witness_segments + .iter() + .map(|segment| segment.as_slice()) + .collect(); + let gens = MockCommitScheme::::setup(&[CommitShape::new(16, BatchType::Small)]); + let witness_commitment = + MockCommitScheme::batch_commit(&witness_segments_ref, &gens, BatchType::Small); + + // Prove spartan! + let mut prover_transcript = ProofTranscript::new(b"stuff"); + let proof = + UniformSpartanProof::>::prove_precommitted::< + SimpTestIn, + >( + &gens, + builder, + &key, + witness_segments, + &mut prover_transcript, + ) + .unwrap(); + + let mut verifier_transcript = ProofTranscript::new(b"stuff"); + let witness_commitment_ref: Vec<&_> = witness_commitment.iter().collect(); + proof + .verify_precommitted( + &key, + witness_commitment_ref, + &gens, + &mut verifier_transcript, + ) + .expect("Spartan verifier failed"); + } + + #[test] + fn integration_two() { + type Fr = BiniusField; + let (builder, key) = add_mul_builder_key(); + let witness_segments: Vec> = vec![ + vec![Fr::from_u64(1).unwrap(), Fr::from_u64(1).unwrap(), Fr::from_u64(0).unwrap(), Fr::from_u64(1).unwrap()], /* Q */ + vec![Fr::from_u64(5).unwrap(), Fr::from_u64(8).unwrap(), Fr::from_u64(10).unwrap(), Fr::from_u64(400).unwrap()], /* R */ + vec![Fr::from_u64(9).unwrap(), Fr::from_u64(4).unwrap(), Fr::from_u64(10).unwrap(), Fr::from_u64(10).unwrap()], /* S */ + vec![Fr::from_u64(5).unwrap(), Fr::from_u64(8).unwrap(), Fr::from_u64(7).unwrap(), Fr::from_u64(400).unwrap()], /* T */ + ]; + + // Create a witness and commit + let witness_segments_ref: Vec<&[Fr]> = witness_segments + .iter() + .map(|segment| segment.as_slice()) + .collect(); + let gens = MockCommitScheme::::setup(&[CommitShape::new(16, BatchType::Small)]); + let witness_commitment = + MockCommitScheme::batch_commit(&witness_segments_ref, &gens, BatchType::Small); + + // Prove spartan! + let mut prover_transcript = ProofTranscript::new(b"stuff"); + let proof = + UniformSpartanProof::>::prove_precommitted::< + MediumTestIn, + >( + &gens, + builder, + &key, + witness_segments, + &mut prover_transcript, + ) + .unwrap(); + + let mut verifier_transcript = ProofTranscript::new(b"stuff"); + let witness_commitment_ref: Vec<&_> = witness_commitment.iter().collect(); + proof + .verify_precommitted( + &key, + witness_commitment_ref, + &gens, + &mut verifier_transcript, + ) + .expect("Spartan verifier failed"); + } +} diff --git a/jolt-core/src/r1cs/test.rs b/jolt-core/src/r1cs/test.rs index 086abd7ae..d7584cc7b 100644 --- a/jolt-core/src/r1cs/test.rs +++ b/jolt-core/src/r1cs/test.rs @@ -124,11 +124,12 @@ pub fn simp_test_builder_key( } } // Q[n] + 4 - S[n+1] == 0 - let offset_eq_constraint = OffsetEqConstraint::new( - (SimpTestIn::S, true), - (SimpTestIn::Q, false), - (SimpTestIn::S + -4, true), - ); + // let offset_eq_constraint = OffsetEqConstraint::new( + // (SimpTestIn::S, true), + // (SimpTestIn::Q, false), + // (SimpTestIn::S + -4, true), + // ); + let offset_eq_constraint = OffsetEqConstraint::empty(); let constraints = TestConstraints(); constraints.build_constraints(&mut uniform_builder); @@ -173,3 +174,58 @@ pub fn simp_test_big_matrices() -> (Vec, Vec, Vec) { (big_a, big_b, big_c) } + + +#[allow(non_camel_case_types)] +#[derive( + strum_macros::EnumIter, + strum_macros::EnumCount, + Clone, + Copy, + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, +)] +#[repr(usize)] +pub enum MediumTestIn { + Q, + R, + S, + T, +} +impl ConstraintInput for MediumTestIn {} +impl_r1cs_input_lc_conversions!(MediumTestIn); + +pub fn add_mul_builder_key() -> (CombinedUniformBuilder, UniformSpartanKey) { + let mut uniform_builder = R1CSBuilder::::new(); + + struct TestConstraints(); + impl R1CSConstraintBuilder for TestConstraints { + type Inputs = MediumTestIn; + fn build_constraints(&self, builder: &mut R1CSBuilder) { + // if (Q) { R * S == T } + // else { R + S == T } + // let prod = builder.allocate_prod(MediumTestIn::R, MediumTestIn::S); + // builder.constrain_eq_conditional(MediumTestIn::Q - 1, MediumTestIn::R + MediumTestIn::S, MediumTestIn::T); + + // if (Q) { R == T } + // else { R == S } + builder.constrain_eq_conditional(MediumTestIn::Q, MediumTestIn::R, MediumTestIn::T); + builder.constrain_eq_conditional(MediumTestIn::Q - 1, MediumTestIn::R, MediumTestIn::S); + } + } + + let constraints = TestConstraints(); + constraints.build_constraints(&mut uniform_builder); + + let _num_steps: usize = 3; + let num_steps_pad = 4; + let combined_builder = + CombinedUniformBuilder::construct(uniform_builder, num_steps_pad, OffsetEqConstraint::empty()); + let key = UniformSpartanKey::from_builder(&combined_builder); + + (combined_builder, key) +} diff --git a/jolt-core/src/subprotocols/sumcheck.rs b/jolt-core/src/subprotocols/sumcheck.rs index f7b43d1b2..b01c7307c 100644 --- a/jolt-core/src/subprotocols/sumcheck.rs +++ b/jolt-core/src/subprotocols/sumcheck.rs @@ -9,6 +9,7 @@ use crate::utils::errors::ProofVerifyError; use crate::utils::mul_0_optimized; use crate::utils::thread::drop_in_background_thread; use crate::utils::transcript::{AppendToTranscript, ProofTranscript}; +use crate::field::OptimizedMul; use ark_serialize::*; use rayon::prelude::*; @@ -92,13 +93,15 @@ impl SumcheckInstanceProof { let mut compressed_polys: Vec> = Vec::new(); for _round in 0..num_rounds { + let span = tracing::span!(tracing::Level::TRACE, "round"); + let _enter = span.enter(); // Vector storing evaluations of combined polynomials g(x) = P_0(x) * ... P_{num_polys} (x) // for points {0, ..., |g(x)|} - let mut eval_points = vec![F::zero(); combined_degree + 1]; + // let mut eval_points = vec![F::zero(); combined_degree + 1]; let mle_half = polys[0].len() / 2; - let accum: Vec> = (0..mle_half) + let eval_points: Vec = (0..mle_half) .into_par_iter() .map(|poly_term_i| { let mut accum = vec![F::zero(); combined_degree + 1]; @@ -112,52 +115,80 @@ impl SumcheckInstanceProof { // eval 0: bound_func is A(low) let params_zero: Vec = polys.iter().map(|poly| poly[poly_term_i]).collect(); - accum[0] += comb_func(¶ms_zero); + accum[0] = comb_func(¶ms_zero); // TODO(#28): Can be computed from prev_round_claim - eval_point_0 let params_one: Vec = polys .iter() .map(|poly| poly[mle_half + poly_term_i]) .collect(); - accum[1] += comb_func(¶ms_one); + accum[1] = comb_func(¶ms_one); + // println!("params_one {params_one:?}"); // D_n(index, r) = D_{n-1}[half + index] + r * (D_{n-1}[half + index] - D_{n-1}[index]) // D_n(index, 0) = D_{n-1}[LOW] // D_n(index, 1) = D_{n-1}[HIGH] // D_n(index, 2) = D_{n-1}[HIGH] + (D_{n-1}[HIGH] - D_{n-1}[LOW]) + // D_n(index, 2) = (1 - 2) * D[LOW] + 2 * D[HIGH] = 2 * D[HIGH] - D[LOW] // D_n(index, 3) = D_{n-1}[HIGH] + (D_{n-1}[HIGH] - D_{n-1}[LOW]) + (D_{n-1}[HIGH] - D_{n-1}[LOW]) + // D_n(index, 2) = (1 - 3) * D[LOW] + 3 * D[HIGH] = D[HIGH] + 2 * [ D[HIGH] - D[LOW] ] // ... - let mut existing_term = params_one; + + // D_n(index, r) = D_{n-1}[half + index] + r * (D_{n-1}[half + index] - D_{n-1}[index]) + // D_n(index, 0) = D_{n-1}[LOW] + // D_n(index, 1) = D_{n-1}[HIGH] + // z_0 = Binius((1,0)) + // z_1 = Binius((1,1)) + // D_n(index, z_0) = (1 + z_0) * D[LOW] + z_0 * D[HIGH] + // D_n(index, z_1) = + // let mut existing_term = params_one; for eval_i in 2..(combined_degree + 1) { let mut poly_evals = vec![F::zero(); polys.len()]; for poly_i in 0..polys.len() { + // let poly = &polys[poly_i]; + // poly_evals[poly_i] = existing_term[poly_i] + // + poly[mle_half + poly_term_i] + // - poly[poly_term_i]; let poly = &polys[poly_i]; - poly_evals[poly_i] = existing_term[poly_i] - + poly[mle_half + poly_term_i] - - poly[poly_term_i]; + let r = F::from_u64(eval_i as u64).unwrap(); + poly_evals[poly_i] = poly[poly_term_i].mul_01_optimized(F::one() - r) + mul_0_optimized(&poly[mle_half + poly_term_i], &r); } - accum[eval_i] += comb_func(&poly_evals); - existing_term = poly_evals; + accum[eval_i] = comb_func(&poly_evals); + // existing_term = poly_evals; } accum - }) - .collect(); + }).reduce( + || vec![F::zero(); combined_degree + 1], + |mut accum, item| { + for (i, val) in item.iter().enumerate() { + accum[i] += *val; + } + accum + } + ); - eval_points - .par_iter_mut() - .enumerate() - .for_each(|(poly_i, eval_point)| { - *eval_point = accum - .par_iter() - .take(mle_half) - .map(|mle| mle[poly_i]) - .sum::(); - }); + // eval_points + // .par_iter_mut() + // .enumerate() + // .for_each(|(poly_i, eval_point)| { + // *eval_point = accum + // .par_iter() + // .take(mle_half) + // .map(|mle| mle[poly_i]) + // .sum::(); + // }); + + // println!("evals: {eval_points:?}"); let round_uni_poly = UniPoly::from_evals(&eval_points); + // println!( + // "[{_round}] eval(0) + eval(1) = {:?}", + // eval_points[0] + eval_points[1] + // ); // append the prover's message to the transcript + // println!("[{_round}] appending to transcript: {round_uni_poly:?}"); round_uni_poly.append_to_transcript(transcript); let r_j = transcript.challenge_scalar(); r.push(r_j); @@ -165,15 +196,81 @@ impl SumcheckInstanceProof { // bound all tables to the verifier's challenege polys .par_iter_mut() - .for_each(|poly| poly.bound_poly_var_top(&r_j)); + .for_each(|poly| poly.bound_poly_var_top_zero_optimized(&r_j)); compressed_polys.push(round_uni_poly.compress()); } + polys + .iter() + .for_each(|poly| debug_assert_eq!(poly.len(), 1)); let final_evals = polys.iter().map(|poly| poly[0]).collect(); (SumcheckInstanceProof::new(compressed_polys), r, final_evals) } + #[tracing::instrument(skip_all, name = "Sumcheck.prove_special")] + pub fn prove_special( + _claim: &F, + num_rounds: usize, + poly_l: &mut DensePolynomial, + poly_r: &mut DensePolynomial, + transcript: &mut ProofTranscript, + ) -> (Self, Vec, Vec) + { + let num_eval_points = 3; + + let mut r: Vec = Vec::new(); + let mut compressed_polys: Vec> = Vec::new(); + + let two = F::from_u64(2).unwrap(); + let lhs_point = F::one() - two; + + for _round in 0..num_rounds { + let span = tracing::span!(tracing::Level::TRACE, "round"); + let _enter = span.enter(); + + let mle_half = poly_l.len() / 2; + + let eval_points: (F, F, F) = (0..mle_half) + .into_par_iter() + .map(|poly_term_i| { + let eval_0 = poly_l[poly_term_i].mul_01_optimized(poly_r[poly_term_i]); + let eval_1 = poly_l[mle_half + poly_term_i].mul_01_optimized(poly_r[mle_half + poly_term_i]); + let eval_left = poly_l[poly_term_i].mul_01_optimized(lhs_point) + mul_0_optimized(&poly_l[mle_half + poly_term_i], &two); + let eval_right= poly_r[poly_term_i].mul_01_optimized(lhs_point) + mul_0_optimized(&poly_r[mle_half + poly_term_i], &two); + let eval_2 = eval_left.mul_01_optimized(eval_right); + (eval_0, eval_1, eval_2) + }).reduce( + || (F::zero(), F::zero(), F::zero()), + |mut accum, item| { + accum.0 += item.0; + accum.1 += item.1; + accum.2 += item.2; + accum + } + ); + + // println!("evals: {eval_points:?}"); + let round_uni_poly = UniPoly::from_evals(&vec![eval_points.0, eval_points.1, eval_points.2]); + round_uni_poly.append_to_transcript(transcript); + let r_j = transcript.challenge_scalar(); + r.push(r_j); + + // bound all tables to the verifier's challenege + rayon::join( + || poly_l.bound_poly_var_top_zero_optimized(&r_j), + || poly_r.bound_poly_var_top_zero_optimized(&r_j), + ); + compressed_polys.push(round_uni_poly.compress()); + } + + assert_eq!(poly_l.len(), 1); + assert_eq!(poly_r.len(), 1); + let final_evals = vec![poly_l[0], poly_r[0]]; + + (SumcheckInstanceProof::new(compressed_polys), r, final_evals) + } + #[inline] #[tracing::instrument( skip_all, @@ -189,6 +286,10 @@ impl SumcheckInstanceProof { where Func: Fn(&F, &F, &F, &F) -> F + Sync, { + let two = F::from_u64(2 as u64).unwrap(); + let one_plus_two = F::one() + two; + let three = F::from_u64(3 as u64).unwrap(); + let one_plus_three = F::one() + three; let len = poly_A.len() / 2; (0..len) .into_par_iter() @@ -196,34 +297,22 @@ impl SumcheckInstanceProof { // eval 0: bound_func is A(low) let eval_point_0 = comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); - let m_A = poly_A[len + i] - poly_A[i]; - let m_B = poly_B[len + i] - poly_B[i]; - let m_C = poly_C[len + i] - poly_C[i]; - let m_D = poly_D[len + i] - poly_D[i]; + let low = i; + let high = len + i; - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + m_A; - let poly_B_bound_point = poly_B[len + i] + m_B; - let poly_C_bound_point = poly_C[len + i] + m_C; - let poly_D_bound_point = poly_D[len + i] + m_D; let eval_point_2 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, + &(one_plus_two.mul_01_optimized(poly_A[low]) + two.mul_01_optimized(poly_A[high])), + &(one_plus_two.mul_01_optimized(poly_B[low]) + two.mul_01_optimized(poly_B[high])), + &(one_plus_two.mul_01_optimized(poly_C[low]) + two.mul_01_optimized(poly_C[high])), + &(one_plus_two.mul_01_optimized(poly_D[low]) + two.mul_01_optimized(poly_D[high])), ); - - // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + m_A; - let poly_B_bound_point = poly_B_bound_point + m_B; - let poly_C_bound_point = poly_C_bound_point + m_C; - let poly_D_bound_point = poly_D_bound_point + m_D; let eval_point_3 = comb_func( - &poly_A_bound_point, - &poly_B_bound_point, - &poly_C_bound_point, - &poly_D_bound_point, + &(one_plus_three.mul_01_optimized(poly_A[low]) + three.mul_01_optimized(poly_A[high])), + &(one_plus_three.mul_01_optimized(poly_B[low]) + three.mul_01_optimized(poly_B[high])), + &(one_plus_three.mul_01_optimized(poly_C[low]) + three.mul_01_optimized(poly_C[high])), + &(one_plus_three.mul_01_optimized(poly_D[low]) + three.mul_01_optimized(poly_D[high])), ); + (eval_point_0, eval_point_2, eval_point_3) }) .reduce( @@ -264,6 +353,7 @@ impl SumcheckInstanceProof { eval_point_2, eval_point_3, ]; + // println!("EVALS: {evals:?}"); UniPoly::from_evals(&evals) }; @@ -302,6 +392,10 @@ impl SumcheckInstanceProof { ) } + // Sam plan + // Flatten W + // Evaluate {0,1,2} normally + #[tracing::instrument(skip_all, name = "Spartan2::sumcheck::prove_spartan_quadratic")] // A fork of `prove_quad` with the 0th round unrolled from the rest of the // for loop. This allows us to pass in `W` and `X` as references instead of @@ -314,6 +408,182 @@ impl SumcheckInstanceProof { poly_A: &mut DensePolynomial, W: &P, transcript: &mut ProofTranscript, + ) -> (Self, Vec, Vec) { + const NEW: bool = true; + + if NEW { + let len = poly_A.len() / 2; + let W_iter = (0..W.len()).into_par_iter().map(move |i| &W[i]); + let zero = F::zero(); + let one = [F::one()]; + let Z_iter = W_iter + .chain(one.par_iter()) + .chain(rayon::iter::repeatn(&zero, len - 1)); + let flat_z: Vec = Z_iter.cloned().collect(); + let mut poly_z = DensePolynomial::new(flat_z); + assert_eq!(poly_A.len(), poly_z.len()); + + let comb_func = |inputs: &[F]| -> F { + debug_assert_eq!(inputs.len(), 2); + inputs[0].mul_01_optimized(inputs[1]) + }; + // let mut polys_vec = vec![poly_A.clone(), poly_z]; + // SumcheckInstanceProof::prove_arbitrary( + // claim, + // num_rounds, + // &mut polys_vec, + // comb_func, + // 2, + // transcript, + // ) + SumcheckInstanceProof::prove_special( + claim, + num_rounds, + // &mut polys_vec, + &mut poly_A.clone(), + &mut poly_z, + // comb_func, + // 2, + transcript, + ) + } else { + let mut r: Vec = Vec::with_capacity(num_rounds); + let mut polys: Vec> = Vec::with_capacity(num_rounds); + let mut claim_per_round = *claim; + + /* Round 0 START */ + + let len = poly_A.len() / 2; + assert_eq!(len, W.len()); + + let poly = { + // eval_point_0 = \sum_i A[i] * B[i] + // where B[i] = W[i] for i in 0..len + let eval_point_0: F = (0..len) + .into_par_iter() + .map(|i| { + if poly_A[i].is_zero() || W[i].is_zero() { + F::zero() + } else { + poly_A[i] * W[i] + } + }) + .sum(); + // eval_point_2 = \sum_i (2 * A[len + i] - A[i]) * (2 * B[len + i] - B[i]) + // where B[i] = W[i] for i in 0..len, B[len] = 1, and B[i] = 0 for i > len + let mut eval_point_2: F = (1..len) + .into_par_iter() + .map(|i| { + if W[i].is_zero() { + F::zero() + } else { + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = -W[i]; + mul_0_optimized(&poly_A_bound_point, &poly_B_bound_point) + } + }) + .sum(); + eval_point_2 += mul_0_optimized( + &(poly_A[len] + poly_A[len] - poly_A[0]), + &(F::from_u64(2).unwrap() - W[0]), + ); + + let evals = [eval_point_0, claim_per_round - eval_point_0, eval_point_2]; + // println!("evals: {evals:?}"); + UniPoly::from_evals(&evals) + }; + + // append the prover's message to the transcript + // println!("[0] appending to transcript: {poly:?}"); + poly.append_to_transcript(transcript); + + //derive the verifier's challenge for the next round + let r_i: F = transcript.challenge_scalar(); + r.push(r_i); + polys.push(poly.compress()); + + // Set up next round + claim_per_round = poly.evaluate(&r_i); + + // bound all tables to the verifier's challenge + let (_, mut poly_B) = rayon::join( + || poly_A.bound_poly_var_top_zero_optimized(&r_i), + || { + // Simulates `poly_B.bound_poly_var_top(&r_i)` + // We need to do this because we don't actually have + // a `MultilinearPolynomial` instance for `poly_B` yet, + // only the constituents of its (Lagrange basis) coefficients + // `W` and `X`. + let zero = F::zero(); + let one = [F::one()]; + let W_iter = (0..W.len()).into_par_iter().map(move |i| &W[i]); + let Z_iter = W_iter + .chain(one.par_iter()) + .chain(rayon::iter::repeatn(&zero, len)); + let left_iter = Z_iter.clone().take(len); + let right_iter = Z_iter.skip(len).take(len); + let B = left_iter + .zip(right_iter) + .map(|(a, b)| if *a == *b { *a } else { *a + r_i * (*b - *a) }) + .collect(); + DensePolynomial::new(B) + }, + ); + + /* Round 0 END */ + + for i in 1..num_rounds { + let poly = { + let (eval_point_0, eval_point_2) = + Self::compute_eval_points_spartan_quadratic(poly_A, &poly_B); + + let evals = [eval_point_0, claim_per_round - eval_point_0, eval_point_2]; + UniPoly::from_evals(&evals) + }; + + // append the prover's message to the transcript + // println!("[{i}] appending to transcript: {poly:?}"); + poly.append_to_transcript(transcript); + + //derive the verifier's challenge for the next round + let r_i: F = transcript.challenge_scalar(); + + r.push(r_i); + polys.push(poly.compress()); + + // Set up next round + claim_per_round = poly.evaluate(&r_i); + + // bound all tables to the verifier's challenege + rayon::join( + || poly_A.bound_poly_var_top_zero_optimized(&r_i), + || poly_B.bound_poly_var_top_zero_optimized(&r_i), + ); + + if i == num_rounds - 1 { + assert_eq!(poly.evaluate(&r_i), poly_A[0] * poly_B[0]); + } + } + + let evals = vec![poly_A[0], poly_B[0]]; + drop_in_background_thread(poly_B); + + (SumcheckInstanceProof::new(polys), r, evals) + } + } + + #[tracing::instrument(skip_all, name = "Spartan2::sumcheck::prove_spartan_quadratic")] + // A fork of `prove_quad` with the 0th round unrolled from the rest of the + // for loop. This allows us to pass in `W` and `X` as references instead of + // passing them in as a single `MultilinearPolynomial`, which would require + // an expensive concatenation. We defer the actual instantation of a + // `MultilinearPolynomial` to the end of the 0th round. + pub fn prove_spartan_quadratic_old>( + claim: &F, + num_rounds: usize, + poly_A: &mut DensePolynomial, + W: &P, + transcript: &mut ProofTranscript, ) -> (Self, Vec, Vec) { let mut r: Vec = Vec::with_capacity(num_rounds); let mut polys: Vec> = Vec::with_capacity(num_rounds); @@ -504,6 +774,7 @@ impl SumcheckInstanceProof { // verify that there is a univariate polynomial for each round assert_eq!(self.compressed_polys.len(), num_rounds); for i in 0..self.compressed_polys.len() { + // println!("[{i}] hint = eval(0) + eval(1) = {e:?}"); let poly = self.compressed_polys[i].decompress(&e); // verify degree bound @@ -518,6 +789,7 @@ impl SumcheckInstanceProof { assert_eq!(poly.eval_at_zero() + poly.eval_at_one(), e); // append the prover's message to the transcript + // println!("[{i}] appending to transcript: {poly:?}"); poly.append_to_transcript(transcript); //derive the verifier's challenge for the next round