From ec11ba216875feefb669105968f018532fbe27d3 Mon Sep 17 00:00:00 2001 From: Brian H Date: Thu, 9 Jan 2025 16:27:11 -0700 Subject: [PATCH] fix #2974 -- use redis::aio::ConnectionManager Signed-off-by: Brian H --- Cargo.lock | 38 ++++++++++++- crates/factor-key-value/src/host.rs | 19 ++++--- crates/factor-key-value/src/util.rs | 4 ++ crates/key-value-redis/Cargo.toml | 2 +- crates/key-value-redis/src/store.rs | 83 ++++++++++------------------- 5 files changed, 82 insertions(+), 64 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index afc7f7fe56..b38a767f58 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -968,6 +968,15 @@ dependencies = [ "time", ] +[[package]] +name = "backon" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5289ec98f68f28dd809fd601059e6aa908bb8f6108620930828283d4ee23d7" +dependencies = [ + "fastrand 2.2.0", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -4464,7 +4473,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -6493,6 +6502,31 @@ dependencies = [ "combine", "futures-util", "itoa", + "num-bigint", + "percent-encoding", + "pin-project-lite", + "ryu", + "sha1_smol", + "socket2 0.5.7", + "tokio", + "tokio-util", + "url", +] + +[[package]] +name = "redis" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff21dd025d2d3d2a6ad6788c0f7153f82d063216a7638f70367aac5790fea5da" +dependencies = [ + "arc-swap", + "backon", + "bytes", + "combine", + "futures-channel", + "futures-util", + "itertools 0.13.0", + "itoa", "native-tls", "num-bigint", "percent-encoding", @@ -8059,7 +8093,7 @@ name = "spin-key-value-redis" version = "3.2.0-pre0" dependencies = [ "anyhow", - "redis 0.27.5", + "redis 0.28.0", "serde", "spin-core", "spin-factor-key-value", diff --git a/crates/factor-key-value/src/host.rs b/crates/factor-key-value/src/host.rs index 3b6fdfb5e9..98fda701d8 100644 --- a/crates/factor-key-value/src/host.rs +++ b/crates/factor-key-value/src/host.rs @@ -27,6 +27,9 @@ pub trait StoreManager: Sync + Send { #[async_trait] pub trait Store: Sync + Send { + async fn after_open(&self) -> Result<(), Error> { + Ok(()) + } async fn get(&self, key: &str) -> Result>, Error>; async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error>; async fn delete(&self, key: &str) -> Result<(), Error>; @@ -109,11 +112,13 @@ impl key_value::HostStore for KeyValueDispatch { async fn open(&mut self, name: String) -> Result, Error>> { Ok(async { if self.allowed_stores.contains(&name) { - let store = self + let store = self.manager.get(&name).await?; + store.after_open().await?; + let store_idx = self .stores - .push(self.manager.get(&name).await?) + .push(store) .map_err(|()| Error::StoreTableFull)?; - Ok(Resource::new_own(store)) + Ok(Resource::new_own(store_idx)) } else { Err(Error::AccessDenied) } @@ -193,11 +198,13 @@ impl wasi_keyvalue::store::Host for KeyValueDispatch { identifier: String, ) -> Result, wasi_keyvalue::store::Error> { if self.allowed_stores.contains(&identifier) { - let store = self + let store = self.manager.get(&identifier).await.map_err(to_wasi_err)?; + store.after_open().await.map_err(to_wasi_err)?; + let store_idx = self .stores - .push(self.manager.get(&identifier).await.map_err(to_wasi_err)?) + .push(store) .map_err(|()| wasi_keyvalue::store::Error::Other("store table full".to_string()))?; - Ok(Resource::new_own(store)) + Ok(Resource::new_own(store_idx)) } else { Err(wasi_keyvalue::store::Error::AccessDenied) } diff --git a/crates/factor-key-value/src/util.rs b/crates/factor-key-value/src/util.rs index 9fec7e4348..523a61f9e7 100644 --- a/crates/factor-key-value/src/util.rs +++ b/crates/factor-key-value/src/util.rs @@ -156,6 +156,10 @@ struct CachingStore { #[async_trait] impl Store for CachingStore { + async fn after_open(&self) -> Result<(), Error> { + self.inner.after_open().await + } + async fn get(&self, key: &str) -> Result>, Error> { // Retrieve the specified value from the cache, lazily populating the cache as necessary. diff --git a/crates/key-value-redis/Cargo.toml b/crates/key-value-redis/Cargo.toml index 35c1050173..7e8119ebc4 100644 --- a/crates/key-value-redis/Cargo.toml +++ b/crates/key-value-redis/Cargo.toml @@ -6,7 +6,7 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } -redis = { version = "0.27", features = ["tokio-comp", "tokio-native-tls-comp"] } +redis = { version = "0.28", features = ["tokio-comp", "tokio-native-tls-comp", "connection-manager"] } serde = { workspace = true } spin-core = { path = "../core" } spin-factor-key-value = { path = "../factor-key-value" } diff --git a/crates/key-value-redis/src/store.rs b/crates/key-value-redis/src/store.rs index 1fbbabc6c0..6c8628841c 100644 --- a/crates/key-value-redis/src/store.rs +++ b/crates/key-value-redis/src/store.rs @@ -1,15 +1,14 @@ use anyhow::{Context, Result}; -use redis::{aio::MultiplexedConnection, parse_redis_url, AsyncCommands, Client, RedisError}; +use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError}; use spin_core::async_trait; use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError}; -use std::ops::DerefMut; use std::sync::Arc; -use tokio::sync::{Mutex, OnceCell}; +use tokio::sync::OnceCell; use url::Url; pub struct KeyValueRedis { database_url: Url, - connection: OnceCell>>, + connection: OnceCell, } impl KeyValueRedis { @@ -30,10 +29,8 @@ impl StoreManager for KeyValueRedis { .connection .get_or_try_init(|| async { Client::open(self.database_url.clone())? - .get_multiplexed_async_connection() + .get_connection_manager() .await - .map(Mutex::new) - .map(Arc::new) }) .await .map_err(log_error)?; @@ -55,90 +52,69 @@ impl StoreManager for KeyValueRedis { } struct RedisStore { - connection: Arc>, + connection: ConnectionManager, database_url: Url, } struct CompareAndSwap { key: String, - connection: Arc>, + connection: ConnectionManager, bucket_rep: u32, } #[async_trait] impl Store for RedisStore { + async fn after_open(&self) -> Result<(), Error> { + if let Err(_error) = self.connection.clone().ping::<()>().await { + // If an IO error happens, ConnectionManager will start reconnection in the background + // so we do not take any action and just pray re-connection will be successful. + } + Ok(()) + } + async fn get(&self, key: &str) -> Result>, Error> { - let mut conn = self.connection.lock().await; - conn.get(key).await.map_err(log_error) + self.connection.clone().get(key).await.map_err(log_error) } async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> { self.connection - .lock() - .await + .clone() .set(key, value) .await .map_err(log_error) } async fn delete(&self, key: &str) -> Result<(), Error> { - self.connection - .lock() - .await - .del(key) - .await - .map_err(log_error) + self.connection.clone().del(key).await.map_err(log_error) } async fn exists(&self, key: &str) -> Result { - self.connection - .lock() - .await - .exists(key) - .await - .map_err(log_error) + self.connection.clone().exists(key).await.map_err(log_error) } async fn get_keys(&self) -> Result, Error> { - self.connection - .lock() - .await - .keys("*") - .await - .map_err(log_error) + self.connection.clone().keys("*").await.map_err(log_error) } async fn get_many(&self, keys: Vec) -> Result>)>, Error> { - self.connection - .lock() - .await - .keys(keys) - .await - .map_err(log_error) + self.connection.clone().keys(keys).await.map_err(log_error) } async fn set_many(&self, key_values: Vec<(String, Vec)>) -> Result<(), Error> { self.connection - .lock() - .await + .clone() .mset(&key_values) .await .map_err(log_error) } async fn delete_many(&self, keys: Vec) -> Result<(), Error> { - self.connection - .lock() - .await - .del(keys) - .await - .map_err(log_error) + self.connection.clone().del(keys).await.map_err(log_error) } async fn increment(&self, key: String, delta: i64) -> Result { self.connection - .lock() - .await + .clone() .incr(key, delta) .await .map_err(log_error) @@ -154,10 +130,8 @@ impl Store for RedisStore { ) -> Result, Error> { let cx = Client::open(self.database_url.clone()) .map_err(log_error)? - .get_multiplexed_async_connection() + .get_connection_manager() .await - .map(Mutex::new) - .map(Arc::new) .map_err(log_error)?; Ok(Arc::new(CompareAndSwap { @@ -175,12 +149,11 @@ impl Cas for CompareAndSwap { async fn current(&self) -> Result>, Error> { redis::cmd("WATCH") .arg(&self.key) - .exec_async(self.connection.lock().await.deref_mut()) + .exec_async(&mut self.connection.clone()) .await .map_err(log_error)?; self.connection - .lock() - .await + .clone() .get(&self.key) .await .map_err(log_error) @@ -194,12 +167,12 @@ impl Cas for CompareAndSwap { let res: Result<(), RedisError> = transaction .atomic() .set(&self.key, value) - .query_async(self.connection.lock().await.deref_mut()) + .query_async(&mut self.connection.clone()) .await; redis::cmd("UNWATCH") .arg(&self.key) - .exec_async(self.connection.lock().await.deref_mut()) + .exec_async(&mut self.connection.clone()) .await .map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;