From 6fd0bf5bd83c76b008529c5f576af5bf97e0a72d Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Thu, 14 Nov 2024 13:57:48 +0100 Subject: [PATCH] perf: enable using sharded repodata for custom channels (#910) --- crates/rattler-bin/src/commands/create.rs | 78 +++++--- .../src/channel/channel_url.rs | 65 +++++++ crates/rattler_conda_types/src/channel/mod.rs | 95 ++++----- crates/rattler_conda_types/src/lib.rs | 2 +- .../rattler_conda_types/src/repo_data/mod.rs | 9 +- crates/rattler_conda_types/src/utils/mod.rs | 3 + crates/rattler_conda_types/src/utils/url.rs | 14 -- .../src/utils/url_with_trailing_slash.rs | 60 ++++++ .../src/gateway/channel_config.rs | 53 ++++-- .../src/gateway/error.rs | 3 + .../src/gateway/mod.rs | 88 +-------- .../src/gateway/sharded_subdir/index.rs | 180 ++++++++++-------- .../src/gateway/sharded_subdir/mod.rs | 67 +++---- .../src/gateway/sharded_subdir/token.rs | 166 ---------------- .../src/gateway/subdir_builder.rs | 158 +++++++++++++++ .../src/sparse/mod.rs | 3 +- py-rattler/Cargo.lock | 5 +- py-rattler/rattler/repo_data/gateway.py | 10 +- py-rattler/src/lock/mod.rs | 2 +- py-rattler/src/repo_data/gateway.rs | 13 +- 20 files changed, 592 insertions(+), 482 deletions(-) create mode 100644 crates/rattler_conda_types/src/channel/channel_url.rs create mode 100644 crates/rattler_conda_types/src/utils/url_with_trailing_slash.rs delete mode 100644 crates/rattler_repodata_gateway/src/gateway/sharded_subdir/token.rs create mode 100644 crates/rattler_repodata_gateway/src/gateway/subdir_builder.rs diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index bfb0461c2..6b3600463 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -1,29 +1,35 @@ -use crate::global_multi_progress; +use std::{ + borrow::Cow, + env, + future::IntoFuture, + path::PathBuf, + str::FromStr, + sync::Arc, + time::{Duration, Instant}, +}; + use anyhow::Context; use clap::ValueEnum; use indicatif::{ProgressBar, ProgressStyle}; use itertools::Itertools; -use rattler::install::{IndicatifReporter, Installer}; -use rattler::package_cache::PackageCache; use rattler::{ default_cache_dir, - install::{Transaction, TransactionOperation}, + install::{IndicatifReporter, Installer, Transaction, TransactionOperation}, + package_cache::PackageCache, }; use rattler_conda_types::{ Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, ParseStrictness, Platform, PrefixRecord, RepoDataRecord, Version, }; use rattler_networking::{AuthenticationMiddleware, AuthenticationStorage}; -use rattler_repodata_gateway::{Gateway, RepoData}; +use rattler_repodata_gateway::{Gateway, RepoData, SourceConfig}; use rattler_solve::{ libsolv_c::{self}, resolvo, SolverImpl, SolverTask, }; use reqwest::Client; -use std::future::IntoFuture; -use std::sync::Arc; -use std::time::Instant; -use std::{borrow::Cow, env, path::PathBuf, str::FromStr, time::Duration}; + +use crate::global_multi_progress; #[derive(Debug, clap::Parser)] pub struct Opt { @@ -105,8 +111,9 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { println!("Installing for platform: {install_platform:?}"); - // Parse the specs from the command line. We do this explicitly instead of allow clap to deal - // with this because we need to parse the `channel_config` when parsing matchspecs. + // Parse the specs from the command line. We do this explicitly instead of allow + // clap to deal with this because we need to parse the `channel_config` when + // parsing matchspecs. let specs = opt .specs .iter() @@ -118,8 +125,9 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { std::fs::create_dir_all(&cache_dir) .map_err(|e| anyhow::anyhow!("could not create cache directory: {}", e))?; - // Determine the channels to use from the command line or select the default. Like matchspecs - // this also requires the use of the `channel_config` so we have to do this manually. + // Determine the channels to use from the command line or select the default. + // Like matchspecs this also requires the use of the `channel_config` so we + // have to do this manually. let channels = opt .channels .unwrap_or_else(|| vec![String::from("conda-forge")]) @@ -130,9 +138,10 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { // Determine the packages that are currently installed in the environment. let installed_packages = PrefixRecord::collect_from_prefix(&target_prefix)?; - // For each channel/subdirectory combination, download and cache the `repodata.json` that should - // be available from the corresponding Url. The code below also displays a nice CLI progress-bar - // to give users some more information about what is going on. + // For each channel/subdirectory combination, download and cache the + // `repodata.json` that should be available from the corresponding Url. The + // code below also displays a nice CLI progress-bar to give users some more + // information about what is going on. let download_client = Client::builder() .no_gzip() .build() @@ -147,13 +156,21 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { .with(rattler_networking::GCSMiddleware) .build(); - // Get the package names from the matchspecs so we can only load the package records that we need. + // Get the package names from the matchspecs so we can only load the package + // records that we need. let gateway = Gateway::builder() .with_cache_dir(cache_dir.join(rattler_cache::REPODATA_CACHE_DIR)) .with_package_cache(PackageCache::new( cache_dir.join(rattler_cache::PACKAGE_CACHE_DIR), )) .with_client(download_client.clone()) + .with_channel_config(rattler_repodata_gateway::ChannelConfig { + default: SourceConfig { + sharded_enabled: false, + ..SourceConfig::default() + }, + ..rattler_repodata_gateway::ChannelConfig::default() + }) .finish(); let start_load_repo_data = Instant::now(); @@ -178,9 +195,9 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { start_load_repo_data.elapsed() ); - // Determine virtual packages of the system. These packages define the capabilities of the - // system. Some packages depend on these virtual packages to indicate compatibility with the - // hardware of the system. + // Determine virtual packages of the system. These packages define the + // capabilities of the system. Some packages depend on these virtual + // packages to indicate compatibility with the hardware of the system. let virtual_packages = wrap_in_progress("determining virtual packages", move || { if let Some(virtual_packages) = opt.virtual_package { Ok(virtual_packages @@ -218,9 +235,10 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { .format_with("\n", |i, f| f(&format_args!(" - {i}",))) ); - // Now that we parsed and downloaded all information, construct the packaging problem that we - // need to solve. We do this by constructing a `SolverProblem`. This encapsulates all the - // information required to be able to solve the problem. + // Now that we parsed and downloaded all information, construct the packaging + // problem that we need to solve. We do this by constructing a + // `SolverProblem`. This encapsulates all the information required to be + // able to solve the problem. let locked_packages = installed_packages .iter() .map(|record| record.repodata_record.clone()) @@ -235,8 +253,9 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { ..SolverTask::from_iter(&repo_data) }; - // Next, use a solver to solve this specific problem. This provides us with all the operations - // we need to apply to our environment to bring it up to date. + // Next, use a solver to solve this specific problem. This provides us with all + // the operations we need to apply to our environment to bring it up to + // date. let required_packages = wrap_in_progress("solving", move || match opt.solver.unwrap_or_default() { Solver::Resolvo => resolvo::Solver.solve(solver_task), @@ -340,7 +359,8 @@ fn print_transaction(transaction: &Transaction) { } } -/// Displays a spinner with the given message while running the specified function to completion. +/// Displays a spinner with the given message while running the specified +/// function to completion. fn wrap_in_progress T>(msg: impl Into>, func: F) -> T { let pb = ProgressBar::new_spinner(); pb.enable_steady_tick(Duration::from_millis(100)); @@ -351,7 +371,8 @@ fn wrap_in_progress T>(msg: impl Into>, func result } -/// Displays a spinner with the given message while running the specified function to completion. +/// Displays a spinner with the given message while running the specified +/// function to completion. async fn wrap_in_async_progress>( msg: impl Into>, fut: F, @@ -365,7 +386,8 @@ async fn wrap_in_async_progress>( result } -/// Returns the style to use for a progressbar that is indeterminate and simply shows a spinner. +/// Returns the style to use for a progressbar that is indeterminate and simply +/// shows a spinner. fn long_running_progress_style() -> indicatif::ProgressStyle { ProgressStyle::with_template("{spinner:.green} {msg}").unwrap() } diff --git a/crates/rattler_conda_types/src/channel/channel_url.rs b/crates/rattler_conda_types/src/channel/channel_url.rs new file mode 100644 index 000000000..52e3697b0 --- /dev/null +++ b/crates/rattler_conda_types/src/channel/channel_url.rs @@ -0,0 +1,65 @@ +use std::fmt::{Debug, Display, Formatter}; + +use serde::{Deserialize, Serialize}; +use url::Url; + +use crate::{utils::url_with_trailing_slash::UrlWithTrailingSlash, Platform}; + +/// Represents a channel base url. This is a wrapper around an url that is +/// normalized: +/// +/// * The URL always contains a trailing `/`. +/// +/// This is useful to be able to compare different channels. +#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ChannelUrl(UrlWithTrailingSlash); + +impl ChannelUrl { + /// Returns the base Url of the channel. + pub fn url(&self) -> &Url { + &self.0 + } + + /// Returns the string representation of the url. + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + /// Append the platform to the base url. + pub fn platform_url(&self, platform: Platform) -> Url { + self.0 + .join(&format!("{}/", platform.as_str())) // trailing slash is important here as this signifies a directory + .expect("platform is a valid url fragment") + } +} + +impl Debug for ChannelUrl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0.as_str()) + } +} + +impl From for ChannelUrl { + fn from(url: Url) -> Self { + Self(UrlWithTrailingSlash::from(url)) + } +} + +impl From for Url { + fn from(value: ChannelUrl) -> Self { + value.0.into() + } +} + +impl AsRef for ChannelUrl { + fn as_ref(&self) -> &Url { + &self.0 + } +} + +impl Display for ChannelUrl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", &self.0) + } +} diff --git a/crates/rattler_conda_types/src/channel/mod.rs b/crates/rattler_conda_types/src/channel/mod.rs index b9050ec04..aaf46d844 100644 --- a/crates/rattler_conda_types/src/channel/mod.rs +++ b/crates/rattler_conda_types/src/channel/mod.rs @@ -13,10 +13,11 @@ use typed_path::{Utf8NativePathBuf, Utf8TypedPath, Utf8TypedPathBuf}; use url::Url; use super::{ParsePlatformError, Platform}; -use crate::utils::{ - path::is_path, - url::{add_trailing_slash, parse_scheme}, -}; +use crate::utils::{path::is_path, url::parse_scheme}; + +mod channel_url; + +pub use channel_url::ChannelUrl; const DEFAULT_CHANNEL_ALIAS: &str = "https://conda.anaconda.org"; @@ -105,7 +106,7 @@ impl NamedChannelOrUrl { /// Converts the channel to a base url using the given configuration. /// This method ensures that the base url always ends with a `/`. - pub fn into_base_url(self, config: &ChannelConfig) -> Result { + pub fn into_base_url(self, config: &ChannelConfig) -> Result { let url = match self { NamedChannelOrUrl::Name(name) => { let mut base_url = config.channel_alias.clone(); @@ -114,16 +115,17 @@ impl NamedChannelOrUrl { segments.push(segment); } } - base_url + base_url.into() } - NamedChannelOrUrl::Url(url) => url, + NamedChannelOrUrl::Url(url) => url.into(), NamedChannelOrUrl::Path(path) => { let absolute_path = absolute_path(path.as_str(), &config.root_dir)?; directory_path_to_url(absolute_path.to_path()) .map_err(|_err| ParseChannelError::InvalidPath(path.to_string()))? + .into() } }; - Ok(add_trailing_slash(&url).into_owned()) + Ok(url) } /// Converts this instance into a channel. @@ -136,7 +138,8 @@ impl NamedChannelOrUrl { let base_url = self.into_base_url(config)?; Ok(Channel { name, - ..Channel::from_url(base_url) + base_url, + platforms: None, }) } } @@ -189,7 +192,7 @@ pub struct Channel { pub platforms: Option>, /// Base URL of the channel, everything is relative to this url. - pub base_url: Url, + pub base_url: ChannelUrl, /// The name of the channel pub name: Option, @@ -221,7 +224,7 @@ impl Channel { .map_err(|_err| ParseChannelError::InvalidPath(channel.to_owned()))?; Self { platforms, - base_url: url, + base_url: url.into(), name: Some(channel.to_owned()), } } @@ -252,15 +255,6 @@ impl Channel { // Get the path part of the URL but trim the directory suffix let path = url.path().trim_end_matches('/'); - // Ensure that the base_url does always ends in a `/` - let base_url = if url.path().ends_with('/') { - url.clone() - } else { - let mut url = url.clone(); - url.set_path(&format!("{path}/")); - url - }; - // Case 1: No path give, channel name is "" // Case 2: migrated_custom_channels @@ -268,23 +262,23 @@ impl Channel { // Case 4: custom_channels matches // Case 5: channel_alias match - if base_url.has_host() { + if url.has_host() { // Case 7: Fallback let name = path.trim_start_matches('/'); Self { platforms: None, name: (!name.is_empty()).then_some(name).map(str::to_owned), - base_url, + base_url: url.into(), } } else { // Case 6: non-otherwise-specified file://-type urls let name = path .rsplit_once('/') - .map_or_else(|| base_url.path(), |(_, path_part)| path_part); + .map_or_else(|| path, |(_, path_part)| path_part); Self { platforms: None, name: (!name.is_empty()).then_some(name).map(str::to_owned), - base_url, + base_url: url.into(), } } } @@ -305,7 +299,8 @@ impl Channel { base_url: config .channel_alias .join(dir_name.as_ref()) - .expect("name is not a valid Url"), + .expect("name is not a valid Url") + .into(), name: (!name.is_empty()).then_some(name).map(str::to_owned), } } @@ -329,14 +324,14 @@ impl Channel { let url = Url::from_directory_path(path).expect("path is a valid url"); Self { platforms: None, - base_url: url, + base_url: url.into(), name: None, } } /// Returns the name of the channel pub fn name(&self) -> &str { - match self.base_url().scheme() { + match self.base_url.url().scheme() { // The name of the channel is only defined for http and https channels. // If the name is not defined we return the base url. "https" | "http" => self @@ -347,17 +342,9 @@ impl Channel { } } - /// Returns the base Url of the channel. This does not include the platform - /// part. - pub fn base_url(&self) -> &Url { - &self.base_url - } - /// Returns the Urls for the given platform pub fn platform_url(&self, platform: Platform) -> Url { - self.base_url() - .join(&format!("{}/", platform.as_str())) // trailing slash is important here as this signifies a directory - .expect("platform is a valid url fragment") + self.base_url.platform_url(platform) } /// Returns the Urls for all the supported platforms of this package. @@ -380,7 +367,7 @@ impl Channel { /// Returns the canonical name of the channel pub fn canonical_name(&self) -> String { - self.base_url.clone().redact().to_string() + self.base_url.url().clone().redact().to_string() } } @@ -579,7 +566,7 @@ mod tests { let channel = Channel::from_str("conda-forge", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge")); @@ -596,14 +583,14 @@ mod tests { let channel = Channel::from_str("https://conda.anaconda.org/conda-forge/", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge")); assert_eq!(channel.name(), "conda-forge"); assert_eq!(channel.platforms, None); assert_eq!( - channel.base_url().to_string(), + channel.base_url.to_string(), "https://conda.anaconda.org/conda-forge/" ); @@ -622,12 +609,12 @@ mod tests { assert_eq!(channel.name.as_deref(), Some("conda-forge")); assert_eq!(channel.name(), "file:///var/channels/conda-forge/"); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("file:///var/channels/conda-forge/").unwrap() ); assert_eq!(channel.platforms, None); assert_eq!( - channel.base_url().to_string(), + channel.base_url.to_string(), "file:///var/channels/conda-forge/" ); @@ -643,7 +630,7 @@ mod tests { ); assert_eq!(channel.platforms, None); assert_eq!( - channel.base_url().to_file_path().unwrap(), + channel.base_url.url().to_file_path().unwrap(), current_dir.join("dir/does/not_exist") ); } @@ -654,7 +641,7 @@ mod tests { let channel = Channel::from_str("http://localhost:1234", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("http://localhost:1234/").unwrap() ); assert_eq!(channel.name, None); @@ -681,7 +668,7 @@ mod tests { ) .unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge")); @@ -693,7 +680,7 @@ mod tests { ) .unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/pkgs/main/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("pkgs/main")); @@ -701,7 +688,7 @@ mod tests { let channel = Channel::from_str("conda-forge/label/rust_dev", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/label/rust_dev/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge/label/rust_dev")); @@ -785,8 +772,8 @@ mod tests { for channel_str in test_channels { let channel = Channel::from_str(channel_str, &channel_config).unwrap(); - assert!(channel.base_url().as_str().ends_with('/')); - assert!(!channel.base_url().as_str().ends_with("//")); + assert!(channel.base_url.as_str().ends_with('/')); + assert!(!channel.base_url.as_str().ends_with("//")); let named_channel = NamedChannelOrUrl::from_str(channel_str).unwrap(); let base_url = named_channel @@ -798,8 +785,8 @@ mod tests { assert!(!base_url_str.ends_with("//")); let channel = named_channel.into_channel(&channel_config).unwrap(); - assert!(channel.base_url().as_str().ends_with('/')); - assert!(!channel.base_url().as_str().ends_with("//")); + assert!(channel.base_url.as_str().ends_with('/')); + assert!(!channel.base_url.as_str().ends_with("//")); } } @@ -813,14 +800,14 @@ mod tests { let channel = Channel::from_str("conda-forge", &channel_config).unwrap(); assert_eq!( &channel.base_url, - named.into_channel(&channel_config).unwrap().base_url() + &named.into_channel(&channel_config).unwrap().base_url ); let named = NamedChannelOrUrl::Name("nvidia/label/cuda-11.8.0".to_string()); let channel = Channel::from_str("nvidia/label/cuda-11.8.0", &channel_config).unwrap(); assert_eq!( - channel.base_url(), - named.into_channel(&channel_config).unwrap().base_url() + channel.base_url, + named.into_channel(&channel_config).unwrap().base_url ); } } diff --git a/crates/rattler_conda_types/src/lib.rs b/crates/rattler_conda_types/src/lib.rs index d79622b17..e53e558c7 100644 --- a/crates/rattler_conda_types/src/lib.rs +++ b/crates/rattler_conda_types/src/lib.rs @@ -28,7 +28,7 @@ pub mod prefix_record; use std::path::{Path, PathBuf}; pub use build_spec::{BuildNumber, BuildNumberSpec, ParseBuildNumberSpecError}; -pub use channel::{Channel, ChannelConfig, NamedChannelOrUrl, ParseChannelError}; +pub use channel::{Channel, ChannelConfig, ChannelUrl, NamedChannelOrUrl, ParseChannelError}; pub use channel_data::{ChannelData, ChannelDataPackage}; pub use environment_yaml::{EnvironmentYaml, MatchSpecOrSubSection}; pub use explicit_environment_spec::{ diff --git a/crates/rattler_conda_types/src/repo_data/mod.rs b/crates/rattler_conda_types/src/repo_data/mod.rs index 56c3ec299..25010af36 100644 --- a/crates/rattler_conda_types/src/repo_data/mod.rs +++ b/crates/rattler_conda_types/src/repo_data/mod.rs @@ -24,7 +24,7 @@ use crate::{ package::{IndexJson, RunExportsJson}, utils::{ serde::{sort_map_alphabetically, DeserializeFromStrUnchecked}, - url::add_trailing_slash, + UrlWithTrailingSlash, }, Channel, MatchSpec, Matches, NoArchType, PackageName, PackageUrl, ParseMatchSpecError, ParseStrictness, Platform, RepoDataRecord, VersionWithSource, @@ -234,7 +234,8 @@ impl RepoData { records.push(RepoDataRecord { url: compute_package_url( &channel - .base_url() + .base_url + .url() .join(&package_record.subdir) .expect("cannot join channel base_url and subdir"), base_url.as_deref(), @@ -259,7 +260,7 @@ pub fn compute_package_url( None => repo_data_base_url.clone(), Some(base_url) => match Url::parse(base_url) { Err(url::ParseError::RelativeUrlWithoutBase) if !base_url.starts_with('/') => { - add_trailing_slash(repo_data_base_url) + UrlWithTrailingSlash::from(repo_data_base_url.clone()) .join(base_url) .expect("failed to join base_url with channel") } @@ -609,7 +610,7 @@ mod test { &ChannelConfig::default_with_root_dir(std::env::current_dir().unwrap()), ) .unwrap(); - let base_url = channel.base_url().join("linux-64/").unwrap(); + let base_url = channel.base_url.url().join("linux-64/").unwrap(); assert_eq!( compute_package_url(&base_url, None, "bla.conda").to_string(), "https://conda.anaconda.org/conda-forge/linux-64/bla.conda" diff --git a/crates/rattler_conda_types/src/utils/mod.rs b/crates/rattler_conda_types/src/utils/mod.rs index 7afe9803d..c2cc669aa 100644 --- a/crates/rattler_conda_types/src/utils/mod.rs +++ b/crates/rattler_conda_types/src/utils/mod.rs @@ -1,3 +1,6 @@ pub(crate) mod path; pub(crate) mod serde; pub(crate) mod url; +pub(crate) mod url_with_trailing_slash; + +pub(crate) use url_with_trailing_slash::UrlWithTrailingSlash; diff --git a/crates/rattler_conda_types/src/utils/url.rs b/crates/rattler_conda_types/src/utils/url.rs index c9d81579a..8c8ecba2a 100644 --- a/crates/rattler_conda_types/src/utils/url.rs +++ b/crates/rattler_conda_types/src/utils/url.rs @@ -1,6 +1,3 @@ -use std::borrow::Cow; -use url::Url; - /// Parses the schema part of the human-readable channel. Returns the scheme part if it exists. pub(crate) fn parse_scheme(channel: &str) -> Option<&str> { let scheme_end = channel.find("://")?; @@ -25,14 +22,3 @@ pub(crate) fn parse_scheme(channel: &str) -> Option<&str> { None } } - -pub(crate) fn add_trailing_slash(url: &Url) -> Cow<'_, Url> { - let path = url.path(); - if path.ends_with('/') { - Cow::Borrowed(url) - } else { - let mut url = url.clone(); - url.set_path(&format!("{path}/")); - Cow::Owned(url) - } -} diff --git a/crates/rattler_conda_types/src/utils/url_with_trailing_slash.rs b/crates/rattler_conda_types/src/utils/url_with_trailing_slash.rs new file mode 100644 index 000000000..470babdf1 --- /dev/null +++ b/crates/rattler_conda_types/src/utils/url_with_trailing_slash.rs @@ -0,0 +1,60 @@ +use serde::{Deserialize, Deserializer, Serialize}; +use std::fmt::{Display, Formatter}; +use std::ops::Deref; +use url::Url; + +/// A URL that always has a trailing slash. A trailing slash in a URL has +/// significance but users often forget to add it. This type is used to +/// normalize the use of the URL. +#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize)] +#[serde(transparent)] +pub struct UrlWithTrailingSlash(Url); + +impl Deref for UrlWithTrailingSlash { + type Target = Url; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef for UrlWithTrailingSlash { + fn as_ref(&self) -> &Url { + &self.0 + } +} + +impl From for UrlWithTrailingSlash { + fn from(url: Url) -> Self { + let path = url.path(); + if path.ends_with('/') { + Self(url) + } else { + let mut url = url.clone(); + url.set_path(&format!("{path}/")); + Self(url) + } + } +} + +impl<'de> Deserialize<'de> for UrlWithTrailingSlash { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let url = Url::deserialize(deserializer)?; + Ok(url.into()) + } +} + +impl From for Url { + fn from(value: UrlWithTrailingSlash) -> Self { + value.0 + } +} + +impl Display for UrlWithTrailingSlash { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", &self.0) + } +} diff --git a/crates/rattler_repodata_gateway/src/gateway/channel_config.rs b/crates/rattler_repodata_gateway/src/gateway/channel_config.rs index e9affbb42..6b2017921 100644 --- a/crates/rattler_repodata_gateway/src/gateway/channel_config.rs +++ b/crates/rattler_repodata_gateway/src/gateway/channel_config.rs @@ -1,20 +1,29 @@ -use crate::fetch::CacheAction; -use rattler_conda_types::Channel; use std::collections::HashMap; -/// Describes additional properties that influence how the gateway fetches repodata for a specific -/// channel. +use rattler_conda_types::ChannelUrl; +use url::Url; + +use crate::fetch::CacheAction; + +/// Describes additional properties that influence how the gateway fetches +/// repodata for a specific channel. #[derive(Debug, Clone)] pub struct SourceConfig { - /// When enabled repodata can be fetched incrementally using JLAP (defaults to true) + /// When enabled repodata can be fetched incrementally using JLAP (defaults + /// to true) pub jlap_enabled: bool, - /// When enabled, the zstd variant will be used if available (defaults to true) + /// When enabled, the zstd variant will be used if available (defaults to + /// true) pub zstd_enabled: bool, - /// When enabled, the bz2 variant will be used if available (defaults to true) + /// When enabled, the bz2 variant will be used if available (defaults to + /// true) pub bz2_enabled: bool, + /// When enabled, sharded repodata will be used if available. + pub sharded_enabled: bool, + /// Describes fetching repodata from a channel should interact with any /// caches. pub cache_action: CacheAction, @@ -26,6 +35,7 @@ impl Default for SourceConfig { jlap_enabled: true, zstd_enabled: true, bz2_enabled: true, + sharded_enabled: false, cache_action: CacheAction::default(), } } @@ -34,18 +44,31 @@ impl Default for SourceConfig { /// Describes additional information for fetching channels. #[derive(Debug, Default)] pub struct ChannelConfig { - /// The default source configuration. If a channel does not have a specific source configuration - /// this configuration will be used. + /// The default source configuration. If a channel does not have a specific + /// source configuration this configuration will be used. pub default: SourceConfig, - /// Describes per channel properties that influence how the gateway fetches repodata. - pub per_channel: HashMap, + /// Source configuration on a per-URL basis. This URL is used as a prefix, + /// so any channel that starts with the URL uses the configuration. + /// The configuration with the longest matching prefix is used. + pub per_channel: HashMap, } impl ChannelConfig { - /// Returns the source configuration for the given channel. If the channel does not have a - /// specific source configuration the default source configuration will be returned. - pub fn get(&self, channel: &Channel) -> &SourceConfig { - self.per_channel.get(channel).unwrap_or(&self.default) + /// Returns the source configuration for the given channel. Locates the + /// source configuration that best matches the requested channel. + pub fn get(&self, channel: &ChannelUrl) -> &SourceConfig { + self.per_channel + .iter() + .filter_map(|(url, config)| { + let key_url = url.as_str().strip_suffix('/').unwrap_or(url.as_str()); + if channel.as_str().starts_with(key_url) { + Some((key_url.len(), config)) + } else { + None + } + }) + .max_by_key(|(len, _)| *len) + .map_or(&self.default, |(_, config)| config) } } diff --git a/crates/rattler_repodata_gateway/src/gateway/error.rs b/crates/rattler_repodata_gateway/src/gateway/error.rs index af421daa8..0a2d8e053 100644 --- a/crates/rattler_repodata_gateway/src/gateway/error.rs +++ b/crates/rattler_repodata_gateway/src/gateway/error.rs @@ -47,6 +47,9 @@ pub enum GatewayError { #[error(transparent)] InvalidPackageName(#[from] InvalidPackageNameError), + + #[error("{0}")] + CacheError(String), } impl From for GatewayError { diff --git a/crates/rattler_repodata_gateway/src/gateway/mod.rs b/crates/rattler_repodata_gateway/src/gateway/mod.rs index ac93c35e4..d4df74bce 100644 --- a/crates/rattler_repodata_gateway/src/gateway/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/mod.rs @@ -9,6 +9,7 @@ mod remote_subdir; mod repo_data; mod sharded_subdir; mod subdir; +mod subdir_builder; use std::{ collections::HashSet, @@ -21,19 +22,17 @@ pub use builder::GatewayBuilder; pub use channel_config::{ChannelConfig, SourceConfig}; use dashmap::{mapref::entry::Entry, DashMap}; pub use error::GatewayError; -use file_url::url_to_path; -use local_subdir::LocalSubdirClient; pub use query::{NamesQuery, RepoDataQuery}; use rattler_cache::package_cache::PackageCache; use rattler_conda_types::{Channel, MatchSpec, Platform}; pub use repo_data::RepoData; use reqwest_middleware::ClientWithMiddleware; -use subdir::{Subdir, SubdirData}; +use subdir::Subdir; use tokio::sync::broadcast; use tracing::instrument; use url::Url; -use crate::{fetch::FetchRepoDataError, gateway::error::SubdirNotFoundError, Reporter}; +use crate::{gateway::subdir_builder::SubdirBuilder, Reporter}; /// Central access point for high level queries about /// [`rattler_conda_types::RepoDataRecord`]s from different channels. @@ -143,7 +142,7 @@ impl Gateway { /// This method does not clear any on-disk cache. pub fn clear_repodata_cache(&self, channel: &Channel, subdirs: SubdirSelection) { self.inner.subdirs.retain(|key, _| { - key.0.base_url() != channel.base_url() || !subdirs.contains(key.1.as_str()) + key.0.base_url != channel.base_url || !subdirs.contains(key.1.as_str()) }); } } @@ -177,7 +176,7 @@ impl GatewayInner { /// coalesced, and they will all receive the same subdir. If an error /// occurs while creating the subdir all waiting tasks will also return an /// error. - #[instrument(skip(self, reporter), err)] + #[instrument(skip(self, reporter, channel), fields(channel = %channel.base_url), err)] async fn get_or_create_subdir( &self, channel: &Channel, @@ -271,76 +270,9 @@ impl GatewayInner { platform: Platform, reporter: Option>, ) -> Result { - let url = channel.platform_url(platform); - let subdir_data = if url.scheme() == "file" { - if let Some(path) = url_to_path(&url) { - LocalSubdirClient::from_channel_subdir( - &path.join("repodata.json"), - channel.clone(), - platform.as_str(), - ) - .await - .map(SubdirData::from_client) - } else { - return Err(GatewayError::UnsupportedUrl( - "unsupported file based url".to_string(), - )); - } - } else if supports_sharded_repodata(&url) { - sharded_subdir::ShardedSubdir::new( - channel.clone(), - platform.to_string(), - self.client.clone(), - self.cache.clone(), - self.concurrent_requests_semaphore.clone(), - reporter.as_deref(), - ) + SubdirBuilder::new(self, channel.clone(), platform, reporter) + .build() .await - .map(SubdirData::from_client) - } else if url.scheme() == "http" - || url.scheme() == "https" - || url.scheme() == "gcs" - || url.scheme() == "oci" - { - remote_subdir::RemoteSubdirClient::new( - channel.clone(), - platform, - self.client.clone(), - self.cache.clone(), - self.channel_config.get(channel).clone(), - reporter, - ) - .await - .map(SubdirData::from_client) - } else { - return Err(GatewayError::UnsupportedUrl(format!( - "'{}' is not a supported scheme", - url.scheme() - ))); - }; - - match subdir_data { - Ok(client) => Ok(Subdir::Found(client)), - Err(GatewayError::SubdirNotFoundError(err)) if platform != Platform::NoArch => { - // If the subdir was not found and the platform is not `noarch` we assume its - // just empty. - tracing::info!( - "subdir {} of channel {} was not found, ignoring", - err.subdir, - err.channel.canonical_name() - ); - Ok(Subdir::NotFound) - } - Err(GatewayError::FetchRepoDataError(FetchRepoDataError::NotFound(err))) => { - Err(SubdirNotFoundError { - subdir: platform.to_string(), - channel: channel.clone(), - source: err.into(), - } - .into()) - } - Err(err) => Err(err), - } } } @@ -351,9 +283,9 @@ enum PendingOrFetched { Fetched(T), } -fn supports_sharded_repodata(url: &Url) -> bool { - (url.scheme() == "http" || url.scheme() == "https") - && (url.host_str() == Some("fast.prefiks.dev") || url.host_str() == Some("fast.prefix.dev")) +fn force_sharded_repodata(url: &Url) -> bool { + matches!(url.scheme(), "http" | "https") + && matches!(url.host_str(), Some("fast.prefiks.dev" | "fast.prefix.dev")) } #[cfg(test)] diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs index e6f761e95..223075c13 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs @@ -15,8 +15,11 @@ use tokio::{ }; use url::Url; -use super::{token::TokenClient, ShardedRepodata}; -use crate::{reporter::ResponseReporterExt, utils::url_to_cache_filename, GatewayError, Reporter}; +use super::ShardedRepodata; +use crate::{ + fetch::CacheAction, reporter::ResponseReporterExt, utils::url_to_cache_filename, GatewayError, + Reporter, +}; /// Magic number that identifies the cache file format. const MAGIC_NUMBER: &[u8] = b"SHARD-CACHE-V1"; @@ -27,8 +30,8 @@ const REPODATA_SHARDS_FILENAME: &str = "repodata_shards.msgpack.zst"; pub async fn fetch_index( client: ClientWithMiddleware, channel_base_url: &Url, - token_client: &TokenClient, cache_dir: &Path, + cache_action: CacheAction, concurrent_requests_semaphore: Arc, reporter: Option<&dyn Reporter>, ) -> Result { @@ -39,6 +42,8 @@ pub async fn fetch_index( response: Response, reporter: Option<(&dyn Reporter, usize)>, ) -> Result { + let response = response.error_for_status()?; + // Read the bytes of the response let response_url = response.url().clone(); let bytes = response.bytes_with_progress(reporter).await?; @@ -124,105 +129,118 @@ pub async fn fetch_index( let canonical_request = SimpleRequest::get(&canonical_shards_url); // Try reading the cached file - if let Ok(cache_header) = read_cached_index(&mut cache_reader).await { - match cache_header - .policy - .before_request(&canonical_request, SystemTime::now()) - { - BeforeRequest::Fresh(_) => { + if cache_action != CacheAction::NoCache { + if let Ok(cache_header) = read_cached_index(&mut cache_reader).await { + // If we are in cache-only mode we can't fetch the index from the server + if cache_action == CacheAction::ForceCacheOnly { if let Ok(shard_index) = read_shard_index_from_reader(&mut cache_reader).await { - tracing::debug!("shard index cache hit"); + tracing::debug!("using locally cached shard index for {channel_base_url}"); return Ok(shard_index); } - } - BeforeRequest::Stale { - request: state_request, - .. - } => { - // Get the token from the token client - let token = token_client.get_token(reporter).await?; - - // Determine the actual URL to use for the request - let shards_url = token - .shard_base_url - .as_ref() - .unwrap_or(channel_base_url) - .join(REPODATA_SHARDS_FILENAME) - .expect("invalid shard base url"); - - // Construct the actual request that we will send - let mut request = client - .get(shards_url.clone()) - .headers(state_request.headers().clone()) - .build() - .expect("failed to build request for shard index"); - token.add_to_headers(request.headers_mut()); - - // Acquire a permit to do a request - let _request_permit = concurrent_requests_semaphore.acquire().await; - - // Send the request - let download_reporter = reporter.map(|r| (r, r.on_download_start(&shards_url))); - let response = client.execute(request).await?; - - match cache_header.policy.after_response( - &state_request, - &response, - SystemTime::now(), - ) { - AfterResponse::NotModified(_policy, _) => { - // The cached file is still valid - match read_shard_index_from_reader(&mut cache_reader).await { - Ok(shard_index) => { - tracing::debug!("shard index cache was not modified"); - // If reading the file failed for some reason we'll just fetch it - // again. - return Ok(shard_index); - } - Err(e) => { - tracing::warn!("the cached shard index has been corrupted: {e}"); - if let Some((reporter, index)) = download_reporter { - reporter.on_download_complete(response.url(), index); + } else { + match cache_header + .policy + .before_request(&canonical_request, SystemTime::now()) + { + BeforeRequest::Fresh(_) => { + if let Ok(shard_index) = + read_shard_index_from_reader(&mut cache_reader).await + { + tracing::debug!("shard index cache hit"); + return Ok(shard_index); + } + } + BeforeRequest::Stale { + request: state_request, + .. + } => { + if cache_action == CacheAction::UseCacheOnly { + return Err(GatewayError::CacheError( + format!("the sharded index cache for {channel_base_url} is stale and cache-only mode is enabled"), + )); + } + + // Determine the actual URL to use for the request + let shards_url = channel_base_url + .join(REPODATA_SHARDS_FILENAME) + .expect("invalid shard base url"); + + // Construct the actual request that we will send + let request = client + .get(shards_url.clone()) + .headers(state_request.headers().clone()) + .build() + .expect("failed to build request for shard index"); + + // Acquire a permit to do a request + let _request_permit = concurrent_requests_semaphore.acquire().await; + + // Send the request + let download_reporter = + reporter.map(|r| (r, r.on_download_start(&shards_url))); + let response = client.execute(request).await?; + + match cache_header.policy.after_response( + &state_request, + &response, + SystemTime::now(), + ) { + AfterResponse::NotModified(_policy, _) => { + // The cached file is still valid + match read_shard_index_from_reader(&mut cache_reader).await { + Ok(shard_index) => { + tracing::debug!("shard index cache was not modified"); + // If reading the file failed for some reason we'll just + // fetch it again. + return Ok(shard_index); + } + Err(e) => { + tracing::warn!( + "the cached shard index has been corrupted: {e}" + ); + if let Some((reporter, index)) = download_reporter { + reporter.on_download_complete(response.url(), index); + } + } } } + AfterResponse::Modified(policy, _) => { + // Close the old file so we can create a new one. + tracing::debug!("shard index cache has become stale"); + return from_response( + cache_reader.into_inner(), + &cache_path, + policy, + response, + download_reporter, + ) + .await; + } } } - AfterResponse::Modified(policy, _) => { - // Close the old file so we can create a new one. - tracing::debug!("shard index cache has become stale"); - return from_response( - cache_reader.into_inner(), - &cache_path, - policy, - response, - download_reporter, - ) - .await; - } } } } - }; + } - tracing::debug!("fetching fresh shard index"); + if cache_action == CacheAction::ForceCacheOnly { + return Err(GatewayError::CacheError(format!( + "the sharded index cache for {channel_base_url} is not available" + ))); + } - // Get the token from the token client - let token = token_client.get_token(reporter).await?; + tracing::debug!("fetching fresh shard index"); // Determine the actual URL to use for the request - let shards_url = token - .shard_base_url - .as_ref() - .unwrap_or(channel_base_url) + let shards_url = channel_base_url .join(REPODATA_SHARDS_FILENAME) .expect("invalid shard base url"); // Construct the actual request that we will send - let mut request = client + let request = client .get(shards_url.clone()) .build() .expect("failed to build request for shard index"); - token.add_to_headers(request.headers_mut()); // Acquire a permit to do a request let _request_permit = concurrent_requests_semaphore.acquire().await; diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs index 0824e9318..d9b224882 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs @@ -4,27 +4,25 @@ use http::{header::CACHE_CONTROL, HeaderValue, StatusCode}; use rattler_conda_types::{Channel, PackageName, RepoDataRecord, Shard, ShardedRepodata}; use reqwest_middleware::ClientWithMiddleware; use simple_spawn_blocking::tokio::run_blocking_task; -use token::TokenClient; use url::Url; use crate::{ - fetch::FetchRepoDataError, + fetch::{CacheAction, FetchRepoDataError}, gateway::{error::SubdirNotFoundError, subdir::SubdirClient}, reporter::ResponseReporterExt, GatewayError, Reporter, }; mod index; -mod token; pub struct ShardedSubdir { channel: Channel, client: ClientWithMiddleware, shards_base_url: Url, package_base_url: Url, - token_client: TokenClient, sharded_repodata: ShardedRepodata, cache_dir: PathBuf, + cache_action: CacheAction, concurrent_requests_semaphore: Arc, } @@ -34,27 +32,23 @@ impl ShardedSubdir { subdir: String, client: ClientWithMiddleware, cache_dir: PathBuf, + cache_action: CacheAction, concurrent_requests_semaphore: Arc, reporter: Option<&dyn Reporter>, ) -> Result { // Construct the base url for the shards (e.g. `/`). - let index_base_url = add_trailing_slash(channel.base_url()) + let index_base_url = channel + .base_url + .url() .join(&format!("{subdir}/")) .expect("invalid subdir url"); - // Construct a token client to fetch the token when we need it. - let token_client = TokenClient::new( - client.clone(), - index_base_url.clone(), - concurrent_requests_semaphore.clone(), - ); - // Fetch the shard index let sharded_repodata = index::fetch_index( client.clone(), &index_base_url, - &token_client, &cache_dir, + cache_action, concurrent_requests_semaphore.clone(), reporter, ) @@ -101,9 +95,9 @@ impl ShardedSubdir { client, shards_base_url: add_trailing_slash(&shards_base_url).into_owned(), package_base_url: add_trailing_slash(&package_base_url).into_owned(), - token_client, sharded_repodata, cache_dir, + cache_action, concurrent_requests_semaphore, }) } @@ -125,25 +119,35 @@ impl SubdirClient for ShardedSubdir { let shard_cache_path = self.cache_dir.join(format!("{shard:x}.msgpack")); // Read the cached shard - match tokio::fs::read(&shard_cache_path).await { - Ok(cached_bytes) => { - // Decode the cached shard - return parse_records( - cached_bytes, - self.channel.canonical_name(), - self.package_base_url.clone(), - ) - .await - .map(Arc::from); - } - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - // The file is missing from the cache, we need to download it. + if self.cache_action != CacheAction::NoCache { + match tokio::fs::read(&shard_cache_path).await { + Ok(cached_bytes) => { + // Decode the cached shard + return parse_records( + cached_bytes, + self.channel.canonical_name(), + self.package_base_url.clone(), + ) + .await + .map(Arc::from); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + // The file is missing from the cache, we need to download + // it. + } + Err(err) => return Err(FetchRepoDataError::IoError(err).into()), } - Err(err) => return Err(FetchRepoDataError::IoError(err).into()), } - // Get the token - let token = self.token_client.get_token(reporter).await?; + if matches!( + self.cache_action, + CacheAction::UseCacheOnly | CacheAction::ForceCacheOnly + ) { + return Err(GatewayError::CacheError(format!( + "the shard for package '{}' is not in the cache", + name.as_source() + ))); + } // Download the shard let shard_url = self @@ -151,13 +155,12 @@ impl SubdirClient for ShardedSubdir { .join(&format!("{shard:x}.msgpack.zst")) .expect("invalid shard url"); - let mut shard_request = self + let shard_request = self .client .get(shard_url.clone()) .header(CACHE_CONTROL, HeaderValue::from_static("no-store")) .build() .expect("failed to build shard request"); - token.add_to_headers(shard_request.headers_mut()); let shard_bytes = { let _permit = self.concurrent_requests_semaphore.acquire(); diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/token.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/token.rs deleted file mode 100644 index bd4a98c51..000000000 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/token.rs +++ /dev/null @@ -1,166 +0,0 @@ -use crate::reporter::ResponseReporterExt; -use crate::Reporter; -use crate::{fetch::FetchRepoDataError, gateway::PendingOrFetched, GatewayError}; -use chrono::{DateTime, TimeDelta, Utc}; -use http::header::CACHE_CONTROL; -use http::HeaderValue; -use itertools::Either; -use parking_lot::Mutex; -use reqwest_middleware::ClientWithMiddleware; -use serde::{Deserialize, Serialize}; -use std::ops::Add; -use std::sync::Arc; -use url::Url; - -/// A simple client that makes it simple to fetch a token from the token endpoint. -pub struct TokenClient { - client: ClientWithMiddleware, - token_base_url: Url, - token: Arc>>>>, - concurrent_request_semaphore: Arc, -} - -impl TokenClient { - pub fn new( - client: ClientWithMiddleware, - token_base_url: Url, - concurrent_request_semaphore: Arc, - ) -> Self { - Self { - client, - token_base_url, - token: Arc::new(Mutex::new(PendingOrFetched::Fetched(None))), - concurrent_request_semaphore, - } - } - - /// Returns the current token or fetches a new one if the current one is expired. - pub async fn get_token( - &self, - reporter: Option<&dyn Reporter>, - ) -> Result, GatewayError> { - let sender_or_receiver = { - let mut token = self.token.lock(); - match &*token { - PendingOrFetched::Fetched(Some(token)) if token.is_fresh() => { - // The token is still fresh. - return Ok(token.clone()); - } - PendingOrFetched::Fetched(_) => { - let (sender, _) = tokio::sync::broadcast::channel(1); - let sender = Arc::new(sender); - *token = PendingOrFetched::Pending(Arc::downgrade(&sender)); - - Either::Left(sender) - } - PendingOrFetched::Pending(sender) => { - let sender = sender.upgrade(); - if let Some(sender) = sender { - Either::Right(sender.subscribe()) - } else { - let (sender, _) = tokio::sync::broadcast::channel(1); - let sender = Arc::new(sender); - *token = PendingOrFetched::Pending(Arc::downgrade(&sender)); - Either::Left(sender) - } - } - } - }; - - let sender = match sender_or_receiver { - Either::Left(sender) => sender, - Either::Right(mut receiver) => { - return match receiver.recv().await { - Ok(Some(token)) => Ok(token), - _ => { - // If this happens the sender was dropped. - Err(GatewayError::IoError( - "a coalesced request for a token failed".to_string(), - std::io::ErrorKind::Other.into(), - )) - } - }; - } - }; - - let token_url = self - .token_base_url - .join("token") - .expect("invalid token url"); - tracing::debug!("fetching token from {}", &token_url); - - // Fetch the token - let token = { - let _permit = self.concurrent_request_semaphore.acquire().await; - let reporter = reporter.map(|r| (r, r.on_download_start(&token_url))); - let response = self - .client - .get(token_url.clone()) - .header(CACHE_CONTROL, HeaderValue::from_static("max-age=0")) - .send() - .await - .and_then(|r| r.error_for_status().map_err(Into::into)) - .map_err(GatewayError::from)?; - - let bytes = response - .bytes_with_progress(reporter) - .await - .map_err(FetchRepoDataError::from) - .map_err(GatewayError::from)?; - - if let Some((reporter, index)) = reporter { - reporter.on_download_complete(&token_url, index); - } - - let mut token: Token = serde_json::from_slice(&bytes).map_err(|e| { - GatewayError::IoError("failed to parse sharded index token".to_string(), e.into()) - })?; - - // Ensure that the issued_at field is set. - token.issued_at.get_or_insert_with(Utc::now); - - Arc::new(token) - }; - - // Reacquire the token - let mut token_lock = self.token.lock(); - *token_lock = PendingOrFetched::Fetched(Some(token.clone())); - - // Publish the change - let _ = sender.send(Some(token.clone())); - - Ok(token) - } -} - -/// The token endpoint response. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Token { - pub token: Option, - issued_at: Option>, - expires_in: Option, - pub shard_base_url: Option, -} - -impl Token { - /// Returns true if the token is still considered to be valid. - pub fn is_fresh(&self) -> bool { - if let (Some(issued_at), Some(expires_in)) = (&self.issued_at, self.expires_in) { - let now = Utc::now(); - if issued_at.add(TimeDelta::seconds(expires_in as i64)) > now { - return false; - } - } - true - } - - /// Add the token to the headers if its available - pub fn add_to_headers(&self, headers: &mut http::header::HeaderMap) { - if let Some(token) = &self.token { - headers.insert( - http::header::AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}")).unwrap(), - ); - } - } -} diff --git a/crates/rattler_repodata_gateway/src/gateway/subdir_builder.rs b/crates/rattler_repodata_gateway/src/gateway/subdir_builder.rs new file mode 100644 index 000000000..8a101659c --- /dev/null +++ b/crates/rattler_repodata_gateway/src/gateway/subdir_builder.rs @@ -0,0 +1,158 @@ +use std::{path::Path, sync::Arc}; + +use file_url::url_to_path; +use rattler_conda_types::{Channel, Platform}; + +use crate::{ + fetch::FetchRepoDataError, + gateway, + gateway::{ + error::SubdirNotFoundError, + local_subdir::LocalSubdirClient, + remote_subdir, sharded_subdir, + subdir::{Subdir, SubdirData}, + GatewayInner, + }, + GatewayError, Reporter, SourceConfig, +}; + +/// Builder for creating a `Subdir` instance. +pub struct SubdirBuilder<'g> { + channel: Channel, + platform: Platform, + reporter: Option>, + gateway: &'g GatewayInner, +} + +impl<'g> SubdirBuilder<'g> { + pub fn new( + gateway: &'g GatewayInner, + channel: Channel, + platform: Platform, + reporter: Option>, + ) -> Self { + Self { + channel, + platform, + reporter, + gateway, + } + } + + pub async fn build(self) -> Result { + let url = self.channel.platform_url(self.platform); + + let subdir_data = if url.scheme() == "file" { + if let Some(path) = url_to_path(&url) { + self.build_local(&path).await + } else { + return Err(GatewayError::UnsupportedUrl( + "unsupported file based url".to_string(), + )); + } + } else if url.scheme() == "http" + || url.scheme() == "https" + || url.scheme() == "gcs" + || url.scheme() == "oci" + { + let source_config = self.gateway.channel_config.get(&self.channel.base_url); + + // Use sharded repodata if enabled + let subdir_data = if source_config.sharded_enabled + || gateway::force_sharded_repodata(&url) + { + match self.build_sharded(source_config).await { + Ok(client) => Some(client), + Err(GatewayError::SubdirNotFoundError(_)) => { + tracing::info!( + "sharded repodata seems to be missing for {url}, falling back to repodata.json files", + ); + None + } + Err(err) => return Err(err), + } + } else { + None + }; + + // Otherwise fall back to repodata.json files + if let Some(subdir_data) = subdir_data { + Ok(subdir_data) + } else { + self.build_generic(source_config).await + } + } else { + return Err(GatewayError::UnsupportedUrl(format!( + "'{}' is not a supported scheme", + url.scheme() + ))); + }; + + match subdir_data { + Ok(client) => Ok(Subdir::Found(client)), + Err(GatewayError::SubdirNotFoundError(err)) if self.platform != Platform::NoArch => { + // If the subdir was not found and the platform is not `noarch` we assume its + // just empty. + tracing::info!( + "subdir {} of channel {} was not found, ignoring", + err.subdir, + err.channel.canonical_name() + ); + Ok(Subdir::NotFound) + } + Err(GatewayError::FetchRepoDataError(FetchRepoDataError::NotFound(err))) => { + Err(SubdirNotFoundError { + subdir: self.platform.to_string(), + channel: self.channel.clone(), + source: err.into(), + } + .into()) + } + Err(err) => Err(err), + } + } + + async fn build_generic( + &self, + source_config: &SourceConfig, + ) -> Result { + let client = remote_subdir::RemoteSubdirClient::new( + self.channel.clone(), + self.platform, + self.gateway.client.clone(), + self.gateway.cache.clone(), + source_config.clone(), + self.reporter.clone(), + ) + .await?; + Ok(SubdirData::from_client(client)) + } + + async fn build_sharded( + &self, + source_config: &SourceConfig, + ) -> Result { + let client = sharded_subdir::ShardedSubdir::new( + self.channel.clone(), + self.platform.to_string(), + self.gateway.client.clone(), + self.gateway.cache.clone(), + source_config.cache_action, + self.gateway.concurrent_requests_semaphore.clone(), + self.reporter.as_deref(), + ) + .await?; + + Ok(SubdirData::from_client(client)) + } + + async fn build_local(&self, path: &Path) -> Result { + let client = LocalSubdirClient::from_channel_subdir( + &path.join("repodata.json"), + self.channel.clone(), + self.platform.as_str(), + ) + .await?; + Ok(SubdirData::from_client(client)) + } +} diff --git a/crates/rattler_repodata_gateway/src/sparse/mod.rs b/crates/rattler_repodata_gateway/src/sparse/mod.rs index dea1363ab..6800d43f2 100644 --- a/crates/rattler_repodata_gateway/src/sparse/mod.rs +++ b/crates/rattler_repodata_gateway/src/sparse/mod.rs @@ -309,6 +309,7 @@ fn parse_records<'i>( url: compute_package_url( &channel .base_url + .url() .join(&format!("{}/", &package_record.subdir)) .expect("failed determine repo_base_url"), base_url, @@ -477,13 +478,13 @@ mod test { use std::path::{Path, PathBuf}; use bytes::Bytes; + use fs_err as fs; use itertools::Itertools; use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord}; use rstest::rstest; use super::{load_repo_data_recursively, PackageFilename, SparseRepoData}; use crate::utils::test::fetch_repo_data; - use fs_err as fs; fn test_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data") diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 9301142b0..2c9c92a0e 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -2890,7 +2890,7 @@ dependencies = [ [[package]] name = "rattler_repodata_gateway" -version = "0.21.19" +version = "0.21.20" dependencies = [ "anyhow", "async-compression", @@ -2946,6 +2946,7 @@ name = "rattler_shell" version = "0.22.5" dependencies = [ "enum_dispatch", + "fs-err 3.0.0", "indexmap 2.6.0", "itertools 0.13.0", "rattler_conda_types", @@ -2958,7 +2959,7 @@ dependencies = [ [[package]] name = "rattler_solve" -version = "1.2.0" +version = "1.2.1" dependencies = [ "chrono", "futures", diff --git a/py-rattler/rattler/repo_data/gateway.py b/py-rattler/rattler/repo_data/gateway.py index 49757db02..cee4612ec 100644 --- a/py-rattler/rattler/repo_data/gateway.py +++ b/py-rattler/rattler/repo_data/gateway.py @@ -33,6 +33,9 @@ class SourceConfig: bz2_enabled: bool = True """Whether the BZ2 compression is enabled or not.""" + sharded_enabled: bool = False + """Whether sharded repodata is enabled or not.""" + cache_action: CacheAction = "cache-or-fetch" """How to interact with the cache. @@ -58,6 +61,7 @@ def _into_py(self) -> PySourceConfig: jlap_enabled=self.jlap_enabled, zstd_enabled=self.zstd_enabled, bz2_enabled=self.bz2_enabled, + sharded_enabled=self.sharded_enabled, cache_action=self.cache_action, ) @@ -84,7 +88,7 @@ def __init__( self, cache_dir: Optional[os.PathLike[str]] = None, default_config: Optional[SourceConfig] = None, - per_channel_config: Optional[dict[Channel | str, SourceConfig]] = None, + per_channel_config: Optional[dict[str, SourceConfig]] = None, max_concurrent_requests: int = 100, ) -> None: """ @@ -92,7 +96,9 @@ def __init__( cache_dir: The directory where the repodata should be cached. If not specified the default cache directory is used. default_config: The default configuration for channels. - per_channel_config: Per channel configuration. + per_channel_config: Source configuration on a per-URL basis. This URL is used as a + prefix, so any channel that starts with the URL uses the configuration. + The configuration with the longest matching prefix is used. max_concurrent_requests: The maximum number of concurrent requests that can be made. Examples diff --git a/py-rattler/src/lock/mod.rs b/py-rattler/src/lock/mod.rs index d2186fafe..53308000c 100644 --- a/py-rattler/src/lock/mod.rs +++ b/py-rattler/src/lock/mod.rs @@ -270,7 +270,7 @@ impl From for PyLockChannel { impl From for PyLockChannel { fn from(value: rattler_conda_types::Channel) -> Self { Self { - inner: Channel::from(value.base_url().to_string()), + inner: Channel::from(value.base_url.to_string()), } } } diff --git a/py-rattler/src/repo_data/gateway.rs b/py-rattler/src/repo_data/gateway.rs index aac914a32..97aee99c0 100644 --- a/py-rattler/src/repo_data/gateway.rs +++ b/py-rattler/src/repo_data/gateway.rs @@ -11,6 +11,7 @@ use rattler_repodata_gateway::fetch::CacheAction; use rattler_repodata_gateway::{ChannelConfig, Gateway, SourceConfig, SubdirSelection}; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; +use url::Url; #[pyclass] #[repr(transparent)] @@ -54,15 +55,18 @@ impl PyGateway { pub fn new( max_concurrent_requests: usize, default_config: PySourceConfig, - per_channel_config: HashMap, + per_channel_config: HashMap, cache_dir: Option, ) -> PyResult { let channel_config = ChannelConfig { default: default_config.into(), per_channel: per_channel_config .into_iter() - .map(|(k, v)| (k.into(), v.into())) - .collect(), + .map(|(k, v)| { + let url = Url::parse(&k).map_err(PyRattlerError::from)?; + Ok((url, v.into())) + }) + .collect::>()?, }; let mut gateway = Gateway::builder() @@ -174,10 +178,12 @@ impl<'py> FromPyObject<'py> for Wrap { #[pymethods] impl PySourceConfig { #[new] + #[allow(clippy::fn_params_excessive_bools)] pub fn new( jlap_enabled: bool, zstd_enabled: bool, bz2_enabled: bool, + sharded_enabled: bool, cache_action: Wrap, ) -> Self { Self { @@ -185,6 +191,7 @@ impl PySourceConfig { jlap_enabled, zstd_enabled, bz2_enabled, + sharded_enabled, cache_action: cache_action.0, }, }