From 26d82195a8a1fbc153a8caf5af733c2a1453bc77 Mon Sep 17 00:00:00 2001 From: Spencer Bartholomew <38776747+spencerbart@users.noreply.github.com> Date: Thu, 29 May 2025 11:11:12 -0600 Subject: [PATCH] replace backoff with backon --- async-openai/Cargo.toml | 2 +- async-openai/src/client.rs | 250 +++++++++++++++++++++++++++++-------- 2 files changed, 202 insertions(+), 50 deletions(-) diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 578c25e8..fb2b3a99 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -28,7 +28,7 @@ byot = [] [dependencies] async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" } -backoff = { version = "0.4.0", features = ["tokio"] } +backon = { version = "1.5.1", features = ["tokio"] } base64 = "0.22.1" futures = "0.3.31" rand = "0.8.5" diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 1ff90aac..1947e1d1 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -23,7 +23,7 @@ use crate::{ pub struct Client { http_client: reqwest::Client, config: C, - backoff: backoff::ExponentialBackoff, + backoff: backon::ExponentialBuilder, } impl Client { @@ -38,7 +38,7 @@ impl Client { pub fn build( http_client: reqwest::Client, config: C, - backoff: backoff::ExponentialBackoff, + backoff: backon::ExponentialBuilder, ) -> Self { Self { http_client, @@ -65,7 +65,7 @@ impl Client { } /// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests. - pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self { + pub fn with_backoff(mut self, backoff: backon::ExponentialBuilder) -> Self { self.backoff = backoff; self } @@ -319,58 +319,77 @@ impl Client { M: Fn() -> Fut, Fut: core::future::Future>, { + use backon::Retryable; + use std::sync::Arc; + let client = self.http_client.clone(); + let backoff = self.backoff.clone(); + let request_maker = Arc::new(request_maker); + + (move || { + let client = client.clone(); + let request_maker = request_maker.clone(); + async move { + let request = request_maker().await?; + let response = client + .execute(request) + .await + .map_err(OpenAIError::Reqwest)?; + + let status = response.status(); + let bytes = response + .bytes() + .await + .map_err(OpenAIError::Reqwest)?; + + if status.is_server_error() { + // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. + let message: String = String::from_utf8_lossy(&bytes).into_owned(); + tracing::warn!("Server error: {status} - {message}"); + return Err(OpenAIError::ApiError(ApiError { + message, + r#type: None, + param: None, + code: None, + })); + } - backoff::future::retry(self.backoff.clone(), || async { - let request = request_maker().await.map_err(backoff::Error::Permanent)?; - let response = client - .execute(request) - .await - .map_err(OpenAIError::Reqwest) - .map_err(backoff::Error::Permanent)?; - - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(OpenAIError::Reqwest) - .map_err(backoff::Error::Permanent)?; - - if status.is_server_error() { - // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. - let message: String = String::from_utf8_lossy(&bytes).into_owned(); - tracing::warn!("Server error: {status} - {message}"); - return Err(backoff::Error::Transient { - err: OpenAIError::ApiError(ApiError { message, r#type: None, param: None, code: None }), - retry_after: None, - }); - } + // Deserialize response body from either error object or actual response object + if !status.is_success() { + let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + + if status.as_u16() == 429 + // API returns 429 also when: + // "You exceeded your current quota, please check your plan and billing details." + && wrapped_error.error.r#type != Some("insufficient_quota".to_string()) + { + // Rate limited retry... + tracing::warn!("Rate limited: {}", wrapped_error.error.message); + return Err(OpenAIError::ApiError(wrapped_error.error)); + } else { + return Err(OpenAIError::ApiError(wrapped_error.error)); + } + } - // Deserialize response body from either error object or actual response object - if !status.is_success() { - let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref())) - .map_err(backoff::Error::Permanent)?; - - if status.as_u16() == 429 - // API returns 429 also when: - // "You exceeded your current quota, please check your plan and billing details." - && wrapped_error.error.r#type != Some("insufficient_quota".to_string()) - { - // Rate limited retry... - tracing::warn!("Rate limited: {}", wrapped_error.error.message); - return Err(backoff::Error::Transient { - err: OpenAIError::ApiError(wrapped_error.error), - retry_after: None, - }); + Ok(bytes) + } + }) + .retry(backoff) + .when(|e| match e { + OpenAIError::ApiError(api_error) => { + // Retry on server errors (5xx) - these have no type + if api_error.r#type.is_none() { + return true; + } + // Retry on rate limiting (429) but not quota exceeded + if let Some(error_type) = &api_error.r#type { + error_type != "insufficient_quota" } else { - return Err(backoff::Error::Permanent(OpenAIError::ApiError( - wrapped_error.error, - ))); + false } } - - Ok(bytes) + _ => false, // Don't retry on other errors like network issues }) .await } @@ -507,6 +526,139 @@ where Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) } +#[cfg(test)] +mod tests { + use super::*; + use crate::config::OpenAIConfig; + use crate::error::{ApiError, OpenAIError}; + use backon::ExponentialBuilder; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + + + #[tokio::test] + async fn test_retry_on_server_error() { + let call_count = Arc::new(Mutex::new(0)); + let call_count_clone = call_count.clone(); + + let client = Client::new().with_backoff( + ExponentialBuilder::default() + .with_min_delay(Duration::from_millis(10)) + .with_max_delay(Duration::from_millis(100)) + .with_max_times(3), + ); + + let request_maker = move || { + let count = call_count_clone.clone(); + async move { + let mut current_count = count.lock().unwrap(); + *current_count += 1; + + // Fail first 2 times with server error, succeed on 3rd + if *current_count < 3 { + // Simulate server error response + return Err(OpenAIError::ApiError(ApiError { + message: "Internal server error".to_string(), + r#type: None, // Server errors have no type + param: None, + code: None, + })); + } + + // Create a dummy successful request + Ok(reqwest::Request::new( + reqwest::Method::GET, + "https://api.openai.com/v1/test".parse().unwrap(), + )) + } + }; + + // This should fail because we're not actually making real HTTP requests + // but we can verify the retry logic is triggered + let result = client.execute_raw(request_maker).await; + + // The function should have been called 3 times due to retries + let final_count = *call_count.lock().unwrap(); + assert!(final_count >= 2, "Expected at least 2 calls (original + retries), got {}", final_count); + + // The result should be an error since we're not making real requests + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_no_retry_on_quota_exceeded() { + let call_count = Arc::new(Mutex::new(0)); + let call_count_clone = call_count.clone(); + + let client = Client::new().with_backoff( + ExponentialBuilder::default() + .with_min_delay(Duration::from_millis(10)) + .with_max_delay(Duration::from_millis(100)) + .with_max_times(3), + ); + + let request_maker = move || { + let count = call_count_clone.clone(); + async move { + let mut current_count = count.lock().unwrap(); + *current_count += 1; + + // Always fail with quota exceeded error (should not retry) + Err(OpenAIError::ApiError(ApiError { + message: "You exceeded your current quota".to_string(), + r#type: Some("insufficient_quota".to_string()), + param: None, + code: None, + })) + } + }; + + let result = client.execute_raw(request_maker).await; + + // The function should have been called only once (no retries for quota errors) + let final_count = *call_count.lock().unwrap(); + assert_eq!(final_count, 1, "Expected exactly 1 call (no retries), got {}", final_count); + + // The result should be an error + assert!(result.is_err()); + if let Err(OpenAIError::ApiError(api_error)) = result { + assert_eq!(api_error.r#type, Some("insufficient_quota".to_string())); + } else { + panic!("Expected ApiError with insufficient_quota type"); + } + } + + #[test] + fn test_client_with_custom_backoff() { + let custom_backoff = ExponentialBuilder::default() + .with_min_delay(Duration::from_millis(100)) + .with_max_delay(Duration::from_secs(1)) + .with_max_times(5); + + let _client = Client::new().with_backoff(custom_backoff); + + // Just verify the client can be created with custom backoff + assert!(true); // If we get here, the test passed + } + + #[test] + fn test_client_build_with_backoff() { + let custom_backoff = ExponentialBuilder::default() + .with_min_delay(Duration::from_millis(50)) + .with_max_delay(Duration::from_millis(500)) + .with_max_times(3); + + let _client = Client::build( + reqwest::Client::new(), + OpenAIConfig::default(), + custom_backoff, + ); + + // Just verify the client can be built with custom backoff + assert!(true); // If we get here, the test passed + } +} + pub(crate) async fn stream_mapped_raw_events( mut event_source: EventSource, event_mapper: impl Fn(eventsource_stream::Event) -> Result + Send + 'static,