diff --git a/Cargo.lock b/Cargo.lock index 140ad3196..80dd93c7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2018,6 +2018,7 @@ dependencies = [ "plonky2", "plonky2_maybe_rayon 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "plonky2_util", + "primitive-types 0.12.2", "rand", "rand_chacha", "ripemd", diff --git a/evm_arithmetization/Cargo.toml b/evm_arithmetization/Cargo.toml index 62133a099..8a6fa6ef2 100644 --- a/evm_arithmetization/Cargo.toml +++ b/evm_arithmetization/Cargo.toml @@ -15,6 +15,7 @@ homepage.workspace = true keywords.workspace = true [dependencies] +__compat_primitive_types = { workspace = true } anyhow = { workspace = true } bytes = { workspace = true } env_logger = { workspace = true } diff --git a/evm_arithmetization/benches/fibonacci_25m_gas.rs b/evm_arithmetization/benches/fibonacci_25m_gas.rs index ca2b74e04..3a237e6fa 100644 --- a/evm_arithmetization/benches/fibonacci_25m_gas.rs +++ b/evm_arithmetization/benches/fibonacci_25m_gas.rs @@ -194,6 +194,7 @@ fn prepare_setup() -> anyhow::Result { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, }) } diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index a42bc3a1e..fe679651c 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -19,7 +19,9 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::debug_inputs; +use crate::generation::jumpdest::{ContextJumpDests, JumpDestTableProcessed, JumpDestTableWitness}; use crate::generation::mpt::{load_linked_lists_and_txn_and_receipt_mpts, TrieRootPtrs}; +use crate::generation::prover_input::{get_proofs_and_jumpdests, CodeDb}; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::state::{ all_ger_prover_inputs_reversed, all_withdrawals_prover_inputs_reversed, GenerationState, @@ -56,6 +58,7 @@ pub(crate) struct Interpreter { /// Counts the number of appearances of each opcode. For debugging purposes. #[allow(unused)] pub(crate) opcode_count: [usize; 0x100], + /// A table of contexts and their reached JUMPDESTs. jumpdest_table: HashMap>, /// `true` if the we are currently carrying out a jumpdest analysis. pub(crate) is_jumpdest_analysis: bool, @@ -71,9 +74,9 @@ pub(crate) struct Interpreter { pub(crate) fn simulate_cpu_and_get_user_jumps( final_label: &str, state: &GenerationState, -) -> Option>> { +) -> (Option, ContextJumpDests) { match state.jumpdest_table { - Some(_) => None, + Some(_) => (None, Default::default()), None => { let halt_pc = KERNEL.global_labels[final_label]; let initial_context = state.registers.context; @@ -94,14 +97,16 @@ pub(crate) fn simulate_cpu_and_get_user_jumps( interpreter .generation_state - .set_jumpdest_analysis_inputs(interpreter.jumpdest_table); + .set_jumpdest_analysis_inputs(interpreter.jumpdest_table.clone()); log::debug!( "Simulated CPU for jumpdest analysis halted after {:?} cycles.", clock ); - - interpreter.generation_state.jumpdest_table + ( + interpreter.generation_state.jumpdest_table, + ContextJumpDests(interpreter.jumpdest_table), + ) } } } @@ -114,7 +119,7 @@ pub(crate) struct ExtraSegmentData { pub(crate) withdrawal_prover_inputs: Vec, pub(crate) ger_prover_inputs: Vec, pub(crate) trie_root_ptrs: TrieRootPtrs, - pub(crate) jumpdest_table: Option>>, + pub(crate) jumpdest_table: Option, pub(crate) next_txn_index: usize, } @@ -148,6 +153,49 @@ pub(crate) fn set_registers_and_run( interpreter.run() } +/// Computes the JUMPDEST proofs for each context. +/// +/// # Arguments +/// +/// - `jumpdest_table_rpc`: The raw table received from RPC. +/// - `code_db`: The corresponding database of contract code used in the trace. +pub(crate) fn set_jumpdest_analysis_inputs_rpc( + jumpdest_table_rpc: &JumpDestTableWitness, + code_db: &CodeDb, +) -> JumpDestTableProcessed { + let ctx_proofs = jumpdest_table_rpc + .0 + .iter() + .flat_map(|(code_addr, ctx_jumpdests)| { + prove_context_jumpdests(&code_db[code_addr], ctx_jumpdests) + }) + .collect(); + JumpDestTableProcessed(ctx_proofs) +} + +/// Orchestrates the proving of all contexts in a specific bytecode. +/// +/// # Arguments +/// +/// - `ctx_jumpdests`: Map from `ctx` to its list of offsets to reached +/// `JUMPDEST`s. +/// - `code`: The bytecode for the contexts. This is the same for all contexts. +fn prove_context_jumpdests( + code: &[u8], + ctx_jumpdests: &ContextJumpDests, +) -> HashMap> { + ctx_jumpdests + .0 + .iter() + .map(|(&ctx, jumpdests)| { + let proofs = jumpdests.last().map_or(Vec::default(), |&largest_address| { + get_proofs_and_jumpdests(code, largest_address, jumpdests.clone()) + }); + (ctx, proofs) + }) + .collect() +} + impl Interpreter { /// Returns an instance of `Interpreter` given `GenerationInputs`, and /// assuming we are initializing with the `KERNEL` code. diff --git a/evm_arithmetization/src/cpu/kernel/tests/add11.rs b/evm_arithmetization/src/cpu/kernel/tests/add11.rs index 89fbdec80..ba1710328 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/add11.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/add11.rs @@ -195,6 +195,7 @@ fn test_add11_yml() { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, }; let initial_stack = vec![]; @@ -376,6 +377,7 @@ fn test_add11_yml_with_exception() { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, }; let initial_stack = vec![]; diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs index b0ef17033..be886893a 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -7,6 +7,7 @@ use plonky2::field::goldilocks_field::GoldilocksField as F; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; +use crate::generation::jumpdest::JumpDestTableProcessed; use crate::witness::operation::CONTEXT_SCALING_FACTOR; #[test] @@ -67,7 +68,10 @@ fn test_jumpdest_analysis() -> Result<()> { interpreter.generation_state.jumpdest_table, // Context 3 has jumpdest 1, 5, 7. All have proof 0 and hence // the list [proof_0, jumpdest_0, ... ] is [0, 1, 0, 5, 0, 7, 8, 40] - Some(HashMap::from([(3, vec![0, 1, 0, 5, 0, 7, 8, 40])])) + Some(JumpDestTableProcessed(HashMap::from([( + 3, + vec![0, 1, 0, 5, 0, 7, 8, 40] + )]))) ); // Run jumpdest analysis with context = 3 @@ -89,6 +93,7 @@ fn test_jumpdest_analysis() -> Result<()> { .jumpdest_table .as_mut() .unwrap() + .0 .get_mut(&CONTEXT) .unwrap() .pop(); @@ -136,7 +141,8 @@ fn test_packed_verification() -> Result<()> { let mut interpreter: Interpreter = Interpreter::new(write_table_if_jumpdest, initial_stack.clone(), None); interpreter.set_code(CONTEXT, code.clone()); - interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); + interpreter.generation_state.jumpdest_table = + Some(JumpDestTableProcessed(HashMap::from([(3, vec![1, 33])]))); interpreter.run()?; @@ -149,7 +155,8 @@ fn test_packed_verification() -> Result<()> { let mut interpreter: Interpreter = Interpreter::new(write_table_if_jumpdest, initial_stack.clone(), None); interpreter.set_code(CONTEXT, code.clone()); - interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); + interpreter.generation_state.jumpdest_table = + Some(JumpDestTableProcessed(HashMap::from([(3, vec![1, 33])]))); assert!(interpreter.run().is_err()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs index e2d5fb41d..48e2c6639 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs @@ -110,6 +110,7 @@ fn test_init_exc_stop() { cur_hash: H256::default(), }, global_exit_roots: vec![], + batch_jumpdest_table: None, }; let initial_stack = vec![]; let initial_offset = KERNEL.global_labels["init"]; diff --git a/evm_arithmetization/src/generation/jumpdest.rs b/evm_arithmetization/src/generation/jumpdest.rs new file mode 100644 index 000000000..84e241d6a --- /dev/null +++ b/evm_arithmetization/src/generation/jumpdest.rs @@ -0,0 +1,64 @@ +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, +}; + +use keccak_hash::H256; +use serde::{Deserialize, Serialize}; + +/// Each `CodeAddress` can be called one or more times, each time creating a new +/// `Context`. Each `Context` will one or more `JumpDests`. +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default)] +pub struct ContextJumpDests(pub HashMap>); + +/// The result after proving a `JumpDestTableWitness`. +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default)] +pub(crate) struct JumpDestTableProcessed(pub HashMap>); + +/// Map `CodeAddress -> (Context -> [JumpDests])` +#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize, Default)] +pub struct JumpDestTableWitness(pub HashMap); + +impl JumpDestTableWitness { + pub fn insert(&mut self, code_hash: &H256, ctx: usize, offset: usize) { + self.0.entry(*code_hash).or_default(); + + self.0.get_mut(code_hash).unwrap().0.entry(ctx).or_default(); + + self.0 + .get_mut(code_hash) + .unwrap() + .0 + .get_mut(&ctx) + .unwrap() + .insert(offset); + + assert!(self.0.contains_key(code_hash)); + assert!(self.0[code_hash].0.contains_key(&ctx)); + assert!(self.0[code_hash].0[&ctx].contains(&offset)); + } +} + +impl Display for JumpDestTableWitness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "=== JumpDest table ===")?; + + for (code, ctxtbls) in &self.0 { + write!(f, "codehash: {:?}\n{}", code, ctxtbls)?; + } + Ok(()) + } +} + +impl Display for ContextJumpDests { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (ctx, offsets) in &self.0 { + write!(f, " ctx: {}, offsets: [", ctx)?; + for offset in offsets { + write!(f, "{:#10x} ", offset)?; + } + writeln!(f, "]")?; + } + Ok(()) + } +} diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 9f1bd05c8..af7c903c9 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::anyhow; use ethereum_types::{Address, BigEndianHash, H256, U256}; +use jumpdest::JumpDestTableWitness; use keccak_hash::keccak; use log::log_enabled; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -33,6 +34,7 @@ use crate::util::{h2u, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryState}; use crate::witness::state::RegistersState; +pub mod jumpdest; pub(crate) mod linked_list; pub mod mpt; pub(crate) mod prover_input; @@ -93,6 +95,10 @@ pub struct GenerationInputs { /// The hash of the current block, and a list of the 256 previous block /// hashes. pub block_hashes: BlockHashes, + + /// A table listing each JUMPDESTs reached in each call context under + /// associated code hash. + pub batch_jumpdest_table: Option, } /// A lighter version of [`GenerationInputs`], which have been trimmed @@ -135,6 +141,10 @@ pub struct TrimmedGenerationInputs { /// The hash of the current block, and a list of the 256 previous block /// hashes. pub block_hashes: BlockHashes, + + /// A list of tables listing each JUMPDESTs reached in each call context + /// under associated code hash. + pub batch_jumpdest_table: Option, } #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -207,6 +217,7 @@ impl GenerationInputs { contract_code: self.contract_code.clone(), block_metadata: self.block_metadata.clone(), block_hashes: self.block_hashes.clone(), + batch_jumpdest_table: self.batch_jumpdest_table.clone(), } } } diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 601e1c525..643c59836 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -10,6 +10,7 @@ use num_bigint::BigUint; use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; +use super::jumpdest::JumpDestTableProcessed; use super::linked_list::LinkedList; use super::mpt::load_state_mpt; use crate::cpu::kernel::cancun_constants::KZG_VERSIONED_HASH; @@ -18,7 +19,9 @@ use crate::cpu::kernel::constants::cancun_constants::{ POINT_EVALUATION_PRECOMPILE_RETURN_VALUE, }; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; -use crate::cpu::kernel::interpreter::simulate_cpu_and_get_user_jumps; +use crate::cpu::kernel::interpreter::{ + set_jumpdest_analysis_inputs_rpc, simulate_cpu_and_get_user_jumps, +}; use crate::curve_pairings::{bls381, CurveAff, CyclicGroup}; use crate::extension_tower::{FieldExt, Fp12, Fp2, BLS381, BLS_BASE, BLS_SCALAR, BN254, BN_BASE}; use crate::generation::prover_input::EvmField::{ @@ -35,6 +38,8 @@ use crate::witness::memory::MemoryAddress; use crate::witness::operation::CONTEXT_SCALING_FACTOR; use crate::witness::util::{current_context_peek, stack_peek}; +pub type CodeDb = HashMap<__compat_primitive_types::H256, Vec>; + /// Prover input function represented as a scoped function name. /// Example: `PROVER_INPUT(ff::bn254_base::inverse)` is represented as /// `ProverInputFn([ff, bn254_base, inverse])`. @@ -392,12 +397,12 @@ impl GenerationState { )); }; - if let Some(ctx_jumpdest_table) = jumpdest_table.get_mut(&context) + if let Some(ctx_jumpdest_table) = jumpdest_table.0.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop() { Ok((next_jumpdest_address + 1).into()) } else { - jumpdest_table.remove(&context); + jumpdest_table.0.remove(&context); Ok(U256::zero()) } } @@ -411,7 +416,7 @@ impl GenerationState { )); }; - if let Some(ctx_jumpdest_table) = jumpdest_table.get_mut(&context) + if let Some(ctx_jumpdest_table) = jumpdest_table.0.get_mut(&context) && let Some(next_jumpdest_proof) = ctx_jumpdest_table.pop() { Ok(next_jumpdest_proof.into()) @@ -797,7 +802,17 @@ impl GenerationState { fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { // Simulate the user's code and (unnecessarily) part of the kernel code, // skipping the validate table call - self.jumpdest_table = simulate_cpu_and_get_user_jumps("terminate_common", self); + + // log::info!("{:?} Generating JUMPDEST tables", tx_hash); + self.jumpdest_table = if let Some(jumpdest_table_rpc) = &self.inputs.batch_jumpdest_table { + let jumpdest_table_processed = + set_jumpdest_analysis_inputs_rpc(jumpdest_table_rpc, &self.inputs.contract_code); + Some(jumpdest_table_processed) + } else { + let (jumpdest_table_processed, _) = + simulate_cpu_and_get_user_jumps("terminate_common", self); + jumpdest_table_processed + }; Ok(()) } @@ -809,8 +824,8 @@ impl GenerationState { &mut self, jumpdest_table: HashMap>, ) { - self.jumpdest_table = Some(HashMap::from_iter(jumpdest_table.into_iter().map( - |(ctx, jumpdest_table)| { + self.jumpdest_table = Some(JumpDestTableProcessed(HashMap::from_iter( + jumpdest_table.into_iter().map(|(ctx, jumpdest_table)| { let code = self.get_code(ctx).unwrap(); if let Some(&largest_address) = jumpdest_table.last() { let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); @@ -818,7 +833,7 @@ impl GenerationState { } else { (ctx, vec![]) } - }, + }), ))); } @@ -890,7 +905,7 @@ impl GenerationState { /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes) is a PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses. -fn get_proofs_and_jumpdests( +pub(crate) fn get_proofs_and_jumpdests( code: &[u8], largest_address: usize, jumpdest_table: std::collections::BTreeSet, diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 96865806a..063bdc251 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -8,6 +8,7 @@ use keccak_hash::keccak; use log::Level; use plonky2::field::types::Field; +use super::jumpdest::JumpDestTableProcessed; use super::mpt::TrieRootPtrs; use super::segments::GenerationSegmentData; use super::{TrieInputs, TrimmedGenerationInputs, NUM_EXTRA_CYCLES_AFTER}; @@ -365,7 +366,7 @@ pub struct GenerationState { /// "proof" for a jump destination is either 0 or an address i > 32 in /// the code (not necessarily pointing to an opcode) such that for every /// j in [i, i+32] it holds that code[j] < 0x7f - j + i. - pub(crate) jumpdest_table: Option>>, + pub(crate) jumpdest_table: Option, } impl GenerationState { diff --git a/evm_arithmetization/src/lib.rs b/evm_arithmetization/src/lib.rs index b76953311..3004c567e 100644 --- a/evm_arithmetization/src/lib.rs +++ b/evm_arithmetization/src/lib.rs @@ -206,6 +206,9 @@ pub mod verifier; pub mod generation; pub mod witness; +pub use generation::jumpdest; +pub use generation::prover_input::CodeDb; + // Utility modules pub mod curve_pairings; pub mod extension_tower; diff --git a/evm_arithmetization/tests/add11_yml.rs b/evm_arithmetization/tests/add11_yml.rs index 8de4e36ae..fe0cddc42 100644 --- a/evm_arithmetization/tests/add11_yml.rs +++ b/evm_arithmetization/tests/add11_yml.rs @@ -201,6 +201,7 @@ fn get_generation_inputs() -> GenerationInputs { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, } } /// The `add11_yml` test case from https://github.com/ethereum/tests diff --git a/evm_arithmetization/tests/erc20.rs b/evm_arithmetization/tests/erc20.rs index 13ef8ee21..3def641f2 100644 --- a/evm_arithmetization/tests/erc20.rs +++ b/evm_arithmetization/tests/erc20.rs @@ -195,6 +195,7 @@ fn test_erc20() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/erc721.rs b/evm_arithmetization/tests/erc721.rs index 4cf347afc..a60dd9d27 100644 --- a/evm_arithmetization/tests/erc721.rs +++ b/evm_arithmetization/tests/erc721.rs @@ -198,6 +198,7 @@ fn test_erc721() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: Default::default(), }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/global_exit_root.rs b/evm_arithmetization/tests/global_exit_root.rs index 302b1a143..97befac32 100644 --- a/evm_arithmetization/tests/global_exit_root.rs +++ b/evm_arithmetization/tests/global_exit_root.rs @@ -97,6 +97,7 @@ fn test_global_exit_root() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/log_opcode.rs b/evm_arithmetization/tests/log_opcode.rs index 5ac537c4e..0ae2821d7 100644 --- a/evm_arithmetization/tests/log_opcode.rs +++ b/evm_arithmetization/tests/log_opcode.rs @@ -254,6 +254,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: Default::default(), }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/selfdestruct.rs b/evm_arithmetization/tests/selfdestruct.rs index 97f41b78d..a62fab7fc 100644 --- a/evm_arithmetization/tests/selfdestruct.rs +++ b/evm_arithmetization/tests/selfdestruct.rs @@ -169,6 +169,7 @@ fn test_selfdestruct() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: None, }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/simple_transfer.rs b/evm_arithmetization/tests/simple_transfer.rs index 81cd62113..ea0697891 100644 --- a/evm_arithmetization/tests/simple_transfer.rs +++ b/evm_arithmetization/tests/simple_transfer.rs @@ -161,6 +161,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: Default::default(), }; let max_cpu_len_log = 20; diff --git a/evm_arithmetization/tests/withdrawals.rs b/evm_arithmetization/tests/withdrawals.rs index b1ad4c715..7522e669b 100644 --- a/evm_arithmetization/tests/withdrawals.rs +++ b/evm_arithmetization/tests/withdrawals.rs @@ -105,6 +105,7 @@ fn test_withdrawals() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, + batch_jumpdest_table: Default::default(), }; let max_cpu_len_log = 20; diff --git a/trace_decoder/src/decoding.rs b/trace_decoder/src/decoding.rs index 8cab99b7c..cd982475c 100644 --- a/trace_decoder/src/decoding.rs +++ b/trace_decoder/src/decoding.rs @@ -1,4 +1,8 @@ -use std::{cmp::min, collections::HashMap, ops::Range}; +use std::{ + cmp::{max, min}, + collections::HashMap, + ops::Range, +}; use anyhow::{anyhow, Context as _}; use ethereum_types::{Address, BigEndianHash, H256, U256, U512}; @@ -7,6 +11,7 @@ use evm_arithmetization::{ mpt::{decode_receipt, AccountRlp}, GenerationInputs, TrieInputs, }, + jumpdest::{ContextJumpDests, JumpDestTableWitness}, proof::{BlockMetadata, ExtraBlockData, TrieRoots}, testing_utils::{ BEACON_ROOTS_CONTRACT_ADDRESS, BEACON_ROOTS_CONTRACT_ADDRESS_HASHED, HISTORY_BUFFER_LENGTH, @@ -79,6 +84,7 @@ pub fn into_txn_proof_gen_ir( .into_iter() .enumerate() .map(|(txn_idx, txn_info)| { + // batch start and end let txn_range = min(txn_idx * batch_size, num_txs)..min(txn_idx * batch_size + batch_size, num_txs); let is_initial_payload = txn_range.start == 0; @@ -581,6 +587,9 @@ fn process_txn_info( delta_out, )?; + let jdts = txn_info.meta.iter().map(|tx| &tx.jumpdest_table); + let batch_jumpdest_table: Option = merge_batch_jumpdest_tables(jdts); + let gen_inputs = GenerationInputs { txn_number_before: extra_data.txn_number_before, gas_used_before: extra_data.gas_used_before, @@ -608,6 +617,7 @@ fn process_txn_info( block_metadata: other_data.b_data.b_meta.clone(), block_hashes: other_data.b_data.b_hashes.clone(), global_exit_roots: vec![], + batch_jumpdest_table, }; // After processing a transaction, we update the remaining accumulators @@ -618,6 +628,47 @@ fn process_txn_info( Ok(gen_inputs) } +fn merge_batch_jumpdest_tables<'t, T>(jdts: T) -> Option +where + T: Iterator>, +{ + let mut merged_table = JumpDestTableWitness::default(); + + let mut max_batch_ctx = 0; + for jdt in jdts { + let tx_offset = max_batch_ctx; + // abort if any transaction in the batch came without RPC JumpDestTable. + if jdt.is_none() { + return None; + } + for (code_hash, ctx_tbl) in jdt.as_ref().unwrap().0.iter() { + for (ctx, jumpsdests) in ctx_tbl.0.iter() { + let batch_ctx = tx_offset + ctx; + max_batch_ctx = max(max_batch_ctx, batch_ctx); + + merged_table + .0 + .entry(*code_hash) + .or_insert(ContextJumpDests::default()); + + merged_table + .0 + .get_mut(code_hash) + .unwrap() + .0 + .entry(batch_ctx) + .or_insert(jumpsdests.clone()); + + debug_assert!(merged_table.0.contains_key(code_hash)); + debug_assert!(merged_table.0[code_hash].0.contains_key(&batch_ctx)); + } + } + max_batch_ctx = tx_offset; + } + + Some(merged_table) +} + impl StateWrite { fn apply_writes_to_state_node( &self, diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index 14a94d29b..96688af73 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -99,6 +99,7 @@ mod wire; use std::collections::HashMap; use ethereum_types::{Address, U256}; +use evm_arithmetization::jumpdest::JumpDestTableWitness; use evm_arithmetization::proof::{BlockHashes, BlockMetadata}; use evm_arithmetization::GenerationInputs; use keccak_hash::keccak as hash; @@ -207,6 +208,9 @@ pub struct TxnMeta { /// Gas used by this txn (Note: not cumulative gas used). pub gas_used: u64, + + /// JumpDest table + pub jumpdest_table: Option, } /// A "trace" specific to an account for a txn. @@ -258,6 +262,17 @@ pub enum ContractCodeUsage { Write(#[serde(with = "crate::hex")] Vec), } +// TODO: Whyt has this has been removed upstream. +impl ContractCodeUsage { + /// Get code hash from a read or write operation of contract code. + pub fn get_code_hash(&self) -> H256 { + match self { + ContractCodeUsage::Read(hash) => *hash, + ContractCodeUsage::Write(bytes) => hash(bytes), + } + } +} + /// Other data that is needed for proof gen. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct OtherBlockData { diff --git a/trace_decoder/src/processed_block_trace.rs b/trace_decoder/src/processed_block_trace.rs index 480928444..33f9dbf69 100644 --- a/trace_decoder/src/processed_block_trace.rs +++ b/trace_decoder/src/processed_block_trace.rs @@ -3,6 +3,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use anyhow::{bail, Context as _}; use ethereum_types::{Address, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use evm_arithmetization::jumpdest::JumpDestTableWitness; use itertools::Itertools; use zk_evm_common::EMPTY_TRIE_HASH; @@ -27,6 +28,7 @@ pub(crate) struct ProcessedBlockTracePreImages { pub extra_code_hash_mappings: Option>>, } +/// batch info #[derive(Debug, Default)] pub(crate) struct ProcessedTxnInfo { pub nodes_used_by_txn: NodesUsedByTxn, @@ -243,6 +245,7 @@ impl TxnInfo { )?, gas_used: txn.meta.gas_used, created_accounts, + jumpdest_table: txn.meta.jumpdest_table.clone(), }); } @@ -293,4 +296,5 @@ pub(crate) struct TxnMetaState { pub receipt_node_bytes: Vec, pub gas_used: u64, pub created_accounts: BTreeSet
, + pub jumpdest_table: Option, } diff --git a/zero_bin/rpc/src/main.rs b/zero_bin/rpc/src/main.rs index b878b8cf1..673a4cb81 100644 --- a/zero_bin/rpc/src/main.rs +++ b/zero_bin/rpc/src/main.rs @@ -217,6 +217,7 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::Registry::default() .with( tracing_subscriber::fmt::layer() + .with_writer(std::io::stderr) .with_ansi(false) .compact() .with_filter(EnvFilter::from_default_env()), diff --git a/zero_bin/rpc/src/native/mod.rs b/zero_bin/rpc/src/native/mod.rs index 95b38d22f..765e4da93 100644 --- a/zero_bin/rpc/src/native/mod.rs +++ b/zero_bin/rpc/src/native/mod.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; @@ -17,8 +16,6 @@ use crate::provider::CachedProvider; mod state; mod txn; -type CodeDb = HashMap<__compat_primitive_types::H256, Vec>; - /// Fetches the prover input for the given BlockId. pub async fn block_prover_input( provider: Arc>, diff --git a/zero_bin/rpc/src/native/txn.rs b/zero_bin/rpc/src/native/txn.rs index 0385b8def..6f2178f4e 100644 --- a/zero_bin/rpc/src/native/txn.rs +++ b/zero_bin/rpc/src/native/txn.rs @@ -1,8 +1,12 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + ops::Not, + sync::OnceLock, +}; use __compat_primitive_types::{H256, U256}; use alloy::{ - primitives::{keccak256, Address, B256}, + primitives::{keccak256, Address, B256, U160}, providers::{ ext::DebugApi as _, network::{eip2718::Encodable2718, Ethereum, Network}, @@ -12,20 +16,37 @@ use alloy::{ eth::Transaction, eth::{AccessList, Block}, trace::geth::{ - AccountState, DiffMode, GethDebugBuiltInTracerType, GethTrace, PreStateConfig, + AccountState, DefaultFrame, DiffMode, GethDebugBuiltInTracerType, GethDebugTracerType, + GethDebugTracingOptions, GethDefaultTracingOptions, GethTrace, PreStateConfig, PreStateFrame, PreStateMode, }, - trace::geth::{GethDebugTracerType, GethDebugTracingOptions}, }, transports::Transport, }; -use anyhow::Context as _; +use anyhow::{ensure, Context as _}; +use evm_arithmetization::{jumpdest::JumpDestTableWitness, CodeDb}; use futures::stream::{FuturesOrdered, TryStreamExt}; use trace_decoder::{ContractCodeUsage, TxnInfo, TxnMeta, TxnTrace}; +use tracing::trace; -use super::CodeDb; use crate::Compat; +/// Provides a way to check in constant time if an address points to a +/// precompile. +fn precompiles() -> &'static HashSet
{ + static PRECOMPILES: OnceLock> = OnceLock::new(); + PRECOMPILES.get_or_init(|| { + HashSet::
::from_iter((1..=0xa).map(|x| Address::from(U160::from(x)))) + }) +} + +/// Provides a way to check in constant time if `op` is in the set of normal +/// halting states. They are defined in the Yellowpaper, 9.4.4. Normal Halting. +fn normal_halting() -> &'static HashSet<&'static str> { + static NORMAL_HALTING: OnceLock> = OnceLock::new(); + NORMAL_HALTING.get_or_init(|| HashSet::<&str>::from_iter(["RETURN", "REVERT", "STOP"])) +} + /// Processes the transactions in the given block and updates the code db. pub(super) async fn process_transactions( block: &Block, @@ -63,17 +84,12 @@ where ProviderT: Provider, TransportT: Transport + Clone, { - let (tx_receipt, pre_trace, diff_trace) = fetch_tx_data(provider, &tx.hash).await?; + let (tx_receipt, pre_trace, diff_trace, structlog_trace) = + fetch_tx_data(provider, &tx.hash).await?; let tx_status = tx_receipt.status(); let tx_receipt = tx_receipt.map_inner(rlp::map_receipt_envelope); let access_list = parse_access_list(tx.access_list.as_ref()); - let tx_meta = TxnMeta { - byte_code: ::TxEnvelope::try_from(tx.clone())?.encoded_2718(), - new_receipt_trie_node_byte: alloy::rlp::encode(tx_receipt.inner), - gas_used: tx_receipt.gas_used as u64, - }; - let (code_db, mut tx_traces) = match (pre_trace, diff_trace) { ( GethTrace::PreStateTracer(PreStateFrame::Default(read)), @@ -86,6 +102,22 @@ where if !tx_status && tx_receipt.contract_address.is_some() { tx_traces.insert(tx_receipt.contract_address.unwrap(), TxnTrace::default()); } + let jumpdest_table: Option = + if let GethTrace::Default(structlog_frame) = structlog_trace { + generate_jumpdest_table(tx, &structlog_frame, &tx_traces) + .await + .map(Some) + .unwrap_or_default() + } else { + unreachable!() + }; + + let tx_meta = TxnMeta { + byte_code: ::TxEnvelope::try_from(tx.clone())?.encoded_2718(), + new_receipt_trie_node_byte: alloy::rlp::encode(tx_receipt.inner), + gas_used: tx_receipt.gas_used as u64, + jumpdest_table, + }; Ok(( code_db, @@ -103,7 +135,12 @@ where async fn fetch_tx_data( provider: &ProviderT, tx_hash: &B256, -) -> anyhow::Result<(::ReceiptResponse, GethTrace, GethTrace), anyhow::Error> +) -> anyhow::Result<( + ::ReceiptResponse, + GethTrace, + GethTrace, + GethTrace, +)> where ProviderT: Provider, TransportT: Transport + Clone, @@ -111,14 +148,21 @@ where let tx_receipt_fut = provider.get_transaction_receipt(*tx_hash); let pre_trace_fut = provider.debug_trace_transaction(*tx_hash, prestate_tracing_options(false)); let diff_trace_fut = provider.debug_trace_transaction(*tx_hash, prestate_tracing_options(true)); + let structlog_trace_fut = + provider.debug_trace_transaction(*tx_hash, structlog_tracing_options()); - let (tx_receipt, pre_trace, diff_trace) = - futures::try_join!(tx_receipt_fut, pre_trace_fut, diff_trace_fut,)?; + let (tx_receipt, pre_trace, diff_trace, structlog_trace) = futures::try_join!( + tx_receipt_fut, + pre_trace_fut, + diff_trace_fut, + structlog_trace_fut + )?; Ok(( tx_receipt.context("Transaction receipt not found.")?, pre_trace, diff_trace, + structlog_trace, )) } @@ -357,3 +401,222 @@ fn prestate_tracing_options(diff_mode: bool) -> GethDebugTracingOptions { ..GethDebugTracingOptions::default() } } + +/// Tracing options for the `debug_traceTransaction` call to get structlog. +/// Used for filling JUMPDEST table. +fn structlog_tracing_options() -> GethDebugTracingOptions { + GethDebugTracingOptions { + config: GethDefaultTracingOptions { + disable_stack: Some(false), + disable_memory: Some(true), + disable_storage: Some(true), + ..GethDefaultTracingOptions::default() + }, + tracer: None, + ..GethDebugTracingOptions::default() + } +} + +/// Generate at JUMPDEST table by simulating the call stack in EVM, +/// using a Geth structlog as input. +async fn generate_jumpdest_table( + tx: &Transaction, + structlog_trace: &DefaultFrame, + tx_traces: &HashMap, +) -> anyhow::Result { + trace!("Generating JUMPDEST table for tx: {}", tx.hash); + ensure!( + structlog_trace.struct_logs.is_empty().not(), + "Structlog is empty." + ); + + let mut jumpdest_table = JumpDestTableWitness::default(); + + let callee_addr_to_code_hash: HashMap = tx_traces + .iter() + .map(|(callee_addr, trace)| (callee_addr, &trace.code_usage)) + .filter(|(_callee_addr, code_usage)| code_usage.is_some()) + .map(|(callee_addr, code_usage)| { + (*callee_addr, code_usage.as_ref().unwrap().get_code_hash()) + }) + .collect(); + + ensure!( + tx.to.is_some(), + format!("No `to`-address for tx: {}.", tx.hash) + ); + let to_address: Address = tx.to.unwrap(); + + // Guard against transactions to a non-contract address. + ensure!( + callee_addr_to_code_hash.contains_key(&to_address), + format!("Callee addr {} is not at contract address", to_address) + ); + let entrypoint_code_hash: H256 = callee_addr_to_code_hash[&to_address]; + + // `None` encodes that previous `entry`` was not a JUMP or JUMPI with true + // condition, `Some(jump_target)` encodes we came from a JUMP or JUMPI with + // true condition and target `jump_target`. + let mut prev_jump = None; + + // Contains the previous op. + let mut prev_op = ""; + + // Call depth of the previous `entry`. We initialize to 0 as this compares + // smaller to 1. + let mut prev_depth = 0; + // The next available context. Starts at 1. Never decrements. + let mut next_ctx_available = 1; + // Immediately use context 1; + let mut call_stack = vec![(entrypoint_code_hash, next_ctx_available)]; + next_ctx_available += 1; + + for (step, entry) in structlog_trace.struct_logs.iter().enumerate() { + let op = entry.op.as_str(); + let curr_depth = entry.depth; + + let exception_occurred = prev_entry_caused_exception(prev_op, prev_depth, curr_depth); + if exception_occurred { + ensure!( + call_stack.is_empty().not(), + "Call stack was empty after exception." + ); + // discard callee frame and return control to caller. + call_stack.pop().unwrap(); + } + + debug_assert!(entry.depth as usize <= next_ctx_available); + ensure!(call_stack.is_empty().not(), "Call stack was empty."); + let (code_hash, ctx) = call_stack.last().unwrap(); + trace!("TX: {:?}", tx.hash); + trace!("STEP: {:?}", step); + trace!("STEPS: {:?}", structlog_trace.struct_logs.len()); + trace!("PREV OPCODE: {}", prev_op); + trace!("OPCODE: {}", entry.op.as_str()); + trace!("CODE: {:?}", code_hash); + trace!("CTX: {:?}", ctx); + trace!("EXCEPTION OCCURED: {:?}", exception_occurred); + trace!("PREV_DEPTH: {:?}", prev_depth); + trace!("CURR_DEPTH: {:?}", curr_depth); + trace!("{:#?}\n", entry); + + match op { + "CALL" | "CALLCODE" | "DELEGATECALL" | "STATICCALL" => { + let callee_address = { + // This is the same stack index (i.e. 2nd) for all four opcodes. See https://ethervm.io/#F1 + ensure!(entry.stack.as_ref().is_some(), "No evm stack found."); + let mut evm_stack = entry.stack.as_ref().unwrap().iter().rev(); + + let callee_raw_opt = evm_stack.nth(1); + ensure!( + callee_raw_opt.is_some(), + "Stack must contain at least two values for a CALL instruction." + ); + let callee_raw = *callee_raw_opt.unwrap(); + + let lower_bytes = U160::from(callee_raw); + Address::from(lower_bytes) + }; + + if precompiles().contains(&callee_address) { + trace!("Called precompile at address {}.", &callee_address); + } else if callee_addr_to_code_hash.contains_key(&callee_address) { + let code_hash = callee_addr_to_code_hash[&callee_address]; + call_stack.push((code_hash, next_ctx_available)); + } else { + // This case happens if calling an EOA. This is described + // under opcode `STOP`: https://www.evm.codes/#00?fork=cancun + trace!( + "Callee address {} has no associated `code_hash`.", + &callee_address + ); + } + next_ctx_available += 1; + prev_jump = None; + } + "JUMP" => { + ensure!(entry.stack.as_ref().is_some(), "No evm stack found."); + let mut evm_stack = entry.stack.as_ref().unwrap().iter().rev(); + + let jump_target_opt = evm_stack.next(); + ensure!( + jump_target_opt.is_some(), + "Stack must contain at least one value for a JUMP instruction." + ); + let jump_target = jump_target_opt.unwrap().to::(); + + prev_jump = Some(jump_target); + } + "JUMPI" => { + ensure!(entry.stack.as_ref().is_some(), "No evm stack found."); + let mut evm_stack = entry.stack.as_ref().unwrap().iter().rev(); + + let jump_target_opt = evm_stack.next(); + ensure!( + jump_target_opt.is_some(), + "Stack must contain at least one value for a JUMPI instruction." + ); + let jump_target = jump_target_opt.unwrap().to::(); + + let jump_condition_opt = evm_stack.next(); + ensure!( + jump_condition_opt.is_some(), + "Stack must contain at least two values for a JUMPI instruction." + ); + let jump_condition = jump_condition_opt.unwrap().is_zero().not(); + + prev_jump = if jump_condition { + Some(jump_target) + } else { + None + }; + } + "JUMPDEST" => { + ensure!( + call_stack.is_empty().not(), + "Call stack was empty when a JUMPDEST was encountered." + ); + let (code_hash, ctx) = call_stack.last().unwrap(); + let jumped_here = if let Some(jmp_target) = prev_jump { + ensure!( + jmp_target == entry.pc, + "The structlog seems to make improper JUMPs." + ); + true + } else { + false + }; + let jumpdest_offset = entry.pc as usize; + if jumped_here { + jumpdest_table.insert(code_hash, *ctx, jumpdest_offset); + } + // else: we do not care about JUMPDESTs reached through fall-through. + prev_jump = None; + } + "EXTCODECOPY" | "EXTCODESIZE" => { + next_ctx_available += 1; + prev_jump = None; + } + "RETURN" | "REVERT" | "STOP" => { + ensure!(call_stack.is_empty().not(), "Call stack was empty at POP."); + call_stack.pop().unwrap(); + prev_jump = None; + } + _ => { + prev_jump = None; + } + } + + prev_depth = curr_depth; + prev_op = op; + } + Ok(jumpdest_table) +} + +/// Check if an exception occurred. An exception will cause the current call +/// context at `depth` to yield control to the caller context at `depth-1`. +/// Returning statements, viz. RETURN, REVERT, STOP, do this too, so we need to +/// exclude them. +fn prev_entry_caused_exception(prev_entry: &str, prev_depth: u64, curr_depth: u64) -> bool { + prev_depth > curr_depth && normal_halting().contains(&prev_entry).not() +}