diff --git a/Cargo.lock b/Cargo.lock index 35ff0505a..2c22acb97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,7 @@ dependencies = [ "rrs-succinct", "strum", "strum_macros", + "tiny-keccak", "tracing", ] diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 74562142a..ca07b070a 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -18,6 +18,7 @@ num-traits.workspace = true rrs_lib = { package = "rrs-succinct", version = "0.1.0" } strum.workspace = true strum_macros.workspace = true +tiny-keccak = { version = "2.0.2", features = ["keccak"] } tracing.workspace = true [dev-dependencies] diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 1a855006e..f48bb8a75 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -20,3 +20,5 @@ mod elf; pub use elf::Program; pub mod disassemble; + +mod syscalls; diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs new file mode 100644 index 000000000..5980506cb --- /dev/null +++ b/ceno_emul/src/syscalls.rs @@ -0,0 +1,61 @@ +use crate::{RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; +use anyhow::Result; +use itertools::chain; + +mod keccak_permute; + +// Using the same function codes as sp1: +// https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/core/executor/src/syscalls/code.rs + +pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; + +/// Trace the inputs and effects of a syscall. +pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result { + match function_code { + KECCAK_PERMUTE => Ok(keccak_permute::keccak_permute(vm)), + _ => Err(anyhow::anyhow!("Unknown syscall: {}", function_code)), + } +} + +/// A syscall event, available to the circuit witness generators. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct SyscallWitness { + pub mem_writes: Vec, + pub reg_accesses: Vec, +} + +/// The effects of a syscall to apply on the VM. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct SyscallEffects { + /// The witness being built. Get it with `finalize`. + witness: SyscallWitness, + + /// The next PC after the syscall. Defaults to the next instruction. + pub next_pc: Option, +} + +impl SyscallEffects { + /// Iterate over the register values after the syscall. + pub fn iter_reg_values(&self) -> impl Iterator + '_ { + self.witness + .reg_accesses + .iter() + .map(|op| (op.register_index(), op.value.after)) + } + + /// Iterate over the memory values after the syscall. + pub fn iter_mem_values(&self) -> impl Iterator + '_ { + self.witness + .mem_writes + .iter() + .map(|op| (op.addr, op.value.after)) + } + + /// Keep track of the cycles of registers and memory accesses. + pub fn finalize(mut self, tracer: &mut Tracer) -> SyscallWitness { + for op in chain(&mut self.witness.reg_accesses, &mut self.witness.mem_writes) { + op.previous_cycle = tracer.track_access(op.addr, 0); + } + self.witness + } +} diff --git a/ceno_emul/src/syscalls/keccak_permute.rs b/ceno_emul/src/syscalls/keccak_permute.rs new file mode 100644 index 000000000..63decd3eb --- /dev/null +++ b/ceno_emul/src/syscalls/keccak_permute.rs @@ -0,0 +1,68 @@ +use itertools::{Itertools, izip}; +use tiny_keccak::keccakf; + +use crate::{Change, EmuContext, Platform, VMState, WORD_SIZE, WordAddr, WriteOp}; + +use super::{SyscallEffects, SyscallWitness}; + +const KECCAK_CELLS: usize = 25; // u64 cells +const KECCAK_WORDS: usize = KECCAK_CELLS * 2; // u32 words + +/// Trace the execution of a Keccak permutation. +/// +/// Compatible with: +/// https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/core/executor/src/syscalls/precompiles/keccak256/permute.rs +/// +/// TODO: test compatibility. +pub fn keccak_permute(vm: &VMState) -> SyscallEffects { + let state_ptr = vm.peek_register(Platform::reg_arg0()); + + // Read the argument `state_ptr`. + let reg_accesses = vec![WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(state_ptr, state_ptr), + 0, // Cycle set later in finalize(). + )]; + + let addrs = (state_ptr..) + .step_by(WORD_SIZE) + .take(KECCAK_WORDS) + .map(WordAddr::from) + .collect_vec(); + + // Read Keccak state. + let input = addrs + .iter() + .map(|&addr| vm.peek_memory(addr)) + .collect::>(); + + // Compute Keccak permutation. + let output = { + let mut state = [0_u64; KECCAK_CELLS]; + for (cell, (&lo, &hi)) in izip!(&mut state, input.iter().tuples()) { + *cell = lo as u64 | (hi as u64) << 32; + } + + keccakf(&mut state); + + state.into_iter().flat_map(|c| [c as u32, (c >> 32) as u32]) + }; + + // Write permuted state. + let mem_writes = izip!(addrs, input, output) + .map(|(addr, before, after)| WriteOp { + addr, + value: Change { before, after }, + previous_cycle: 0, // Cycle set later in finalize(). + }) + .collect_vec(); + + assert_eq!(mem_writes.len(), KECCAK_WORDS); + SyscallEffects { + witness: SyscallWitness { + mem_writes, + reg_accesses, + }, + next_pc: None, + } +} diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 01ad33e01..ebb8a19e8 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -4,6 +4,7 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, encode_rv32, + syscalls::{SyscallEffects, SyscallWitness}, }; /// An instruction and its context in an execution trace. That is concrete values of registers and memory. @@ -29,6 +30,8 @@ pub struct StepRecord { rd: Option, memory_op: Option, + + syscall: Option, } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -43,6 +46,14 @@ pub struct MemOp { } impl MemOp { + pub fn new_register_op(idx: RegIdx, value: T, previous_cycle: Cycle) -> MemOp { + MemOp { + addr: Platform::register_vma(idx).into(), + value, + previous_cycle, + } + } + /// Get the register index of this operation. pub fn register_index(&self) -> RegIdx { Platform::register_index(self.addr.into()) @@ -240,6 +251,7 @@ impl StepRecord { }), insn, memory_op, + syscall: None, } } @@ -275,6 +287,10 @@ impl StepRecord { pub fn is_busy_loop(&self) -> bool { self.pc.before == self.pc.after } + + pub fn syscall(&self) -> Option<&SyscallWitness> { + self.syscall.as_ref() + } } #[derive(Debug)] @@ -376,11 +392,18 @@ impl Tracer { }); } + pub fn track_syscall(&mut self, effects: SyscallEffects) { + let witness = effects.finalize(self); + + assert!(self.record.syscall.is_none(), "Only one syscall per step"); + self.record.syscall = Some(witness); + } + /// - Return the cycle when an address was last accessed. /// - Return 0 if this is the first access. /// - Record the current instruction as the origin of the latest access. /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. - fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { + pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { self.latest_accesses .insert(addr, self.record.cycle + subcycle) .unwrap_or(0) diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 838779979..4c0adddd1 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -6,6 +6,7 @@ use crate::{ addr::{ByteAddr, RegIdx, Word, WordAddr}, platform::Platform, rv32im::{Instruction, TrapCause}, + syscalls::{SyscallEffects, handle_syscall}, tracer::{Change, StepRecord, Tracer}, }; use anyhow::{Result, anyhow}; @@ -104,30 +105,57 @@ impl VMState { self.set_pc(0.into()); self.halted = true; } + + fn apply_syscall(&mut self, effects: SyscallEffects) -> Result<()> { + for (addr, value) in effects.iter_mem_values() { + self.memory.insert(addr, value); + } + + for (idx, value) in effects.iter_reg_values() { + self.registers[idx] = value; + } + + let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32); + self.set_pc(next_pc.into()); + + self.tracer.track_syscall(effects); + Ok(()) + } } impl EmuContext for VMState { // Expect an ecall to terminate the program: function HALT with argument exit_code. fn ecall(&mut self) -> Result { let function = self.load_register(Platform::reg_ecall())?; - let arg0 = self.load_register(Platform::reg_arg0())?; if function == Platform::ecall_halt() { - tracing::debug!("halt with exit_code={}", arg0); - + let exit_code = self.load_register(Platform::reg_arg0())?; + tracing::debug!("halt with exit_code={}", exit_code); self.halt(); Ok(true) - } else if self.platform.unsafe_ecall_nop { - // Treat unknown ecalls as all powerful instructions: - // Read two registers, write one register, write one memory word, and branch. - tracing::warn!("ecall ignored: syscall_id={}", function); - self.store_register(Instruction::RD_NULL as RegIdx, 0)?; - // Example ecall effect - any writable address will do. - let addr = (self.platform.stack_top - WORD_SIZE as u32).into(); - self.store_memory(addr, self.peek_memory(addr))?; - self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); - Ok(true) } else { - self.trap(TrapCause::EcallError) + match handle_syscall(self, function) { + Ok(effects) => { + self.apply_syscall(effects)?; + Ok(true) + } + Err(err) if self.platform.unsafe_ecall_nop => { + tracing::warn!("ecall ignored with unsafe_ecall_nop: {:?}", err); + // TODO: remove this example. + // Treat unknown ecalls as all powerful instructions: + // Read two registers, write one register, write one memory word, and branch. + let _arg0 = self.load_register(Platform::reg_arg0())?; + self.store_register(Instruction::RD_NULL as RegIdx, 0)?; + // Example ecall effect - any writable address will do. + let addr = (self.platform.stack_top - WORD_SIZE as u32).into(); + self.store_memory(addr, self.peek_memory(addr))?; + self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); + Ok(true) + } + Err(err) => { + tracing::error!("ecall error: {:?}", err); + self.trap(TrapCause::EcallError) + } + } } } diff --git a/ceno_emul/tests/test_elf.rs b/ceno_emul/tests/test_elf.rs index 7448d4508..bf3056309 100644 --- a/ceno_emul/tests/test_elf.rs +++ b/ceno_emul/tests/test_elf.rs @@ -1,5 +1,7 @@ use anyhow::Result; use ceno_emul::{ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, Platform, StepRecord, VMState}; +use itertools::{Itertools, izip}; +use tiny_keccak::keccakf; #[test] fn test_ceno_rt_mini() -> Result<()> { @@ -72,6 +74,75 @@ fn test_ceno_rt_io() -> Result<()> { Ok(()) } +#[test] +fn test_ceno_rt_keccak() -> Result<()> { + let program_elf = ceno_examples::ceno_rt_keccak; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + let steps = run(&mut state)?; + + // Expect the program to have written successive states between Keccak permutations. + const ITERATIONS: usize = 3; + let keccak_outs = sample_keccak_f(ITERATIONS); + + let all_messages = read_all_messages(&state); + assert_eq!(all_messages.len(), ITERATIONS); + for (got, expect) in izip!(&all_messages, &keccak_outs) { + let got = got + .chunks_exact(8) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec(); + assert_eq!(&got, expect); + } + + // Find the syscall records. + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), ITERATIONS); + + // Check the syscall effects. + for (witness, expect) in izip!(syscalls, keccak_outs) { + assert_eq!(witness.reg_accesses.len(), 1); + assert_eq!( + witness.reg_accesses[0].register_index(), + Platform::reg_arg0() + ); + + assert_eq!(witness.mem_writes.len(), expect.len() * 2); + let got = witness + .mem_writes + .chunks_exact(2) + .map(|write_ops| { + assert_eq!( + write_ops[1].addr.baddr(), + write_ops[0].addr.baddr() + WORD_SIZE as u32 + ); + let lo = write_ops[0].value.after as u64; + let hi = write_ops[1].value.after as u64; + lo | (hi << 32) + }) + .collect_vec(); + assert_eq!(got, expect); + } + + Ok(()) +} + +fn unsafe_platform() -> Platform { + let mut platform = CENO_PLATFORM; + platform.unsafe_ecall_nop = true; + platform +} + +fn sample_keccak_f(count: usize) -> Vec> { + let mut state = [0_u64; 25]; + + (0..count) + .map(|_| { + keccakf(&mut state); + state.into() + }) + .collect_vec() +} + fn run(state: &mut VMState) -> Result> { let steps = state.iter_until_halt().collect::>>()?; eprintln!("Emulator ran for {} steps.", steps.len()); diff --git a/examples-builder/build.rs b/examples-builder/build.rs index 0e8d90441..8fc96ef2d 100644 --- a/examples-builder/build.rs +++ b/examples-builder/build.rs @@ -14,6 +14,7 @@ const EXAMPLES: &[&str] = &[ "ceno_rt_mem", "ceno_rt_mini", "ceno_rt_panic", + "ceno_rt_keccak", ]; const CARGO_MANIFEST_DIR: &str = env!("CARGO_MANIFEST_DIR"); diff --git a/guest/ceno_rt/src/lib.rs b/guest/ceno_rt/src/lib.rs index 8de456c41..6e2623ee0 100644 --- a/guest/ceno_rt/src/lib.rs +++ b/guest/ceno_rt/src/lib.rs @@ -12,6 +12,9 @@ pub use io::info_out; mod params; pub use params::*; +mod syscalls; +pub use syscalls::*; + #[cfg(not(test))] mod panic_handler { use core::panic::PanicInfo; diff --git a/guest/ceno_rt/src/syscalls.rs b/guest/ceno_rt/src/syscalls.rs new file mode 100644 index 000000000..90ace85da --- /dev/null +++ b/guest/ceno_rt/src/syscalls.rs @@ -0,0 +1,24 @@ +// Based on https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/zkvm/entrypoint/src/syscalls/keccak_permute.rs + +const KECCAK_PERMUTE: u32 = 0x00_01_01_09; + +use core::arch::asm; + +/// Executes the Keccak256 permutation on the given state. +/// +/// ### Safety +/// +/// The caller must ensure that `state` is valid pointer to data that is aligned along a four +/// byte boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_keccak_permute(state: &mut [u64; 25]) { + unsafe { + asm!( + "ecall", + in("t0") KECCAK_PERMUTE, + in("a0") state as *mut [u64; 25], + in("a1") 0 + ); + } +} diff --git a/guest/examples/examples/ceno_rt_keccak.rs b/guest/examples/examples/ceno_rt_keccak.rs new file mode 100644 index 000000000..cd68d17db --- /dev/null +++ b/guest/examples/examples/ceno_rt_keccak.rs @@ -0,0 +1,28 @@ +//! Compute the Keccak permutation using a syscall. +//! +//! Iterate multiple times and log the state after each iteration. + +#![no_main] +#![no_std] +extern crate ceno_rt; +use ceno_rt::{info_out, syscall_keccak_permute}; +use core::slice; + +const ITERATIONS: usize = 3; + +ceno_rt::entry!(main); +fn main() { + let mut state = [0_u64; 25]; + + for _ in 0..ITERATIONS { + syscall_keccak_permute(&mut state); + log_state(&state); + } +} + +fn log_state(state: &[u64; 25]) { + let out = unsafe { + slice::from_raw_parts(state.as_ptr() as *const u8, state.len() * size_of::()) + }; + info_out().write_frame(out); +}