From c04fc4d7e6bad78e7fd729394074026e95d45cac Mon Sep 17 00:00:00 2001 From: xizheyin Date: Thu, 24 Apr 2025 20:04:57 +0800 Subject: [PATCH 1/2] Check if connection is valid while get client from poll Signed-off-by: xizheyin --- src/db.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/db.rs b/src/db.rs index 2859ae15e..8fa8f91b3 100644 --- a/src/db.rs +++ b/src/db.rs @@ -72,7 +72,7 @@ impl ClientPool { // Pop connections until we hit a non-closed connection (or there are no // "possibly open" connections left). while let Some(c) = slots.pop() { - if !c.is_closed() { + if !c.is_closed() && validate_connection(&c).await { return PooledClient { client: Some(c), permit, @@ -90,6 +90,11 @@ impl ClientPool { } } +/// validate connection via query +async fn validate_connection(conn: &tokio_postgres::Client) -> bool { + conn.query_one("SELECT 1", &[]).await.is_ok() +} + pub async fn make_client(db_url: &str) -> anyhow::Result { if db_url.contains("rds.amazonaws.com") { let mut builder = TlsConnector::builder(); From 472489437d55f87f08638f6cbb39eac06ee1e467 Mon Sep 17 00:00:00 2001 From: xizheyin Date: Thu, 24 Apr 2025 21:26:11 +0800 Subject: [PATCH 2/2] Use tokio::sync::Mutex instead of std::sync::Mutex, and make optimization Signed-off-by: xizheyin --- src/db.rs | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/db.rs b/src/db.rs index 8fa8f91b3..7ab55aabc 100644 --- a/src/db.rs +++ b/src/db.rs @@ -3,8 +3,8 @@ use anyhow::Context as _; use chrono::Utc; use native_tls::{Certificate, TlsConnector}; use postgres_native_tls::MakeTlsConnector; -use std::sync::{Arc, LazyLock, Mutex}; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use std::sync::{Arc, LazyLock}; +use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore}; use tokio_postgres::Client as DbClient; pub mod issue_data; @@ -37,8 +37,14 @@ pub struct PooledClient { impl Drop for PooledClient { fn drop(&mut self) { - let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner()); - clients.push(self.client.take().unwrap()); + // We can't await in drop, so we need to spawn a task to handle async lock + if let Some(client) = self.client.take() { + let pool = self.pool.clone(); + tokio::spawn(async move { + let mut clients = pool.lock().await; + clients.push(client); + }); + } } } @@ -68,10 +74,12 @@ impl ClientPool { pub async fn get(&self) -> PooledClient { let permit = self.permits.clone().acquire_owned().await.unwrap(); { - let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner()); + let mut slots = self.connections.lock().await; // Pop connections until we hit a non-closed connection (or there are no // "possibly open" connections left). while let Some(c) = slots.pop() { + // drop the lock + drop(slots); if !c.is_closed() && validate_connection(&c).await { return PooledClient { client: Some(c), @@ -79,6 +87,8 @@ impl ClientPool { pool: self.connections.clone(), }; } + // re-lock + slots = self.connections.lock().await; } }