Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dlpack support #1306

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }

dlpark = { version = "0.3.0", optional = true }

[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "1.0", default-features = false }
Expand Down Expand Up @@ -73,6 +75,8 @@ rayon = ["rayon_", "std"]

matrixmultiply-threading = ["matrixmultiply/threading"]

dlpack = ["dep:dlpark"]

[profile.bench]
debug = true
[profile.dev.package.numeric-tests]
Expand Down
45 changes: 44 additions & 1 deletion src/data_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ use alloc::sync::Arc;
use alloc::vec::Vec;

use crate::{
ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr
};

#[cfg(feature = "dlpack")]
use crate::ManagedRepr;

/// Array representation trait.
///
/// For an array that meets the invariants of the `ArrayBase` type. This trait
Expand Down Expand Up @@ -346,6 +349,24 @@ unsafe impl<A> RawData for OwnedRepr<A> {
private_impl! {}
}

#[cfg(feature = "dlpack")]
unsafe impl<A> RawData for ManagedRepr<A> {
type Elem = A;

fn _data_slice(&self) -> Option<&[A]> {
Some(self.as_slice())
}

fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool {
let slc = self.as_slice();
let ptr = slc.as_ptr() as *mut A;
let end = unsafe { ptr.add(slc.len()) };
self_ptr >= ptr && self_ptr <= end
}

private_impl! {}
}

unsafe impl<A> RawDataMut for OwnedRepr<A> {
#[inline]
fn try_ensure_unique<D>(_: &mut ArrayBase<Self, D>)
Expand Down Expand Up @@ -382,6 +403,28 @@ unsafe impl<A> Data for OwnedRepr<A> {
}
}

#[cfg(feature = "dlpack")]
unsafe impl<A> Data for ManagedRepr<A> {
#[inline]
fn into_owned<D>(self_: ArrayBase<Self, D>) -> Array<Self::Elem, D>
where
A: Clone,
D: Dimension,
{
self_.to_owned()
}

#[inline]
fn try_into_owned_nocopy<D>(
self_: ArrayBase<Self, D>,
) -> Result<Array<Self::Elem, D>, ArrayBase<Self, D>>
where
D: Dimension,
{
Err(self_)
}
}

unsafe impl<A> DataMut for OwnedRepr<A> {}

unsafe impl<A> RawDataClone for OwnedRepr<A>
Expand Down
94 changes: 94 additions & 0 deletions src/dlpack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use core::ptr::NonNull;
use std::marker::PhantomData;

use dlpark::prelude::*;

use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData};

impl<A, S, D> ToTensor for ArrayBase<S, D>
where
A: InferDtype,
S: RawData<Elem = A>,
D: Dimension,
{
fn data_ptr(&self) -> *mut std::ffi::c_void {
self.as_ptr() as *mut std::ffi::c_void
}

fn byte_offset(&self) -> u64 {
0
}

fn device(&self) -> Device {
Device::CPU
}

fn dtype(&self) -> DataType {
A::infer_dtype()
}

fn shape(&self) -> CowIntArray {
dlpark::prelude::CowIntArray::from_owned(
self.shape().into_iter().map(|&x| x as i64).collect(),
)
}

fn strides(&self) -> Option<CowIntArray> {
Some(dlpark::prelude::CowIntArray::from_owned(
self.strides().into_iter().map(|&x| x as i64).collect(),
))
}
}

pub struct ManagedRepr<A> {
managed_tensor: ManagedTensor,
_ty: PhantomData<A>,
}

impl<A> ManagedRepr<A> {
pub fn new(managed_tensor: ManagedTensor) -> Self {
Self {
managed_tensor,
_ty: PhantomData,
}
}

pub fn as_slice(&self) -> &[A] {
self.managed_tensor.as_slice()
}

pub fn as_ptr(&self) -> *const A {
self.managed_tensor.data_ptr() as *const A
}
}

unsafe impl<A> Sync for ManagedRepr<A> where A: Sync {}
unsafe impl<A> Send for ManagedRepr<A> where A: Send {}

impl<A> FromDLPack for ManagedArray<A, IxDyn> {
fn from_dlpack(dlpack: NonNull<dlpark::ffi::DLManagedTensor>) -> Self {
Copy link
Member

@bluss bluss Mar 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function takes a raw pointer (wrapped in NonNull) and it must be an unsafe function, otherwise we can trivially violate memory safety unfortunately.

The only way to remove this requirement - the requirement of using unsafe - would be if you have a "magical" function that can take an arbitrary pointer and say whether it's a valid, live, non-mutably aliased pointer to a tensor.

Here's how to create a dangling bad pointer: NonNull::new(1 as *mut u8 as *mut dlpark::ffi::DLManagedTensor) does this code crash if we run with this pointer? I think it would..

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. from_dlpack should be unsafe, and users should use it at their own risk.

Copy link
Member

@bluss bluss Mar 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say, we normally don't commit to public dependencies that are not stable (yes, not a very fair policy since ndarray itself is not so stable.), and dlpark is a public dependency here because it becomes part of our API. It could mean it takes a long time between version bumps.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we don't need to include dlpark as a dependency. We can create an ArrayView using ArrayView::from_shape_ptr and ManagedTensor. I can implement ToTensor for ArrayD in dlpark with a new feature ndarray. I'll do some quick experiments.

let managed_tensor = ManagedTensor::new(dlpack);
let shape: Vec<usize> = managed_tensor
.shape()
.into_iter()
.map(|x| *x as _)
.collect();

let strides: Vec<usize> = match (managed_tensor.strides(), managed_tensor.is_contiguous()) {
(Some(s), _) => s.into_iter().map(|&x| x as _).collect(),
(None, true) => managed_tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later work, check compatibility of dlpack and ndarray strides, how they work, their domains etc.

.calculate_contiguous_strides()
.into_iter()
.map(|x| x as _)
.collect(),
(None, false) => panic!("dlpack: invalid strides"),
};
let ptr = managed_tensor.data_ptr() as *mut A;

let managed_repr = ManagedRepr::<A>::new(managed_tensor);
unsafe {
ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr))
.with_strides_dim(strides.into_dimension(), shape.into_dimension())
}
}
}
13 changes: 13 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ mod zip;

mod dimension;

#[cfg(feature = "dlpack")]
mod dlpack;

pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip};

pub use crate::layout::Layout;
Expand Down Expand Up @@ -1346,6 +1349,12 @@ pub type Array<A, D> = ArrayBase<OwnedRepr<A>, D>;
/// instead of either a view or a uniquely owned copy.
pub type CowArray<'a, A, D> = ArrayBase<CowRepr<'a, A>, D>;


/// An array from managed memory
#[cfg(feature = "dlpack")]
pub type ManagedArray<A, D> = ArrayBase<ManagedRepr<A>, D>;


/// A read-only array view.
///
/// An array view represents an array or a part of it, created from
Expand Down Expand Up @@ -1420,6 +1429,10 @@ pub type RawArrayViewMut<A, D> = ArrayBase<RawViewRepr<*mut A>, D>;

pub use data_repr::OwnedRepr;

#[cfg(feature = "dlpack")]
pub use dlpack::ManagedRepr;


/// ArcArray's representation.
///
/// *Don’t use this type directly—use the type alias
Expand Down
17 changes: 17 additions & 0 deletions tests/dlpack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#![cfg(feature = "dlpack")]

use dlpark::prelude::*;
use ndarray::ManagedArray;

#[test]
fn test_dlpack() {
let arr = ndarray::arr1(&[1i32, 2, 3]);
let ptr = arr.as_ptr();
let dlpack = arr.into_dlpack();
let arr2 = ManagedArray::<i32, _>::from_dlpack(dlpack);
let ptr2 = arr2.as_ptr();
assert_eq!(ptr, ptr2);
let arr3 = arr2.to_owned();
let ptr3 = arr3.as_ptr();
assert_ne!(ptr2, ptr3);
}