Skip to content

Commit

Permalink
Replace ethereum_types with alloy::primitives in smt_trie crate (
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad authored Nov 11, 2024
1 parent 4b4a9df commit dd17ff3
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 88 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion evm_arithmetization/src/witness/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ pub(crate) fn generate_poseidon_general<F: RichField, T: Transition<F>>(

let hash = hashout2u(poseidon_hash_padded_byte_vec(input.clone()));

push_no_write(generation_state, hash);
push_no_write(generation_state, hash.into());

state.push_poseidon(poseidon_op);

Expand Down
20 changes: 12 additions & 8 deletions evm_arithmetization/src/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct KeccakHash;

impl Hasher for PoseidonHash {
fn hash(bytes: &[u8]) -> H256 {
hash_bytecode_h256(bytes)
hash_bytecode_h256(bytes).compat()
}
}

Expand Down Expand Up @@ -367,9 +367,8 @@ impl World for Type2World {
Ok(())
}
fn root(&mut self) -> H256 {
let mut it = [0; 32];
smt_trie::utils::hashout2u(self.as_smt().root).to_big_endian(&mut it);
H256(it)
let root = smt_trie::utils::hashout2u(self.as_smt().root);
H256::from(root.to_be_bytes())
}
}

Expand Down Expand Up @@ -411,7 +410,7 @@ impl Type2World {
);
}
for (
addr,
&addr,
Type2Entry {
balance,
nonce,
Expand All @@ -430,11 +429,16 @@ impl Type2World {
(code_length, key_code_length),
] {
if let Some(value) = value {
smt.set(key_fn(*addr), *value);
let addr = addr.compat();
let value = (*value).compat();
smt.set(key_fn(addr), value);
}
}
for (slot, value) in storage {
smt.set(key_storage(*addr, *slot), *value);
for (&slot, &value) in storage {
let addr = addr.compat();
let slot = slot.compat();
let value = value.compat();
smt.set(key_storage(addr, slot), value);
}
}
smt
Expand Down
2 changes: 1 addition & 1 deletion smt_trie/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ homepage.workspace = true
keywords.workspace = true

[dependencies]
ethereum-types.workspace = true
alloy.workspace = true
plonky2.workspace = true
rand.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
Expand Down
18 changes: 9 additions & 9 deletions smt_trie/src/bits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ops::Add;

use ethereum_types::{BigEndianHash, H256, U256};
use alloy::primitives::{B256, U256};
use serde::{Deserialize, Serialize};

pub type Bit = bool;
Expand All @@ -22,11 +22,11 @@ impl From<U256> for Bits {
}
}

impl From<H256> for Bits {
fn from(packed: H256) -> Self {
impl From<B256> for Bits {
fn from(packed: B256) -> Self {
Bits {
count: 256,
packed: packed.into_uint(),
packed: packed.into(),
}
}
}
Expand All @@ -38,7 +38,7 @@ impl Add for Bits {
assert!(self.count + rhs.count <= 256, "Overflow");
Self {
count: self.count + rhs.count,
packed: self.packed * (U256::one() << rhs.count) + rhs.packed,
packed: self.packed * (U256::from(1) << rhs.count) + rhs.packed,
}
}
}
Expand All @@ -47,7 +47,7 @@ impl Bits {
pub const fn empty() -> Self {
Bits {
count: 0,
packed: U256::zero(),
packed: U256::ZERO,
}
}

Expand All @@ -57,19 +57,19 @@ impl Bits {

pub fn pop_next_bit(&mut self) -> Bit {
assert!(!self.is_empty(), "Cannot pop from empty bits");
let b = !(self.packed & U256::one()).is_zero();
let b = !(self.packed & U256::from(1)).is_zero();
self.packed >>= 1;
self.count -= 1;
b
}

pub fn get_bit(&self, i: usize) -> Bit {
assert!(i < self.count, "Index out of bounds");
!(self.packed & (U256::one() << (self.count - 1 - i))).is_zero()
!(self.packed & (U256::from(1) << (self.count - 1 - i))).is_zero()
}

pub fn push_bit(&mut self, bit: Bit) {
self.packed = self.packed * 2 + U256::from(bit as u64);
self.packed = self.packed * U256::from(2) + U256::from(bit as u64);
self.count += 1;
}

Expand Down
4 changes: 2 additions & 2 deletions smt_trie/src/code.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// Functions to hash contract bytecode using Poseidon.
/// See `hashContractBytecode()` in https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation.
use ethereum_types::H256;
use alloy::primitives::B256;
use plonky2::field::types::Field;
use plonky2::hash::poseidon::{self, Poseidon};

Expand Down Expand Up @@ -43,7 +43,7 @@ pub fn poseidon_pad_byte_vec(bytes: &mut Vec<u8>) {
*bytes.last_mut().unwrap() |= 0x80;
}

pub fn hash_bytecode_h256(code: &[u8]) -> H256 {
pub fn hash_bytecode_h256(code: &[u8]) -> B256 {
hashout2h(hash_contract_bytecode(code.to_vec()))
}

Expand Down
7 changes: 4 additions & 3 deletions smt_trie/src/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

/// This module contains functions to generate keys for the SMT.
/// See https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation.
use ethereum_types::{Address, U256};
use alloy::primitives::{Address, U256};
use plonky2::{field::types::Field, hash::poseidon::Poseidon};

use crate::smt::{Key, F};
Expand Down Expand Up @@ -74,8 +74,9 @@ pub fn key_storage(addr: Address, slot: U256) -> Key {
let capacity: [F; 4] = {
let mut arr = [F::ZERO; 12];
for i in 0..4 {
arr[2 * i] = F::from_canonical_u32(slot.0[i] as u32);
arr[2 * i + 1] = F::from_canonical_u32((slot.0[i] >> 32) as u32);
let limbs = slot.as_limbs()[i];
arr[2 * i] = F::from_canonical_u32(limbs as u32);
arr[2 * i + 1] = F::from_canonical_u32((limbs >> 32) as u32);
}
F::poseidon(arr)[0..4].try_into().unwrap()
};
Expand Down
45 changes: 28 additions & 17 deletions smt_trie/src/smt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::borrow::Borrow;
use std::collections::{HashMap, HashSet};

use ethereum_types::U256;
use alloy::primitives::U256;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::hash::poseidon::{Poseidon, PoseidonHash};
Expand Down Expand Up @@ -145,7 +145,7 @@ impl<D: Db> Smt<D> {
.copied()
.unwrap_or_default()
.is_zero());
U256::zero()
U256::ZERO
};
} else {
let b = keys.get_bit(level as usize);
Expand Down Expand Up @@ -347,7 +347,7 @@ impl<D: Db> Smt<D> {
/// Delete the key in the SMT.
pub fn delete(&mut self, key: Key) {
self.kv_store.remove(&key);
self.set(key, U256::zero());
self.set(key, U256::ZERO);
}

/// Set the key to the hash in the SMT.
Expand Down Expand Up @@ -416,7 +416,7 @@ impl<D: Db> Smt<D> {
&self,
keys: I,
) -> Vec<U256> {
let mut v = vec![U256::zero(); 2]; // For empty hash node.
let mut v = vec![U256::ZERO; 2]; // For empty hash node.
let key = Key(self.root.elements);

let mut keys_to_include = HashSet::new();
Expand All @@ -433,7 +433,7 @@ impl<D: Db> Smt<D> {

serialize(self, key, &mut v, Bits::empty(), &keys_to_include);
if v.len() == 2 {
v.extend([U256::zero(); 2]);
v.extend([U256::ZERO; 2]);
}
v
}
Expand All @@ -457,7 +457,7 @@ fn serialize<D: Db>(

if !keys_to_include.contains(&cur_bits) || smt.db.get_node(&key).is_none() {
let index = v.len();
v.push(HASH_TYPE.into());
v.push(U256::from(HASH_TYPE));
v.push(key2u(key));
index
} else if let Some(node) = smt.db.get_node(&key) {
Expand All @@ -473,22 +473,32 @@ fn serialize<D: Db>(
let rem_key = Key(node.0[0..4].try_into().unwrap());
let val = limbs2f(val_a);
let index = v.len();
v.push(LEAF_TYPE.into());
v.push(U256::from(LEAF_TYPE));
v.push(key2u(rem_key));
v.push(val);
index
} else {
let key_left = Key(node.0[0..4].try_into().unwrap());
let key_right = Key(node.0[4..8].try_into().unwrap());
let index = v.len();
v.push(INTERNAL_TYPE.into());
v.push(U256::zero());
v.push(U256::zero());
let i_left =
serialize(smt, key_left, v, cur_bits.add_bit(false), keys_to_include).into();
v.push(U256::from(INTERNAL_TYPE));
v.push(U256::ZERO);
v.push(U256::ZERO);
let i_left = U256::from(serialize(
smt,
key_left,
v,
cur_bits.add_bit(false),
keys_to_include,
));
v[index + 1] = i_left;
let i_right =
serialize(smt, key_right, v, cur_bits.add_bit(true), keys_to_include).into();
let i_right = U256::from(serialize(
smt,
key_right,
v,
cur_bits.add_bit(true),
keys_to_include,
));
v[index + 2] = i_right;
index
}
Expand All @@ -507,15 +517,16 @@ pub fn hash_serialize_u256(v: &[U256]) -> U256 {
}

fn _hash_serialize(v: &[U256], ptr: usize) -> HashOut {
assert!(v[ptr] <= u8::MAX.into());
match v[ptr].as_u64() as u8 {
let byte: u8 = v[ptr].try_into().expect("U256 should have been <= u8::MAX");
match byte {
HASH_TYPE => u2h(v[ptr + 1]),

INTERNAL_TYPE => {
let mut node = Node([F::ZERO; 12]);
for b in 0..2 {
let child_index = v[ptr + 1 + b];
let child_hash = _hash_serialize(v, child_index.as_usize());
let child_index = *(child_index.as_limbs().first().unwrap()) as usize;
let child_hash = _hash_serialize(v, child_index);
node.0[b * 4..(b + 1) * 4].copy_from_slice(&child_hash.elements);
}
F::poseidon(node.0)[0..4].try_into().unwrap()
Expand Down
Loading

0 comments on commit dd17ff3

Please sign in to comment.