Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add link function to py-rattler #364

Merged
merged 4 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rattler/src/install/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod transaction;
pub use crate::install::entry_point::python_entry_point_template;
pub use driver::InstallDriver;
pub use link::{link_file, LinkFileError};
pub use transaction::{Transaction, TransactionOperation};
pub use transaction::{Transaction, TransactionError, TransactionOperation};

use crate::install::entry_point::{
create_unix_python_entry_point, create_windows_python_entry_point,
Expand Down
1 change: 1 addition & 0 deletions crates/rattler/src/install/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::install::python::PythonInfoError;
use crate::install::PythonInfo;
use rattler_conda_types::{PackageRecord, Platform};

/// Error that occurred during creation of a Transaction
#[derive(Debug, thiserror::Error)]
pub enum TransactionError {
/// An error that happens if the python version could not be parsed.
Expand Down
2 changes: 2 additions & 0 deletions py-rattler/rattler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rattler.prefix import PrefixRecord, PrefixPaths
from rattler.solver import solve
from rattler.platform import Platform
from rattler.linker import link

__all__ = [
"Version",
Expand All @@ -36,4 +37,5 @@
"SparseRepoData",
"solve",
"Platform",
"link",
]
3 changes: 3 additions & 0 deletions py-rattler/rattler/linker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from rattler.linker.linker import link

__all__ = ["link"]
48 changes: 48 additions & 0 deletions py-rattler/rattler/linker/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations
import os
from typing import List, Optional

from rattler.networking.authenticated_client import AuthenticatedClient
from rattler.platform.platform import Platform
from rattler.prefix.prefix_record import PrefixRecord
from rattler.repo_data.record import RepoDataRecord

from rattler.rattler import py_link


async def link(
dependencies: List[RepoDataRecord],
target_prefix: os.PathLike[str],
cache_dir: os.PathLike[str],
installed_packages: Optional[List[PrefixRecord]] = None,
platform: Optional[Platform] = None,
) -> None:
"""
Create an environment by downloading and linking the `dependencies` in
the `target_prefix` directory.

Arguments:
dependencies: A list of solved `RepoDataRecord`s.
target_prefix: Path to the directory where the environment should
be created.
cache_dir: Path to directory where the dependencies will be
downloaded and cached.
installed_packages(optional): A list of `PrefixRecord`s which are
already installed in the
`target_prefix`. This can be obtained
by loading `PrefixRecord`s from
`{target_prefix}/conda-meta/`.
platform(optional): Target platform to create and link the
environment. Defaults to current platform.
"""
platform = platform or Platform.current()
client = AuthenticatedClient()

await py_link(
dependencies,
target_prefix,
cache_dir,
installed_packages or [],
platform._inner,
client._client,
)
15 changes: 9 additions & 6 deletions py-rattler/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io;

use pyo3::exceptions::PyException;
use pyo3::{create_exception, PyErr};
use rattler::install::TransactionError;
use rattler_conda_types::{
InvalidPackageNameError, ParseArchError, ParseChannelError, ParseMatchSpecError,
ParsePlatformError, ParseVersionError,
Expand Down Expand Up @@ -41,8 +42,10 @@ pub enum PyRattlerError {
IoError(#[from] io::Error),
#[error(transparent)]
SolverError(#[from] SolveError),
#[error("invalid 'SparseRepoData' object found")]
InvalidSparseDataError,
#[error(transparent)]
TransactionError(#[from] TransactionError),
#[error("{0}")]
LinkError(String),
}

impl From<PyRattlerError> for PyErr {
Expand Down Expand Up @@ -75,9 +78,8 @@ impl From<PyRattlerError> for PyErr {
}
PyRattlerError::IoError(err) => IoException::new_err(err.to_string()),
PyRattlerError::SolverError(err) => SolverException::new_err(err.to_string()),
PyRattlerError::InvalidSparseDataError => InvalidSparseDataException::new_err(
PyRattlerError::InvalidSparseDataError.to_string(),
),
PyRattlerError::TransactionError(err) => TransactionException::new_err(err.to_string()),
PyRattlerError::LinkError(err) => LinkException::new_err(err),
}
}
}
Expand All @@ -95,4 +97,5 @@ create_exception!(exceptions, CacheDirException, PyException);
create_exception!(exceptions, DetectVirtualPackageException, PyException);
create_exception!(exceptions, IoException, PyException);
create_exception!(exceptions, SolverException, PyException);
create_exception!(exceptions, InvalidSparseDataException, PyException);
create_exception!(exceptions, TransactionException, PyException);
create_exception!(exceptions, LinkException, PyException);
5 changes: 5 additions & 0 deletions py-rattler/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod channel;
mod error;
mod generic_virtual_package;
mod linker;
mod match_spec;
mod nameless_match_spec;
mod networking;
Expand Down Expand Up @@ -33,6 +34,7 @@ use version::PyVersion;

use pyo3::prelude::*;

use linker::py_link;
use platform::{PyArch, PyPlatform};
use shell::{PyActivationResult, PyActivationVariables, PyActivator, PyShellEnum};
use solver::py_solve;
Expand Down Expand Up @@ -76,6 +78,9 @@ fn rattler(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_solve, m).unwrap())
.unwrap();

m.add_function(wrap_pyfunction!(py_link, m).unwrap())
.unwrap();

// Exceptions
m.add(
"InvalidVersionError",
Expand Down
231 changes: 231 additions & 0 deletions py-rattler/src/linker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
use std::{future::ready, io::ErrorKind, path::PathBuf};

use futures::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use pyo3::{pyfunction, PyAny, PyResult, Python};
use pyo3_asyncio::tokio::future_into_py;
use rattler::{
install::{link_package, InstallDriver, InstallOptions, Transaction, TransactionOperation},
package_cache::PackageCache,
};
use rattler_conda_types::{PackageRecord, PrefixRecord, RepoDataRecord};
use rattler_networking::{retry_policies::default_retry_policy, AuthenticatedClient};

use crate::{
error::PyRattlerError, networking::authenticated_client::PyAuthenticatedClient,
platform::PyPlatform, prefix_record::PyPrefixRecord,
repo_data::repo_data_record::PyRepoDataRecord,
};

// TODO: Accept functions to report progress
#[pyfunction]
pub fn py_link<'a>(
py: Python<'a>,
dependencies: Vec<&'a PyAny>,
target_prefix: PathBuf,
cache_dir: PathBuf,
installed_packages: Vec<&'a PyAny>,
platform: &PyPlatform,
client: PyAuthenticatedClient,
) -> PyResult<&'a PyAny> {
let dependencies = dependencies
.into_iter()
.map(|rdr| Ok(PyRepoDataRecord::try_from(rdr)?.into()))
.collect::<PyResult<Vec<RepoDataRecord>>>()?;

let installed_packages = installed_packages
.iter()
.map(|&rdr| Ok(PyPrefixRecord::try_from(rdr)?.into()))
.collect::<PyResult<Vec<PrefixRecord>>>()?;

let txn = py.allow_threads(move || {
let reqired_packages = PackageRecord::sort_topologically(dependencies);

Transaction::from_current_and_desired(installed_packages, reqired_packages, platform.inner)
.map_err(PyRattlerError::from)
})?;

future_into_py(py, async move {
Ok(execute_transaction(txn, target_prefix, cache_dir, client.inner).await?)
})
}

async fn execute_transaction(
transaction: Transaction<PrefixRecord, RepoDataRecord>,
target_prefix: PathBuf,
cache_dir: PathBuf,
client: AuthenticatedClient,
) -> Result<(), PyRattlerError> {
let package_cache = PackageCache::new(cache_dir.join("pkgs"));

let install_driver = InstallDriver::default();

let install_options = InstallOptions {
python_info: transaction.python_info.clone(),
platform: Some(transaction.platform),
..Default::default()
};

stream::iter(transaction.operations)
.map(Ok)
.try_for_each_concurrent(50, |op| {
let target_prefix = target_prefix.clone();
let client = client.clone();
let package_cache = &package_cache;
let install_driver = &install_driver;
let install_options = &install_options;
async move {
execute_operation(
op,
target_prefix,
package_cache,
client,
install_driver,
install_options,
)
.await
}
})
.await?;

Ok(())
}

pub async fn execute_operation(
op: TransactionOperation<PrefixRecord, RepoDataRecord>,
target_prefix: PathBuf,
package_cache: &PackageCache,
client: AuthenticatedClient,
install_driver: &InstallDriver,
install_options: &InstallOptions,
) -> Result<(), PyRattlerError> {
let install_record = op.record_to_install();
let remove_record = op.record_to_remove();

let remove_future = if let Some(remove_record) = remove_record {
remove_package_from_environment(target_prefix.clone(), remove_record).left_future()
} else {
ready(Ok(())).right_future()
};

let cached_package_dir_fut = if let Some(install_record) = install_record {
async {
package_cache
.get_or_fetch_from_url_with_retry(
&install_record.package_record,
install_record.url.clone(),
client.clone(),
default_retry_policy(),
)
.map_ok(|cache_dir| Some((install_record.clone(), cache_dir)))
.map_err(|e| PyRattlerError::LinkError(e.to_string()))
.await
}
.left_future()
} else {
ready(Ok(None)).right_future()
};

let (_, install_package) = tokio::try_join!(remove_future, cached_package_dir_fut)?;

if let Some((record, package_dir)) = install_package {
install_package_to_environment(
target_prefix,
package_dir,
record.clone(),
install_driver,
install_options,
)
.await?;
}

Ok(())
}

// TODO: expose as python seperate function
pub async fn install_package_to_environment(
target_prefix: PathBuf,
package_dir: PathBuf,
repodata_record: RepoDataRecord,
install_driver: &InstallDriver,
install_options: &InstallOptions,
) -> Result<(), PyRattlerError> {
let paths = link_package(
&package_dir,
target_prefix.as_path(),
install_driver,
install_options.clone(),
)
.await
.map_err(|e| PyRattlerError::LinkError(e.to_string()))?;

let prefix_record = PrefixRecord {
repodata_record,
package_tarball_full_path: None,
extracted_package_dir: Some(package_dir),
files: paths
.iter()
.map(|entry| entry.relative_path.clone())
.collect(),
paths_data: paths.into(),
requested_spec: None,
link: None,
};

let target_prefix = target_prefix.to_path_buf();
match tokio::task::spawn_blocking(move || {
let conda_meta_path = target_prefix.join("conda-meta");
std::fs::create_dir_all(&conda_meta_path)?;

let pkg_meta_path = conda_meta_path.join(format!(
"{}-{}-{}.json",
prefix_record
.repodata_record
.package_record
.name
.as_normalized(),
prefix_record.repodata_record.package_record.version,
prefix_record.repodata_record.package_record.build
));
prefix_record.write_to_path(pkg_meta_path, true)
})
.await
{
Ok(result) => Ok(result?),
Err(err) => {
if let Ok(panic) = err.try_into_panic() {
std::panic::resume_unwind(panic);
}
Ok(())
}
}
}

// TODO: expose as python seperate function
async fn remove_package_from_environment(
target_prefix: PathBuf,
package: &PrefixRecord,
) -> Result<(), PyRattlerError> {
for paths in package.paths_data.paths.iter() {
match tokio::fs::remove_file(target_prefix.join(&paths.relative_path)).await {
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::NotFound => {}
Err(_) => {
return Err(PyRattlerError::LinkError(format!(
"failed to delete {}",
paths.relative_path.display()
)))
}
}
}

let conda_meta_path = target_prefix.join("conda-meta").join(format!(
"{}-{}-{}.json",
package.repodata_record.package_record.name.as_normalized(),
package.repodata_record.package_record.version,
package.repodata_record.package_record.build
));

tokio::fs::remove_file(&conda_meta_path).await.map_err(|_| {
PyRattlerError::LinkError(format!("failed to delete {}", conda_meta_path.display()))
})
}
Loading