Skip to content

Commit

Permalink
Add optional prover parameters file path to the adapted stwo CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
m-kus committed Jan 29, 2025
1 parent 09d76f7 commit df1d498
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 53 deletions.
38 changes: 30 additions & 8 deletions stwo_cairo_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion stwo_cairo_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ serde_json = "1.0.1"
stwo_cairo_prover = { path = "crates/prover", version = "~0.1.0" }
stwo_cairo_utils = { path = "crates/utils", version = "~0.1.0" }
# TODO(ShaharS): take stwo version from the source repository.
stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "90b3e55", features = [
stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "678d4a7", features = [
"parallel",
], default-features = false }
thiserror = { version = "2.0.10", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions stwo_cairo_prover/crates/adapted_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ stwo_cairo_utils.workspace = true
stwo-prover.workspace = true
thiserror.workspace = true
tracing.workspace = true
serde.workspace = true
75 changes: 66 additions & 9 deletions stwo_cairo_prover/crates/adapted_prover/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@ use std::path::PathBuf;
use std::process::ExitCode;

use clap::Parser;
use serde::Serialize;
use stwo_cairo_prover::cairo_air::{
prove_cairo, verify_cairo, CairoVerificationError, ConfigBuilder,
prove_cairo, verify_cairo, CairoVerificationError, ChannelHash, ConfigBuilder, ProverConfig,
ProverParameters,
};
use stwo_cairo_prover::input::vm_import::{adapt_vm_output, VmImportError};
use stwo_cairo_prover::input::ProverInput;
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::BackendForChannel;
use stwo_prover::core::channel::MerkleChannel;
use stwo_prover::core::fri::FriConfig;
use stwo_prover::core::pcs::PcsConfig;
use stwo_prover::core::prover::ProvingError;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
use stwo_prover::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
use thiserror::Error;
use tracing::{span, Level};

Expand All @@ -25,6 +33,9 @@ struct Args {
pub_json: PathBuf,
#[structopt(long = "priv_json")]
priv_json: PathBuf,
/// The path to the JSON file containing the prover parameters (optional).
#[structopt(long = "params_json")]
params_json: Option<PathBuf>,
/// The output file path for the proof.
#[structopt(long = "proof_path")]
proof_path: PathBuf,
Expand Down Expand Up @@ -57,6 +68,30 @@ fn main() -> ExitCode {
run_binary(run, "adapted_stwo")
}

/// Generates proof given the Cairo VM output and prover config/parameters.
/// Serializes the proof as JSON and write to the output path.
/// Verifies the proof in case the respective flag is set.
fn run_inner<MC: MerkleChannel>(
vm_output: ProverInput,
prover_config: ProverConfig,
pcs_config: PcsConfig,
verify: bool,
proof_path: PathBuf,
) -> Result<(), Error>
where
SimdBackend: BackendForChannel<MC>,
MC::H: Serialize,
{
let proof = prove_cairo::<MC>(vm_output, prover_config, pcs_config)?;
std::fs::write(&proof_path, serde_json::to_string(&proof)?)?;

if verify {
verify_cairo::<MC>(proof, pcs_config)?;
log::info!("Proof verified successfully");
}
Ok(())
}

fn run(args: impl Iterator<Item = String>) -> Result<(), Error> {
let _span = span!(Level::INFO, "run").entered();
let args = Args::try_parse_from(args)?;
Expand All @@ -73,15 +108,37 @@ fn run(args: impl Iterator<Item = String>) -> Result<(), Error> {
vm_output.state_transitions.casm_states_by_opcode
);

// TODO(Ohad): Propagate hash from CLI args.
let proof = prove_cairo::<Blake2sMerkleChannel>(vm_output, prover_config)?;

std::fs::write(args.proof_path, serde_json::to_string(&proof)?)?;
let params: ProverParameters = match args.params_json {
Some(path) => serde_json::from_str(&std::fs::read_to_string(path)?)?,
None => default_parameters(),
};

if args.verify {
verify_cairo::<Blake2sMerkleChannel>(proof)?;
log::info!("Proof verified successfully");
}
let run_inner_fn = match params.channel_hash {
ChannelHash::Blake2s => run_inner::<Blake2sMerkleChannel>,
ChannelHash::Poseidon252 => run_inner::<Poseidon252MerkleChannel>,
};

run_inner_fn(
vm_output,
prover_config,
params.pcs_config,
args.verify,
args.proof_path,
)?;
Ok(())
}

/// The default prover paramters (96 bits of security).
pub fn default_parameters() -> ProverParameters {
ProverParameters {
channel_hash: ChannelHash::Blake2s,
pcs_config: PcsConfig {
pow_bits: 26,
fri_config: FriConfig {
log_last_layer_degree_bound: 7,
log_blowup_factor: 1,
n_queries: 70,
},
},
}
}
79 changes: 48 additions & 31 deletions stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use air::{lookup_sum, CairoClaimGenerator, CairoComponents, CairoInteractionElem
use debug_tools::track_cairo_relations;
use num_traits::Zero;
use preprocessed::preprocessed_trace_columns;
use serde::{Deserialize, Serialize};
use stwo_prover::constraint_framework::relation_tracker::RelationSummary;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::BackendForChannel;
use stwo_prover::core::channel::MerkleChannel;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fri::FriConfig;
use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig};
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
use stwo_prover::core::prover::{prove, verify, ProvingError, VerificationError};
Expand All @@ -32,29 +32,23 @@ pub fn prove_cairo<MC: MerkleChannel>(
track_relations,
display_components,
}: ProverConfig,
pcs_config: PcsConfig,
) -> Result<CairoProof<MC::H>, ProvingError>
where
SimdBackend: BackendForChannel<MC>,
{
let _span = span!(Level::INFO, "prove_cairo").entered();
// TODO(Ohad): Propogate config from CLI args.
let config = PcsConfig {
pow_bits: 0,
fri_config: FriConfig {
log_last_layer_degree_bound: 2,
log_blowup_factor: 1,
n_queries: 15,
},
};

let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(LOG_MAX_ROWS + config.fri_config.log_blowup_factor + 2)
CanonicCoset::new(LOG_MAX_ROWS + pcs_config.fri_config.log_blowup_factor + 2)
.circle_domain()
.half_coset,
);

// Setup protocol.
let channel = &mut MC::C::default();
let mut commitment_scheme = CommitmentSchemeProver::<SimdBackend, MC>::new(config, &twiddles);
let mut commitment_scheme =
CommitmentSchemeProver::<SimdBackend, MC>::new(pcs_config, &twiddles);

// Preprocessed trace.
let mut tree_builder = commitment_scheme.tree_builder();
Expand Down Expand Up @@ -125,19 +119,12 @@ pub fn verify_cairo<MC: MerkleChannel>(
interaction_claim,
stark_proof,
}: CairoProof<MC::H>,
pcs_config: PcsConfig,
) -> Result<(), CairoVerificationError> {
// Verify.
// TODO(Ohad): Propogate config from CLI args.
let config = PcsConfig {
pow_bits: 0,
fri_config: FriConfig {
log_last_layer_degree_bound: 2,
log_blowup_factor: 1,
n_queries: 15,
},
};
let _span = span!(Level::INFO, "verify_cairo").entered();

let channel = &mut MC::C::default();
let commitment_scheme_verifier = &mut CommitmentSchemeVerifier::<MC>::new(config);
let commitment_scheme_verifier = &mut CommitmentSchemeVerifier::<MC>::new(pcs_config);

let log_sizes = claim.log_sizes();

Expand Down Expand Up @@ -207,6 +194,25 @@ impl ConfigBuilder {
}
}

#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
pub struct ProverParameters {
pub channel_hash: ChannelHash,
pub pcs_config: PcsConfig,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ChannelHash {
Blake2s,
Poseidon252,
}

impl Default for ChannelHash {
fn default() -> Self {
Self::Blake2s
}
}

#[derive(Error, Debug)]
pub enum CairoVerificationError {
#[error("Invalid logup sum")]
Expand All @@ -218,6 +224,7 @@ pub enum CairoVerificationError {
#[cfg(test)]
mod tests {
use cairo_lang_casm::casm;
use stwo_prover::core::pcs::PcsConfig;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::ProverConfig;
Expand Down Expand Up @@ -259,8 +266,10 @@ mod tests {

#[test]
fn test_basic_cairo_air() {
let cairo_proof = prove_cairo::<Blake2sMerkleChannel>(test_input(), test_cfg()).unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof).unwrap();
let cairo_proof =
prove_cairo::<Blake2sMerkleChannel>(test_input(), test_cfg(), PcsConfig::default())
.unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof, PcsConfig::default()).unwrap();
}

#[cfg(feature = "slow-tests")]
Expand All @@ -277,21 +286,29 @@ mod tests {

#[test]
fn generate_and_serialise_proof() {
let cairo_proof =
prove_cairo::<Poseidon252MerkleChannel>(test_input(), test_cfg()).unwrap();
let cairo_proof = prove_cairo::<Poseidon252MerkleChannel>(
test_input(),
test_cfg(),
PcsConfig::default(),
)
.unwrap();
let mut output = Vec::new();
CairoSerialize::serialize(&cairo_proof, &mut output);
let proof_str = output.iter().map(|v| v.to_string()).join(",");
let mut file = std::fs::File::create("proof.cairo").unwrap();
file.write_all(proof_str.as_bytes()).unwrap();
verify_cairo::<Poseidon252MerkleChannel>(cairo_proof).unwrap();
verify_cairo::<Poseidon252MerkleChannel>(cairo_proof, PcsConfig::default()).unwrap();
}

#[test]
fn test_full_cairo_air() {
let cairo_proof =
prove_cairo::<Blake2sMerkleChannel>(small_cairo_input(), test_cfg()).unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof).unwrap();
let cairo_proof = prove_cairo::<Blake2sMerkleChannel>(
small_cairo_input(),
test_cfg(),
PcsConfig::default(),
)
.unwrap();
verify_cairo::<Blake2sMerkleChannel>(cairo_proof, PcsConfig::default()).unwrap();
}
}
}
Loading

0 comments on commit df1d498

Please sign in to comment.