Skip to content

Commit

Permalink
Allow preset data for In-Memory provider, rename allow_hosts to allow…
Browse files Browse the repository at this point in the history
…_redis_hosts
  • Loading branch information
iawia002 committed Jul 26, 2024
1 parent a355b1a commit c7cc706
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 45 deletions.
27 changes: 17 additions & 10 deletions crates/test-programs/src/bin/keyvalue_main.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
use test_programs::keyvalue::wasi::keyvalue::{atomics, batch, store};

fn main() {
let bucket = store::open(std::env::var_os("IDENTIFIER").unwrap().to_str().unwrap()).unwrap();
let identifier = std::env::var_os("IDENTIFIER")
.unwrap()
.into_string()
.unwrap();
let bucket = store::open(&identifier).unwrap();

if identifier != "" {
// for In-Memory provider, we have preset this data
assert_eq!(atomics::increment(&bucket, "atomics_key", 5).unwrap(), 5);
}
assert_eq!(atomics::increment(&bucket, "atomics_key", 1).unwrap(), 6);

let resp = bucket.list_keys(None).unwrap();
assert_eq!(resp.keys, vec!["atomics_key".to_string()]);

bucket.set("hello", "world".as_bytes()).unwrap();

let v = bucket.get("hello").unwrap();
assert_eq!(String::from_utf8(v.unwrap()).unwrap(), "world");

assert_eq!(bucket.exists("hello").unwrap(), true);
bucket.delete("hello").unwrap();
let exists = bucket.exists("hello").unwrap();
assert_eq!(exists, false);

bucket.set("aa", "bb".as_bytes()).unwrap();
let resp = bucket.list_keys(None).unwrap();
assert_eq!(resp.keys, vec!["aa".to_string()]);

assert_eq!(atomics::increment(&bucket, "atomics_key", 5).unwrap(), 5);
assert_eq!(atomics::increment(&bucket, "atomics_key", 1).unwrap(), 6);
assert_eq!(bucket.exists("hello").unwrap(), false);

batch::set_many(
&bucket,
Expand Down
86 changes: 56 additions & 30 deletions crates/wasi-keyvalue/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ mod generated {
use self::generated::wasi::keyvalue;
use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
use std::fmt::Display;
use url::Url;
use wasmtime::component::{Resource, ResourceTable, ResourceTableError};
Expand Down Expand Up @@ -143,7 +144,9 @@ trait Host {
/// Builder-style structure used to create a [`WasiKeyValueCtx`].
#[derive(Default)]
pub struct WasiKeyValueCtxBuilder {
allowed_hosts: Vec<String>,
in_memory_data: HashMap<String, Vec<u8>>,
#[cfg(feature = "redis")]
allowed_redis_hosts: Vec<String>,
#[cfg(feature = "redis")]
redis_connection_timeout: Option<std::time::Duration>,
#[cfg(feature = "redis")]
Expand All @@ -156,7 +159,21 @@ impl WasiKeyValueCtxBuilder {
Default::default()
}

/// Appends a list of hosts to the allow-listed set each component gets
/// Preset data for the In-Memory provider.
pub fn in_memory_data<I, K, V>(mut self, data: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<Vec<u8>>,
{
self.in_memory_data = data
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
self
}

/// Appends a list of Redis hosts to the allow-listed set each component gets
/// access to. It can be in the format `<hostname>[:port]` or a unix domain
/// socket path.
///
Expand All @@ -167,13 +184,14 @@ impl WasiKeyValueCtxBuilder {
///
/// # fn main() {
/// let ctx = WasiKeyValueCtxBuilder::new()
/// .allow_hosts(&["localhost:1234", "/var/run/redis.sock"])
/// .allow_redis_hosts(&["localhost:1234", "/var/run/redis.sock"])
/// .build();
/// # }
/// ```
pub fn allow_hosts(mut self, hosts: &[impl AsRef<str>]) -> Self {
self.allowed_hosts
.extend(hosts.iter().map(|a| a.as_ref().to_owned()));
#[cfg(feature = "redis")]
pub fn allow_redis_hosts(mut self, hosts: &[impl AsRef<str>]) -> Self {
self.allowed_redis_hosts
.extend(hosts.iter().map(|h| h.as_ref().to_owned()));
self
}

Expand All @@ -194,7 +212,9 @@ impl WasiKeyValueCtxBuilder {
/// Uses the configured context so far to construct the final [`WasiKeyValueCtx`].
pub fn build(self) -> WasiKeyValueCtx {
WasiKeyValueCtx {
allowed_hosts: self.allowed_hosts,
in_memory_data: self.in_memory_data,
#[cfg(feature = "redis")]
allowed_redis_hosts: self.allowed_redis_hosts,
#[cfg(feature = "redis")]
redis_connection_timeout: self.redis_connection_timeout,
#[cfg(feature = "redis")]
Expand All @@ -205,7 +225,9 @@ impl WasiKeyValueCtxBuilder {

/// Capture the state necessary for use in the `wasi-keyvalue` API implementation.
pub struct WasiKeyValueCtx {
allowed_hosts: Vec<String>,
in_memory_data: HashMap<String, Vec<u8>>,
#[cfg(feature = "redis")]
allowed_redis_hosts: Vec<String>,
#[cfg(feature = "redis")]
redis_connection_timeout: Option<std::time::Duration>,
#[cfg(feature = "redis")]
Expand All @@ -218,7 +240,8 @@ impl WasiKeyValueCtx {
WasiKeyValueCtxBuilder::new()
}

fn allow_host(&self, u: &Url) -> bool {
#[cfg(feature = "redis")]
fn allow_redis_host(&self, u: &Url) -> bool {
let host = match u.host() {
Some(h) => match u.port() {
Some(port) => format!("{}:{}", h, port),
Expand All @@ -227,7 +250,7 @@ impl WasiKeyValueCtx {
// unix domain socket path
None => u.path().to_string(),
};
self.allowed_hosts.contains(&host)
self.allowed_redis_hosts.contains(&host)
}
}

Expand All @@ -249,18 +272,13 @@ impl keyvalue::store::Host for WasiKeyValue<'_> {
async fn open(&mut self, identifier: String) -> Result<Resource<Bucket>, Error> {
if identifier == "" {
return Ok(self.table.push(Bucket {
inner: Box::new(provider::inmemory::InMemory::default()),
inner: Box::new(provider::inmemory::InMemory::new(
self.ctx.in_memory_data.clone(),
)),
})?);
}

let u = Url::parse(&identifier).map_err(to_other_error)?;
if !self.ctx.allow_host(&u) {
return Err(Error::Other(format!(
"the identifier {} is not in the allowed list",
identifier
)));
}

match u.scheme() {
"redis" | "redis+unix" => {
#[cfg(not(feature = "redis"))]
Expand All @@ -272,6 +290,13 @@ impl keyvalue::store::Host for WasiKeyValue<'_> {
}
#[cfg(feature = "redis")]
{
if !self.ctx.allow_redis_host(&u) {
return Err(Error::Other(format!(
"the identifier {} is not in the allowed list",
identifier
)));
}

let host = provider::redis::open(
identifier,
self.ctx.redis_response_timeout,
Expand Down Expand Up @@ -398,19 +423,20 @@ pub fn add_to_linker<T: Send>(

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_allow_host() -> Result<()> {
let ctx = WasiKeyValueCtx::builder()
.allow_hosts(&["127.0.0.1:1234", "localhost", "/var/run/redis.sock"])
#[cfg(feature = "redis")]
fn test_allow_redis_host() {
let ctx = super::WasiKeyValueCtx::builder()
.allow_redis_hosts(&["127.0.0.1:1234", "localhost", "/var/run/redis.sock"])
.build();
assert!(ctx.allow_host(&Url::parse("redis://127.0.0.1:1234/db")?));
assert!(ctx.allow_host(&Url::parse("redis://localhost")?));
assert!(!ctx.allow_host(&Url::parse("redis://192.168.0.1")?));
assert!(ctx.allow_host(&Url::parse("redis+unix:///var/run/redis.sock?db=db")?));
assert!(!ctx.allow_host(&Url::parse("redis+unix:///var/local/redis.sock?db=db")?));

Ok(())
assert!(ctx.allow_redis_host(&super::Url::parse("redis://127.0.0.1:1234/db").unwrap()));
assert!(ctx.allow_redis_host(&super::Url::parse("redis://localhost").unwrap()));
assert!(!ctx.allow_redis_host(&super::Url::parse("redis://192.168.0.1").unwrap()));
assert!(ctx.allow_redis_host(
&super::Url::parse("redis+unix:///var/run/redis.sock?db=db").unwrap()
));
assert!(!ctx.allow_redis_host(
&super::Url::parse("redis+unix:///var/local/redis.sock?db=db").unwrap()
));
}
}
8 changes: 8 additions & 0 deletions crates/wasi-keyvalue/src/provider/inmemory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ pub(crate) struct InMemory {
store: Arc<Mutex<HashMap<String, Vec<u8>>>>,
}

impl InMemory {
pub(crate) fn new(data: HashMap<String, Vec<u8>>) -> Self {
Self {
store: Arc::new(Mutex::new(data)),
}
}
}

#[async_trait]
impl Host for InMemory {
async fn get(&mut self, key: String) -> Result<Option<Vec<u8>>, Error> {
Expand Down
16 changes: 11 additions & 5 deletions crates/wasi-keyvalue/tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use wasmtime::{
Store,
};
use wasmtime_wasi::{bindings::Command, WasiCtx, WasiCtxBuilder, WasiView};
use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx};
use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder};

struct Ctx {
table: ResourceTable,
Expand Down Expand Up @@ -59,8 +59,13 @@ async fn keyvalue_main() -> Result<()> {
KEYVALUE_MAIN_COMPONENT,
Ctx {
table: ResourceTable::new(),
wasi_ctx: WasiCtxBuilder::new().env("IDENTIFIER", "").build(),
wasi_keyvalue_ctx: WasiKeyValueCtx::builder().build(),
wasi_ctx: WasiCtxBuilder::new()
.inherit_stderr()
.env("IDENTIFIER", "")
.build(),
wasi_keyvalue_ctx: WasiKeyValueCtxBuilder::new()
.in_memory_data([("atomics_key", "5")])
.build(),
},
)
.await
Expand All @@ -74,10 +79,11 @@ async fn keyvalue_redis() -> Result<()> {
Ctx {
table: ResourceTable::new(),
wasi_ctx: WasiCtxBuilder::new()
.inherit_stderr()
.env("IDENTIFIER", "redis://127.0.0.1/")
.build(),
wasi_keyvalue_ctx: WasiKeyValueCtx::builder()
.allow_hosts(&["127.0.0.1"])
wasi_keyvalue_ctx: WasiKeyValueCtxBuilder::new()
.allow_redis_hosts(&["127.0.0.1"])
.redis_connection_timeout(std::time::Duration::from_secs(5))
.redis_response_timeout(std::time::Duration::from_secs(5))
.build(),
Expand Down

0 comments on commit c7cc706

Please sign in to comment.