From dc2f2a37b293e4f83507414f0cac3b30c0862e01 Mon Sep 17 00:00:00 2001 From: Xinzhao Xu Date: Fri, 26 Jul 2024 11:18:23 +0800 Subject: [PATCH] Allow preset data for In-Memory provider, rename allow_hosts to allow_redis_hosts --- crates/test-programs/src/bin/keyvalue_main.rs | 27 +++--- crates/wasi-keyvalue/src/lib.rs | 86 ++++++++++++------- crates/wasi-keyvalue/src/provider/inmemory.rs | 8 ++ crates/wasi-keyvalue/tests/main.rs | 16 ++-- 4 files changed, 92 insertions(+), 45 deletions(-) diff --git a/crates/test-programs/src/bin/keyvalue_main.rs b/crates/test-programs/src/bin/keyvalue_main.rs index 635a89e69da9..4fccb2c81ff0 100644 --- a/crates/test-programs/src/bin/keyvalue_main.rs +++ b/crates/test-programs/src/bin/keyvalue_main.rs @@ -1,22 +1,29 @@ use test_programs::keyvalue::wasi::keyvalue::{atomics, batch, store}; fn main() { - let bucket = store::open(std::env::var_os("IDENTIFIER").unwrap().to_str().unwrap()).unwrap(); + let identifier = std::env::var_os("IDENTIFIER") + .unwrap() + .into_string() + .unwrap(); + let bucket = store::open(&identifier).unwrap(); + + if identifier != "" { + // for In-Memory provider, we have preset this data + assert_eq!(atomics::increment(&bucket, "atomics_key", 5).unwrap(), 5); + } + assert_eq!(atomics::increment(&bucket, "atomics_key", 1).unwrap(), 6); + + let resp = bucket.list_keys(None).unwrap(); + assert_eq!(resp.keys, vec!["atomics_key".to_string()]); + bucket.set("hello", "world".as_bytes()).unwrap(); let v = bucket.get("hello").unwrap(); assert_eq!(String::from_utf8(v.unwrap()).unwrap(), "world"); + assert_eq!(bucket.exists("hello").unwrap(), true); bucket.delete("hello").unwrap(); - let exists = bucket.exists("hello").unwrap(); - assert_eq!(exists, false); - - bucket.set("aa", "bb".as_bytes()).unwrap(); - let resp = bucket.list_keys(None).unwrap(); - assert_eq!(resp.keys, vec!["aa".to_string()]); - - assert_eq!(atomics::increment(&bucket, "atomics_key", 5).unwrap(), 5); - assert_eq!(atomics::increment(&bucket, "atomics_key", 1).unwrap(), 6); + assert_eq!(bucket.exists("hello").unwrap(), false); batch::set_many( &bucket, diff --git a/crates/wasi-keyvalue/src/lib.rs b/crates/wasi-keyvalue/src/lib.rs index a02759d613ca..c1149c370eb8 100644 --- a/crates/wasi-keyvalue/src/lib.rs +++ b/crates/wasi-keyvalue/src/lib.rs @@ -87,6 +87,7 @@ mod generated { use self::generated::wasi::keyvalue; use anyhow::Result; use async_trait::async_trait; +use std::collections::HashMap; use std::fmt::Display; use url::Url; use wasmtime::component::{Resource, ResourceTable, ResourceTableError}; @@ -143,7 +144,9 @@ trait Host { /// Builder-style structure used to create a [`WasiKeyValueCtx`]. #[derive(Default)] pub struct WasiKeyValueCtxBuilder { - allowed_hosts: Vec, + in_memory_data: HashMap>, + #[cfg(feature = "redis")] + allowed_redis_hosts: Vec, #[cfg(feature = "redis")] redis_connection_timeout: Option, #[cfg(feature = "redis")] @@ -156,7 +159,21 @@ impl WasiKeyValueCtxBuilder { Default::default() } - /// Appends a list of hosts to the allow-listed set each component gets + /// Preset data for the In-Memory provider. + pub fn in_memory_data(mut self, data: I) -> Self + where + I: IntoIterator, + K: Into, + V: Into>, + { + self.in_memory_data = data + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + self + } + + /// Appends a list of Redis hosts to the allow-listed set each component gets /// access to. It can be in the format `[:port]` or a unix domain /// socket path. /// @@ -167,13 +184,14 @@ impl WasiKeyValueCtxBuilder { /// /// # fn main() { /// let ctx = WasiKeyValueCtxBuilder::new() - /// .allow_hosts(&["localhost:1234", "/var/run/redis.sock"]) + /// .allow_redis_hosts(&["localhost:1234", "/var/run/redis.sock"]) /// .build(); /// # } /// ``` - pub fn allow_hosts(mut self, hosts: &[impl AsRef]) -> Self { - self.allowed_hosts - .extend(hosts.iter().map(|a| a.as_ref().to_owned())); + #[cfg(feature = "redis")] + pub fn allow_redis_hosts(mut self, hosts: &[impl AsRef]) -> Self { + self.allowed_redis_hosts + .extend(hosts.iter().map(|h| h.as_ref().to_owned())); self } @@ -194,7 +212,9 @@ impl WasiKeyValueCtxBuilder { /// Uses the configured context so far to construct the final [`WasiKeyValueCtx`]. pub fn build(self) -> WasiKeyValueCtx { WasiKeyValueCtx { - allowed_hosts: self.allowed_hosts, + in_memory_data: self.in_memory_data, + #[cfg(feature = "redis")] + allowed_redis_hosts: self.allowed_redis_hosts, #[cfg(feature = "redis")] redis_connection_timeout: self.redis_connection_timeout, #[cfg(feature = "redis")] @@ -205,7 +225,9 @@ impl WasiKeyValueCtxBuilder { /// Capture the state necessary for use in the `wasi-keyvalue` API implementation. pub struct WasiKeyValueCtx { - allowed_hosts: Vec, + in_memory_data: HashMap>, + #[cfg(feature = "redis")] + allowed_redis_hosts: Vec, #[cfg(feature = "redis")] redis_connection_timeout: Option, #[cfg(feature = "redis")] @@ -218,7 +240,8 @@ impl WasiKeyValueCtx { WasiKeyValueCtxBuilder::new() } - fn allow_host(&self, u: &Url) -> bool { + #[cfg(feature = "redis")] + fn allow_redis_host(&self, u: &Url) -> bool { let host = match u.host() { Some(h) => match u.port() { Some(port) => format!("{}:{}", h, port), @@ -227,7 +250,7 @@ impl WasiKeyValueCtx { // unix domain socket path None => u.path().to_string(), }; - self.allowed_hosts.contains(&host) + self.allowed_redis_hosts.contains(&host) } } @@ -249,18 +272,13 @@ impl keyvalue::store::Host for WasiKeyValue<'_> { async fn open(&mut self, identifier: String) -> Result, Error> { if identifier == "" { return Ok(self.table.push(Bucket { - inner: Box::new(provider::inmemory::InMemory::default()), + inner: Box::new(provider::inmemory::InMemory::new( + self.ctx.in_memory_data.clone(), + )), })?); } let u = Url::parse(&identifier).map_err(to_other_error)?; - if !self.ctx.allow_host(&u) { - return Err(Error::Other(format!( - "the identifier {} is not in the allowed list", - identifier - ))); - } - match u.scheme() { "redis" | "redis+unix" => { #[cfg(not(feature = "redis"))] @@ -272,6 +290,13 @@ impl keyvalue::store::Host for WasiKeyValue<'_> { } #[cfg(feature = "redis")] { + if !self.ctx.allow_redis_host(&u) { + return Err(Error::Other(format!( + "the identifier {} is not in the allowed list", + identifier + ))); + } + let host = provider::redis::open( identifier, self.ctx.redis_response_timeout, @@ -398,19 +423,20 @@ pub fn add_to_linker( #[cfg(test)] mod tests { - use super::*; - #[test] - fn test_allow_host() -> Result<()> { - let ctx = WasiKeyValueCtx::builder() - .allow_hosts(&["127.0.0.1:1234", "localhost", "/var/run/redis.sock"]) + #[cfg(feature = "redis")] + fn test_allow_redis_host() { + let ctx = super::WasiKeyValueCtx::builder() + .allow_redis_hosts(&["127.0.0.1:1234", "localhost", "/var/run/redis.sock"]) .build(); - assert!(ctx.allow_host(&Url::parse("redis://127.0.0.1:1234/db")?)); - assert!(ctx.allow_host(&Url::parse("redis://localhost")?)); - assert!(!ctx.allow_host(&Url::parse("redis://192.168.0.1")?)); - assert!(ctx.allow_host(&Url::parse("redis+unix:///var/run/redis.sock?db=db")?)); - assert!(!ctx.allow_host(&Url::parse("redis+unix:///var/local/redis.sock?db=db")?)); - - Ok(()) + assert!(ctx.allow_redis_host(&super::Url::parse("redis://127.0.0.1:1234/db").unwrap())); + assert!(ctx.allow_redis_host(&super::Url::parse("redis://localhost").unwrap())); + assert!(!ctx.allow_redis_host(&super::Url::parse("redis://192.168.0.1").unwrap())); + assert!(ctx.allow_redis_host( + &super::Url::parse("redis+unix:///var/run/redis.sock?db=db").unwrap() + )); + assert!(!ctx.allow_redis_host( + &super::Url::parse("redis+unix:///var/local/redis.sock?db=db").unwrap() + )); } } diff --git a/crates/wasi-keyvalue/src/provider/inmemory.rs b/crates/wasi-keyvalue/src/provider/inmemory.rs index cf4749944760..f0ba022369dc 100644 --- a/crates/wasi-keyvalue/src/provider/inmemory.rs +++ b/crates/wasi-keyvalue/src/provider/inmemory.rs @@ -8,6 +8,14 @@ pub(crate) struct InMemory { store: Arc>>>, } +impl InMemory { + pub(crate) fn new(data: HashMap>) -> Self { + Self { + store: Arc::new(Mutex::new(data)), + } + } +} + #[async_trait] impl Host for InMemory { async fn get(&mut self, key: String) -> Result>, Error> { diff --git a/crates/wasi-keyvalue/tests/main.rs b/crates/wasi-keyvalue/tests/main.rs index e86db87df781..fa103d63e6c4 100644 --- a/crates/wasi-keyvalue/tests/main.rs +++ b/crates/wasi-keyvalue/tests/main.rs @@ -5,7 +5,7 @@ use wasmtime::{ Store, }; use wasmtime_wasi::{bindings::Command, WasiCtx, WasiCtxBuilder, WasiView}; -use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx}; +use wasmtime_wasi_keyvalue::{WasiKeyValue, WasiKeyValueCtx, WasiKeyValueCtxBuilder}; struct Ctx { table: ResourceTable, @@ -59,8 +59,13 @@ async fn keyvalue_main() -> Result<()> { KEYVALUE_MAIN_COMPONENT, Ctx { table: ResourceTable::new(), - wasi_ctx: WasiCtxBuilder::new().env("IDENTIFIER", "").build(), - wasi_keyvalue_ctx: WasiKeyValueCtx::builder().build(), + wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() + .env("IDENTIFIER", "") + .build(), + wasi_keyvalue_ctx: WasiKeyValueCtxBuilder::new() + .in_memory_data([("atomics_key", "5")]) + .build(), }, ) .await @@ -74,10 +79,11 @@ async fn keyvalue_redis() -> Result<()> { Ctx { table: ResourceTable::new(), wasi_ctx: WasiCtxBuilder::new() + .inherit_stderr() .env("IDENTIFIER", "redis://127.0.0.1/") .build(), - wasi_keyvalue_ctx: WasiKeyValueCtx::builder() - .allow_hosts(&["127.0.0.1"]) + wasi_keyvalue_ctx: WasiKeyValueCtxBuilder::new() + .allow_redis_hosts(&["127.0.0.1"]) .redis_connection_timeout(std::time::Duration::from_secs(5)) .redis_response_timeout(std::time::Duration::from_secs(5)) .build(),