diff --git a/src/db.rs b/src/db.rs index 2859ae15..7ab55aab 100644 --- a/src/db.rs +++ b/src/db.rs @@ -3,8 +3,8 @@ use anyhow::Context as _; use chrono::Utc; use native_tls::{Certificate, TlsConnector}; use postgres_native_tls::MakeTlsConnector; -use std::sync::{Arc, LazyLock, Mutex}; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use std::sync::{Arc, LazyLock}; +use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore}; use tokio_postgres::Client as DbClient; pub mod issue_data; @@ -37,8 +37,14 @@ pub struct PooledClient { impl Drop for PooledClient { fn drop(&mut self) { - let mut clients = self.pool.lock().unwrap_or_else(|e| e.into_inner()); - clients.push(self.client.take().unwrap()); + // We can't await in drop, so we need to spawn a task to handle async lock + if let Some(client) = self.client.take() { + let pool = self.pool.clone(); + tokio::spawn(async move { + let mut clients = pool.lock().await; + clients.push(client); + }); + } } } @@ -68,17 +74,21 @@ impl ClientPool { pub async fn get(&self) -> PooledClient { let permit = self.permits.clone().acquire_owned().await.unwrap(); { - let mut slots = self.connections.lock().unwrap_or_else(|e| e.into_inner()); + let mut slots = self.connections.lock().await; // Pop connections until we hit a non-closed connection (or there are no // "possibly open" connections left). while let Some(c) = slots.pop() { - if !c.is_closed() { + // drop the lock + drop(slots); + if !c.is_closed() && validate_connection(&c).await { return PooledClient { client: Some(c), permit, pool: self.connections.clone(), }; } + // re-lock + slots = self.connections.lock().await; } } @@ -90,6 +100,11 @@ impl ClientPool { } } +/// validate connection via query +async fn validate_connection(conn: &tokio_postgres::Client) -> bool { + conn.query_one("SELECT 1", &[]).await.is_ok() +} + pub async fn make_client(db_url: &str) -> anyhow::Result { if db_url.contains("rds.amazonaws.com") { let mut builder = TlsConnector::builder();