Skip to content

Commit ea3775c

Browse files
committed
feat: GuestMemory trait (#1574)
Not merging to main Add `GuestMemory` trait and implement for `AddressMap`. We are moving more towards a trait based style to re-use code when different types of memory might be swapped out.
1 parent 706fbdf commit ea3775c

File tree

3 files changed

+94
-19
lines changed

3 files changed

+94
-19
lines changed

crates/vm/src/system/memory/online.rs

+46-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,50 @@ use crate::{
88
system::memory::{offline::INITIAL_TIMESTAMP, MemoryImage, RecordId},
99
};
1010

11+
/// API for guest memory conforming to OpenVM ISA
12+
pub trait GuestMemory {
13+
/// Returns `[pointer:BLOCK_SIZE]_{address_space}`
14+
///
15+
/// # Safety
16+
/// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`,
17+
/// and it must be the exact type used to represent a single memory cell in
18+
/// address space `address_space`. For standard usage,
19+
/// `T` is either `u8` or `F` where `F` is the base field of the ZK backend.
20+
unsafe fn read<T: Copy, const BLOCK_SIZE: usize>(
21+
&mut self, // &mut potentially for logs?
22+
address_space: u32,
23+
pointer: u32,
24+
) -> [T; BLOCK_SIZE];
25+
26+
/// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}`
27+
///
28+
/// # Safety
29+
/// See [`GuestMemory::read`].
30+
unsafe fn write<T: Copy, const BLOCK_SIZE: usize>(
31+
&mut self,
32+
address_space: u32,
33+
pointer: u32,
34+
values: &[T; BLOCK_SIZE],
35+
);
36+
37+
/// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` and returns
38+
/// the previous values.
39+
///
40+
/// # Safety
41+
/// See [`GuestMemory::read`].
42+
#[inline(always)]
43+
unsafe fn replace<T: Copy, const BLOCK_SIZE: usize>(
44+
&mut self,
45+
address_space: u32,
46+
pointer: u32,
47+
values: &[T; BLOCK_SIZE],
48+
) -> [T; BLOCK_SIZE] {
49+
let prev = self.read(address_space, pointer);
50+
self.write(address_space, pointer, values);
51+
prev
52+
}
53+
}
54+
1155
// TO BE DELETED
1256
#[derive(Debug, Clone, Serialize, Deserialize)]
1357
pub enum MemoryLogEntry<T> {
@@ -80,7 +124,7 @@ impl Memory {
80124
) -> (RecordId, [T; BLOCK_SIZE]) {
81125
debug_assert!(BLOCK_SIZE.is_power_of_two());
82126

83-
let prev_data = self.data.set_range((address_space, pointer), values);
127+
let prev_data = self.data.replace(address_space, pointer, values);
84128

85129
// self.log.push(MemoryLogEntry::Write {
86130
// address_space,
@@ -113,7 +157,7 @@ impl Memory {
113157
// len: N,
114158
// });
115159

116-
let values = self.data.get_range((address_space, pointer));
160+
let values = self.data.read(address_space, pointer);
117161
self.timestamp += 1;
118162
(self.last_record_id(), values)
119163
}

crates/vm/src/system/memory/paged_vec.rs

+46-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use openvm_instructions::exe::SparseMemoryImage;
55
use openvm_stark_backend::p3_field::PrimeField32;
66
use serde::{Deserialize, Serialize};
77

8+
use super::online::GuestMemory;
89
use crate::arch::MemoryConfig;
910

1011
/// (address_space, pointer)
@@ -72,6 +73,7 @@ impl<const PAGE_SIZE: usize> PagedVec<PAGE_SIZE> {
7273
ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len);
7374
ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), len);
7475
} else {
76+
assert_eq!(start_page + 1, end_page);
7577
let offset = start % PAGE_SIZE;
7678
let first_part = PAGE_SIZE - offset;
7779
{
@@ -120,11 +122,41 @@ impl<const PAGE_SIZE: usize> PagedVec<PAGE_SIZE> {
120122
unsafe { result.assume_init() }
121123
}
122124

125+
/// # Panics
126+
/// If `start..start + size_of<BLOCK>()` is out of bounds.
127+
#[inline(always)]
128+
pub fn set<BLOCK: Copy>(&mut self, start: usize, values: &BLOCK) {
129+
let len = size_of::<BLOCK>();
130+
let start_page = start / PAGE_SIZE;
131+
let end_page = (start + len - 1) / PAGE_SIZE;
132+
let src = values as *const _ as *const u8;
133+
unsafe {
134+
if start_page == end_page {
135+
let offset = start % PAGE_SIZE;
136+
let page = self.pages[start_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]);
137+
ptr::copy_nonoverlapping(src, page.as_mut_ptr().add(offset), len);
138+
} else {
139+
assert_eq!(start_page + 1, end_page);
140+
let offset = start % PAGE_SIZE;
141+
let first_part = PAGE_SIZE - offset;
142+
{
143+
let page = self.pages[start_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]);
144+
ptr::copy_nonoverlapping(src, page.as_mut_ptr().add(offset), first_part);
145+
}
146+
let second_part = len - first_part;
147+
{
148+
let page = self.pages[end_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]);
149+
ptr::copy_nonoverlapping(src.add(first_part), page.as_mut_ptr(), second_part);
150+
}
151+
}
152+
}
153+
}
154+
123155
/// memcpy of new `values` into pages, memcpy of old existing values into new returned value.
124156
/// # Panics
125157
/// If `from..from + size_of<BLOCK>()` is out of bounds.
126158
#[inline(always)]
127-
pub fn set<BLOCK: Copy>(&mut self, from: usize, values: &BLOCK) -> BLOCK {
159+
pub fn replace<BLOCK: Copy>(&mut self, from: usize, values: &BLOCK) -> BLOCK {
128160
// Create an uninitialized array for old values.
129161
let mut result: MaybeUninit<BLOCK> = MaybeUninit::uninit();
130162
self.set_range_generic(
@@ -278,7 +310,7 @@ impl<const PAGE_SIZE: usize> AddressMap<PAGE_SIZE> {
278310
);
279311
self.paged_vecs
280312
.get_unchecked_mut((addr_space - self.as_offset) as usize)
281-
.set((ptr as usize) * size_of::<T>(), &data)
313+
.replace((ptr as usize) * size_of::<T>(), &data)
282314
}
283315
pub fn is_empty(&self) -> bool {
284316
self.paged_vecs.iter().all(|page| page.is_empty())
@@ -302,11 +334,12 @@ impl<const PAGE_SIZE: usize> AddressMap<PAGE_SIZE> {
302334
}
303335
}
304336

305-
impl<const PAGE_SIZE: usize> AddressMap<PAGE_SIZE> {
306-
/// # Safety
307-
/// - `T` **must** be the correct type for a single memory cell for `addr_space`
308-
/// - Assumes `addr_space` is within the configured memory and not out of bounds
309-
pub unsafe fn get_range<T: Copy, const N: usize>(&self, (addr_space, ptr): Address) -> [T; N] {
337+
impl<const PAGE_SIZE: usize> GuestMemory for AddressMap<PAGE_SIZE> {
338+
unsafe fn read<T: Copy, const BLOCK_SIZE: usize>(
339+
&mut self,
340+
addr_space: u32,
341+
ptr: u32,
342+
) -> [T; BLOCK_SIZE] {
310343
debug_assert_eq!(
311344
size_of::<T>(),
312345
self.cell_size[(addr_space - self.as_offset) as usize]
@@ -316,22 +349,20 @@ impl<const PAGE_SIZE: usize> AddressMap<PAGE_SIZE> {
316349
.get((ptr as usize) * size_of::<T>())
317350
}
318351

319-
/// # Safety
320-
/// - `T` **must** be the correct type for a single memory cell for `addr_space`
321-
/// - Assumes `addr_space` is within the configured memory and not out of bounds
322-
pub unsafe fn set_range<T: Copy, const N: usize>(
352+
unsafe fn write<T: Copy, const BLOCK_SIZE: usize>(
323353
&mut self,
324-
(addr_space, ptr): Address,
325-
values: &[T; N],
326-
) -> [T; N] {
354+
addr_space: u32,
355+
ptr: u32,
356+
values: &[T; BLOCK_SIZE],
357+
) {
327358
debug_assert_eq!(
328359
size_of::<T>(),
329360
self.cell_size[(addr_space - self.as_offset) as usize],
330361
"addr_space={addr_space}"
331362
);
332363
self.paged_vecs
333364
.get_unchecked_mut((addr_space - self.as_offset) as usize)
334-
.set((ptr as usize) * size_of::<T>(), values)
365+
.set((ptr as usize) * size_of::<T>(), values);
335366
}
336367
}
337368

crates/vm/src/system/memory/tree/public_values.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ mod tests {
206206
use super::{UserPublicValuesProof, PUBLIC_VALUES_ADDRESS_SPACE_OFFSET};
207207
use crate::{
208208
arch::{hasher::poseidon2::vm_poseidon2_hasher, SystemConfig},
209-
system::memory::{paged_vec::AddressMap, tree::MemoryNode, CHUNK},
209+
system::memory::{online::GuestMemory, paged_vec::AddressMap, tree::MemoryNode, CHUNK},
210210
};
211211

212212
type F = BabyBear;
@@ -224,7 +224,7 @@ mod tests {
224224
1 << memory_dimensions.address_height,
225225
);
226226
unsafe {
227-
memory.set_range::<F, 1>((pv_as, 15), &[F::ONE]);
227+
memory.write::<F, 1>(pv_as, 15, &[F::ONE]);
228228
}
229229
let mut expected_pvs = F::zero_vec(num_public_values);
230230
expected_pvs[15] = F::ONE;

0 commit comments

Comments
 (0)