Skip to content

Commit

Permalink
vm runner: output execution resources (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
noam-starkware authored Jan 31, 2025
1 parent 466480e commit 0768635
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 31 deletions.
1 change: 1 addition & 0 deletions stwo_cairo_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,39 @@ impl BuiltinSegments {
}
}

/// Returns the number of instances for each builtin.
pub fn get_counts(&self) -> HashMap<BuiltinName, usize> {
let builtin_names = &[
BuiltinName::range_check,
BuiltinName::pedersen,
BuiltinName::ecdsa,
BuiltinName::keccak,
BuiltinName::bitwise,
BuiltinName::ec_op,
BuiltinName::poseidon,
BuiltinName::range_check96,
BuiltinName::add_mod,
BuiltinName::mul_mod,
];

builtin_names
.iter()
.filter_map(|&builtin_name| {
let segment = self.get_segment(builtin_name).as_ref()?;
Some((
builtin_name,
Self::get_memory_segment_size(segment)
/ Self::builtin_memory_cells_per_instance(builtin_name),
))
})
.collect()
}

/// Return the size of a memory segment.
fn get_memory_segment_size(segment: &MemorySegmentAddresses) -> usize {
segment.stop_ptr - segment.begin_addr
}

/// Pads a builtin segment with copies of its last instance if that segment isn't None, in
/// which case at least one instance is guaranteed to exist.
/// The segment is padded to the next power of 2 number of instances.
Expand Down
46 changes: 46 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use builtin_segments::BuiltinSegments;
use cairo_vm::types::builtin_name::BuiltinName;
use memory::Memory;
use prover_types::cpu::M31;
use serde::{Deserialize, Serialize};
Expand All @@ -25,3 +26,48 @@ pub struct ProverInput {
pub public_memory_addresses: Vec<u32>,
pub builtins_segments: BuiltinSegments,
}

/// Sizes of memory address to ID and ID to value tables.
#[derive(Debug, Serialize, Deserialize)]
pub struct MemoryTablesSizes {
/// Size of memory address to ID table.
pub address_to_id: usize,
/// Size of memory ID to big value table.
pub id_to_big: usize,
/// Size of memory ID to small value table.
pub id_to_small: usize,
}

/// Execution resources required to compute trace size.
#[derive(Debug, Serialize, Deserialize)]
pub struct ExecutionResources {
/// Map opcode to the number of invocations.
pub opcode_instance_counts: HashMap<String, usize>,
/// Map builtin to the number of invocations.
pub builtin_instance_counts: HashMap<BuiltinName, usize>,
/// Sizes of memory tables.
pub memory_tables_sizes: MemoryTablesSizes,
/// Number of verify instructions, corresponds to the number of unique pc values.
pub verify_instructions_count: usize,
}

impl ExecutionResources {
/// Create execution resources from prover input.
pub fn from_prover_input(input: &ProverInput) -> Self {
ExecutionResources {
opcode_instance_counts: input
.state_transitions
.casm_states_by_opcode
.counts()
.into_iter()
.collect(),
builtin_instance_counts: input.builtins_segments.get_counts(),
memory_tables_sizes: MemoryTablesSizes {
address_to_id: input.memory.address_to_id.len(),
id_to_big: input.memory.f252_values.len(),
id_to_small: input.memory.small_values.len(),
},
verify_instructions_count: input.instruction_by_pc.len(),
}
}
}
74 changes: 46 additions & 28 deletions stwo_cairo_prover/crates/prover/src/input/state_transitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,50 +52,68 @@ pub struct CasmStatesByOpcode {
}

impl CasmStatesByOpcode {
pub fn counts(&self) -> Vec<(&str, usize)> {
pub fn counts(&self) -> Vec<(String, usize)> {
vec![
("generic_opcode", self.generic_opcode.len()),
("add_ap_opcode", self.add_ap_opcode.len()),
("add_ap_opcode_imm", self.add_ap_opcode_imm.len()),
("generic_opcode".to_string(), self.generic_opcode.len()),
("add_ap_opcode".to_string(), self.add_ap_opcode.len()),
(
"add_ap_opcode_op_1_base_fp",
"add_ap_opcode_imm".to_string(),
self.add_ap_opcode_imm.len(),
),
(
"add_ap_opcode_op_1_base_fp".to_string(),
self.add_ap_opcode_op_1_base_fp.len(),
),
("add_opcode_small_imm", self.add_opcode_small_imm.len()),
("add_opcode", self.add_opcode.len()),
("add_opcode_small", self.add_opcode_small.len()),
("add_opcode_imm", self.add_opcode_imm.len()),
("assert_eq_opcode", self.assert_eq_opcode.len()),
(
"assert_eq_opcode_double_deref",
"add_opcode_small_imm".to_string(),
self.add_opcode_small_imm.len(),
),
("add_opcode".to_string(), self.add_opcode.len()),
("add_opcode_small".to_string(), self.add_opcode_small.len()),
("add_opcode_imm".to_string(), self.add_opcode_imm.len()),
("assert_eq_opcode".to_string(), self.assert_eq_opcode.len()),
(
"assert_eq_opcode_double_deref".to_string(),
self.assert_eq_opcode_double_deref.len(),
),
("assert_eq_opcode_imm", self.assert_eq_opcode_imm.len()),
("call_opcode", self.call_opcode.len()),
("call_opcode_rel", self.call_opcode_rel.len()),
(
"call_opcode_op_1_base_fp",
"assert_eq_opcode_imm".to_string(),
self.assert_eq_opcode_imm.len(),
),
("call_opcode".to_string(), self.call_opcode.len()),
("call_opcode_rel".to_string(), self.call_opcode_rel.len()),
(
"call_opcode_op_1_base_fp".to_string(),
self.call_opcode_op_1_base_fp.len(),
),
(
"jnz_opcode_taken_dst_base_fp",
"jnz_opcode_taken_dst_base_fp".to_string(),
self.jnz_opcode_taken_dst_base_fp.len(),
),
("jnz_opcode", self.jnz_opcode.len()),
("jnz_opcode_taken", self.jnz_opcode_taken.len()),
("jnz_opcode_dst_base_fp", self.jnz_opcode_dst_base_fp.len()),
("jump_opcode_rel_imm", self.jump_opcode_rel_imm.len()),
("jump_opcode_rel", self.jump_opcode_rel.len()),
("jnz_opcode".to_string(), self.jnz_opcode.len()),
("jnz_opcode_taken".to_string(), self.jnz_opcode_taken.len()),
(
"jump_opcode_double_deref",
"jnz_opcode_dst_base_fp".to_string(),
self.jnz_opcode_dst_base_fp.len(),
),
(
"jump_opcode_rel_imm".to_string(),
self.jump_opcode_rel_imm.len(),
),
("jump_opcode_rel".to_string(), self.jump_opcode_rel.len()),
(
"jump_opcode_double_deref".to_string(),
self.jump_opcode_double_deref.len(),
),
("jump_opcode", self.jump_opcode.len()),
("mul_opcode_small_imm", self.mul_opcode_small_imm.len()),
("mul_opcode_small", self.mul_opcode_small.len()),
("mul_opcode", self.mul_opcode.len()),
("mul_opcode_imm", self.mul_opcode_imm.len()),
("ret_opcode", self.ret_opcode.len()),
("jump_opcode".to_string(), self.jump_opcode.len()),
(
"mul_opcode_small_imm".to_string(),
self.mul_opcode_small_imm.len(),
),
("mul_opcode_small".to_string(), self.mul_opcode_small.len()),
("mul_opcode".to_string(), self.mul_opcode.len()),
("mul_opcode_imm".to_string(), self.mul_opcode_imm.len()),
("ret_opcode".to_string(), self.ret_opcode.len()),
]
}
}
Expand Down
1 change: 1 addition & 0 deletions stwo_cairo_prover/crates/vm_runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[dependencies]
clap.workspace = true
log.workspace = true
num-traits.workspace = true
serde_json.workspace = true
stwo_cairo_prover.workspace = true
Expand Down
7 changes: 4 additions & 3 deletions stwo_cairo_prover/crates/vm_runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::process::ExitCode;

use clap::Parser;
use stwo_cairo_prover::input::plain::adapt_finished_runner;
use stwo_cairo_prover::input::ProverInput;
use stwo_cairo_prover::input::{ExecutionResources, ProverInput};
use stwo_cairo_utils::binary_utils::run_binary;
use stwo_cairo_utils::vm_utils::{run_vm, VmArgs, VmError};
use thiserror::Error;
Expand Down Expand Up @@ -47,10 +47,11 @@ fn run(args: impl Iterator<Item = String>) -> Result<ProverInput, Error> {
let cairo_runner = run_vm(&args.vm_args)?;
let cairo_input = adapt_finished_runner(cairo_runner, false);

let execution_resources = &cairo_input.state_transitions.casm_states_by_opcode.counts();
let execution_resources = ExecutionResources::from_prover_input(&cairo_input);
log::info!("Execution resources: {:#?}", execution_resources);
std::fs::write(
args.output_path,
serde_json::to_string(execution_resources)?,
serde_json::to_string(&execution_resources)?,
)?;

Ok(cairo_input)
Expand Down

0 comments on commit 0768635

Please sign in to comment.