Skip to content

feat: enable jax backend for virtual arrays #3451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/awkward/_backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ def finder(obj: VirtualArray):
return _name_to_backend_cls["cpu"].instance()
elif isinstance(obj.nplike, ak._nplikes.cupy.Cupy):
return _name_to_backend_cls["cuda"].instance()
elif isinstance(obj.nplike, ak._nplikes.jax.Jax):
return _name_to_backend_cls["jax"].instance()
else:
raise TypeError("A virtual array can only have numpy or cupy backends")
raise TypeError(
f"The nplike {type(obj.nplike)} does not support virtual arrays"
)

return finder
30 changes: 19 additions & 11 deletions src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from awkward import _reducers
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._nplikes.virtual import materialize_if_virtual
from awkward._reducers import Reducer
from awkward._typing import Final, Self, TypeVar

Expand Down Expand Up @@ -88,7 +89,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_argmin(array.data, parents.data)
result = segment_argmin(*materialize_if_virtual(array.data, parents.data))
result = jax.numpy.asarray(result, dtype=array.dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -145,7 +146,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_argmax(array.data, parents.data)
result = segment_argmax(*materialize_if_virtual(array.data, parents.data))
result = jax.numpy.asarray(result, dtype=array.dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -175,8 +176,10 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.numpy.ones_like(array.data, dtype=array.dtype)
result = jax.ops.segment_sum(result, parents.data)
result = jax.numpy.ones_like(
*materialize_if_virtual(array.data), dtype=array.dtype
)
result = jax.ops.segment_sum(result, *materialize_if_virtual(parents.data))

if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
Expand Down Expand Up @@ -232,7 +235,9 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = segment_count_nonzero(array.data, parents.data)
result = segment_count_nonzero(
*materialize_if_virtual(array.data, parents.data)
)
result = jax.numpy.asarray(result, dtype=self.preferred_dtype)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -261,7 +266,7 @@ def apply(
if array.dtype.kind == "M":
raise TypeError(f"cannot compute the sum (ak.sum) of {array.dtype!r}")

result = jax.ops.segment_sum(array.data, parents.data)
result = jax.ops.segment_sum(*materialize_if_virtual(array.data, parents.data))

if array.dtype.kind == "m":
return ak.contents.NumpyArray(
Expand Down Expand Up @@ -295,7 +300,10 @@ def apply(
assert isinstance(array, ak.contents.NumpyArray)
# See issue https://github.com/google/jax/issues/9296
result = jax.numpy.exp(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
jax.ops.segment_sum(
jax.numpy.log(*materialize_if_virtual(array.data)),
*materialize_if_virtual(parents.data),
)
)

if np.issubdtype(array.dtype, np.complexfloating):
Expand Down Expand Up @@ -330,7 +338,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.ops.segment_max(array.data, parents.data)
result = jax.ops.segment_max(*materialize_if_virtual(array.data, parents.data))
result = jax.numpy.asarray(result, dtype=bool)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -360,7 +368,7 @@ def apply(
outlength: ShapeItem,
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)
result = jax.ops.segment_min(array.data, parents.data)
result = jax.ops.segment_min(*materialize_if_virtual(array.data, parents.data))
result = jax.numpy.asarray(result, dtype=bool)

return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -413,7 +421,7 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)

result = jax.ops.segment_min(array.data, parents.data)
result = jax.ops.segment_min(*materialize_if_virtual(array.data, parents.data))
result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype))

if np.issubdtype(array.dtype, np.complexfloating):
Expand Down Expand Up @@ -474,7 +482,7 @@ def apply(
) -> ak.contents.NumpyArray:
assert isinstance(array, ak.contents.NumpyArray)

result = jax.ops.segment_max(array.data, parents.data)
result = jax.ops.segment_max(*materialize_if_virtual(array.data, parents.data))

result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype))
if np.issubdtype(array.dtype, np.complexfloating):
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def _cast(self, x, t):
def __call__(self, *args) -> None:
assert len(args) == len(self._impl.argtypes)

if not any(Jax.is_tracer_type(type(arg)) for arg in args):
args = materialize_if_virtual(*args)
args = materialize_if_virtual(*args)

if not any(Jax.is_tracer_type(type(arg)) for arg in args):
return self._impl(
*(self._cast(x, t) for x, t in zip(args, self._impl.argtypes))
)
Expand Down
5 changes: 2 additions & 3 deletions src/awkward/_nplikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def to_nplike(

# We can always convert virtual arrays to typetracers
# but can only convert virtual arrays to other backends with known data if they are intentionally materialized
# Only numpy and cupy nplikes are allowed for virtual arrays
if isinstance(array, awkward._nplikes.virtual.VirtualArray):
if not array.is_materialized and nplike.known_data:
raise TypeError(
"Cannot convert a VirtualArray to a different nplike with known data without materializing it first. Use ak.materialize on the array to do so."
"Cannot convert a VirtualArray to a different nplike with known data without materializing it first. Use ak.materialize on the array to do so"
)
else:
if nplike.supports_virtual_arrays:
Expand All @@ -43,7 +42,7 @@ def to_nplike(
pass
else:
raise TypeError(
f"Can only convert a VirtualArray to numpy, cupy or typetracer nplikes. Received {type(nplike)}"
f"The target nplike {type(nplike)} does not support virtual arrays"
)

if nplike.known_data and not from_nplike.known_data:
Expand Down
9 changes: 6 additions & 3 deletions src/awkward/_nplikes/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from awkward._nplikes.dispatch import register_nplike
from awkward._nplikes.numpy_like import UfuncLike
from awkward._nplikes.placeholder import PlaceholderArray
from awkward._nplikes.virtual import materialize_if_virtual
from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._typing import Final, cast


@register_nplike
class Jax(ArrayModuleNumpyLike):
is_eager: Final = True
supports_structured_dtypes: Final = False
supports_virtual_arrays: Final = False
supports_virtual_arrays: Final = True

def __init__(self):
jax = ak.jax.import_jax()
Expand Down Expand Up @@ -87,7 +87,10 @@ def is_c_contiguous(self, x: ArrayLike) -> bool:
return True

def ascontiguousarray(self, x: ArrayLike) -> ArrayLike:
return x
if isinstance(x, VirtualArray) and x.is_materialized:
return x.materialize()
else:
return x

def strides(self, x: ArrayLike) -> tuple[int, ...]:
out: tuple[int, ...] = (x.dtype.itemsize,)
Expand Down
17 changes: 8 additions & 9 deletions src/awkward/_nplikes/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def materialize_if_virtual(*args: Any) -> tuple[Any, ...]:
class VirtualArray(NDArrayOperatorsMixin, ArrayLike):
"""
Implements a virtual array to be used as a buffer inside layouts.
Virtual arrays are tied to specific nplikes and only numpy and cupy nplikes are allowed.
Therefore, virtual arrays are only allowed to generate `numpy.ndarray`s or `cupy.ndarray`s when materialized.
Virtual arrays are tied to specific nplikes.
The arrays are generated via a generator function that is passed to the constructor.
All virtual arrays also required to have a known dtype and shape and `unknown_length` is not currently allowed in the shape.
Some operations (such as trivial slicing) maintain virtualness and return a new virtual array.
Expand All @@ -53,11 +52,11 @@ def __init__(
) -> None:
if not nplike.supports_virtual_arrays:
raise TypeError(
f"Only numpy and cupy nplikes are supported for {type(self).__name__}. Received {type(nplike)}"
f"The nplike {type(nplike)} does not support virtual arrays"
)
if any(not is_integer(dim) for dim in shape):
raise TypeError(
f"Only shapes of integer dimensions are supported for {type(self).__name__}. Received shape {shape}."
f"Only shapes of integer dimensions are supported for {type(self).__name__}. Received shape {shape}"
)

# array metadata
Expand Down Expand Up @@ -160,7 +159,7 @@ def generator(self) -> Callable:
def nplike(self) -> NumpyLike:
if not self._nplike.supports_virtual_arrays:
raise TypeError(
f"Only numpy and cupy nplikes are supported for {type(self).__name__}. Received {type(self._nplike)}"
f"The nplike {type(self._nplike)} does not support virtual arrays"
)
return self._nplike

Expand All @@ -173,7 +172,7 @@ def tolist(self) -> list:
@property
def ctypes(self):
if isinstance((self._nplike), ak._nplikes.cupy.Cupy):
raise AttributeError("Cupy ndarrays do not have a ctypes attribute.")
raise AttributeError("Cupy ndarrays do not have a ctypes attribute")
return self.materialize().ctypes

@property
Expand Down Expand Up @@ -245,7 +244,7 @@ def __getitem__(self, index):
or index.step is unknown_length
):
raise TypeError(
f"{type(self).__name__} does not support slicing with unknown_length while slice {index} was provided."
f"{type(self).__name__} does not support slicing with unknown_length while slice {index} was provided"
)
else:
start, stop, step = index.indices(length)
Expand Down Expand Up @@ -274,13 +273,13 @@ def __int__(self) -> int:
array = self.materialize()
if len(array.shape) == 0:
return int(array)
raise TypeError("Only scalar arrays can be converted to an int.")
raise TypeError("Only scalar arrays can be converted to an int")

def __index__(self) -> int:
array = self.materialize()
if len(array.shape) == 0:
return int(array)
raise TypeError("Only scalar arrays can be used as an index.")
raise TypeError("Only scalar arrays can be used as an index")

def __len__(self) -> int:
if len(self._shape) == 0:
Expand Down
15 changes: 12 additions & 3 deletions tests/test_3364_virtualarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
Expand Down Expand Up @@ -2354,6 +2356,13 @@ def test_numpyarray_argmax(numpyarray, virtual_numpyarray):
assert virtual_numpyarray.is_all_materialized


def test_numpyarray_nanargmax(numpyarray, virtual_numpyarray):
assert not virtual_numpyarray.is_any_materialized
assert ak.nanargmax(virtual_numpyarray, axis=0) == ak.nanargmax(numpyarray, axis=0)
assert virtual_numpyarray.is_any_materialized
assert virtual_numpyarray.is_all_materialized


def test_numpyarray_sort(numpyarray, virtual_numpyarray):
assert not virtual_numpyarray.is_any_materialized
assert ak.array_equal(
Expand All @@ -2376,7 +2385,7 @@ def test_numpyarray_argsort(numpyarray, virtual_numpyarray):

def test_numpyarray_is_none(numpyarray, virtual_numpyarray):
assert not virtual_numpyarray.is_any_materialized
assert np.all(ak.is_none(virtual_numpyarray) == ak.is_none(numpyarray))
assert ak.all(ak.is_none(virtual_numpyarray) == ak.is_none(numpyarray))
assert not virtual_numpyarray.is_any_materialized
assert not virtual_numpyarray.is_all_materialized

Expand Down Expand Up @@ -3088,7 +3097,7 @@ def test_listoffsetarray_argsort(listoffsetarray, virtual_listoffsetarray):

def test_listoffsetarray_is_none(listoffsetarray, virtual_listoffsetarray):
assert not virtual_listoffsetarray.is_any_materialized
assert np.all(ak.is_none(virtual_listoffsetarray) == ak.is_none(listoffsetarray))
assert ak.all(ak.is_none(virtual_listoffsetarray) == ak.is_none(listoffsetarray))
assert virtual_listoffsetarray.is_any_materialized
assert not virtual_listoffsetarray.is_all_materialized

Expand Down Expand Up @@ -4073,7 +4082,7 @@ def test_listarray_argsort(listarray, virtual_listarray):

def test_listarray_is_none(listarray, virtual_listarray):
assert not virtual_listarray.is_any_materialized
assert np.all(ak.is_none(virtual_listarray) == ak.is_none(listarray))
assert ak.all(ak.is_none(virtual_listarray) == ak.is_none(listarray))
assert virtual_listarray.is_any_materialized
assert not virtual_listarray.is_all_materialized

Expand Down
Loading
Loading