diff --git a/Cargo.toml b/Cargo.toml index 0f0f2de..5ce6e51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,11 @@ reqwest-retry = "0.7.0" tokio = { version = "1.42.0", features = ["full"] } tokio-util = { version = "0.7.13", features = ["compat"] } url = "2.5.4" + +[dev-dependencies] +serde = { version = "1.0.217", features = ["derive"] } +reqwest = { version = "0.12.9", default-features = false, features = [ + "stream", + "rustls-tls", + "json", +] } diff --git a/src/download.rs b/src/download.rs index bfd8d7d..19ef273 100644 --- a/src/download.rs +++ b/src/download.rs @@ -21,6 +21,22 @@ use crate::errors::DownloadError; const BASE_URL: &str = "https://data.commoncrawl.org/"; +static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); + +fn new_client(max_retries: usize) -> Result { + let retry_policy = ExponentialBackoff::builder() + .retry_bounds(Duration::from_secs(1), Duration::from_secs(3600)) + .jitter(Jitter::Bounded) + .base(2) + .build_with_max_retries(u32::try_from(max_retries).unwrap()); + + let client_base = Client::builder().user_agent(APP_USER_AGENT).build()?; + + Ok(ClientBuilder::new(client_base) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build()) +} + pub async fn download_paths( snapshot: &String, data_type: &str, @@ -30,7 +46,7 @@ pub async fn download_paths( println!("Downloading paths from: {}", paths); let url = Url::parse(&paths)?; - let client = Client::new(); + let client = new_client(1000)?; let filename = url .path_segments() // Splits into segments of the URL @@ -227,15 +243,7 @@ pub async fn download( main_pb.tick(); } - let retry_policy = ExponentialBackoff::builder() - .retry_bounds(Duration::from_secs(1), Duration::from_secs(3600)) - .jitter(Jitter::Bounded) - .base(2) - .build_with_max_retries(u32::try_from(max_retries).unwrap()); - - let client = ClientBuilder::new(reqwest::Client::new()) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); + let client = new_client(max_retries)?; let semaphore = Arc::new(Semaphore::new(threads)); let mut set = JoinSet::new(); @@ -298,3 +306,33 @@ pub async fn download( } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use serde::Deserialize; + use std::collections::HashMap; + + #[derive(Deserialize, Debug)] + pub struct HeadersEcho { + pub headers: HashMap, + } + + #[test] + fn user_agent_format() { + assert_eq!( + APP_USER_AGENT, + concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),) + ); + } + + #[tokio::test] + async fn user_agent_test() -> Result<(), DownloadError> { + let client = new_client(1000)?; + let response = client.get("http://httpbin.org/headers").send().await?; + + let out: HeadersEcho = response.json().await?; + assert_eq!(out.headers["User-Agent"], APP_USER_AGENT); + Ok(()) + } +}