Skip to content

Commit

Permalink
cuda.parallel: Support structured types as algorithm inputs (NVIDIA#3218
Browse files Browse the repository at this point in the history
)

* Introduce gpu_struct decorator and typing

* Enable `reduce` to accept arrays of structs as inputs

* Add test for reducing arrays-of-struct

* Update documentation

* Use a numpy array rather than ctypes object

* Change zeros -> empty for output array and temp storage

* Add a TODO for typing GpuStruct

* Documentation udpates

* Remove test_reduce_struct_type from test_reduce.py

* Revert to `to_cccl_value()` accepting ndarray + GpuStruct

* Bump copyrights

---------

Co-authored-by: Ashwin Srinath <[email protected]>
  • Loading branch information
shwina and shwina committed Jan 16, 2025
1 parent 4d49a97 commit 8428c3a
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 22 deletions.
6 changes: 6 additions & 0 deletions docs/cuda_parallel/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ Iterators
:members:
:undoc-members:
:imported-members:

Utilities
---------

.. automodule:: cuda.parallel.experimental.struct
:members:
21 changes: 16 additions & 5 deletions python/cuda_parallel/cuda/parallel/experimental/_cccl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from __future__ import annotations

import ctypes
import functools
Expand All @@ -10,8 +11,9 @@
import numpy as np
from numba import cuda, types

from ._utils.cai import DeviceArrayLike, get_dtype, is_contiguous
from ._utils.cai import get_dtype, is_contiguous
from .iterators._iterators import IteratorBase
from .typing import DeviceArrayLike, GpuStruct


# MUST match `cccl_type_enum` in c/include/cccl/c/types.h
Expand Down Expand Up @@ -121,6 +123,10 @@ def _type_to_enum(numba_type: types.Type) -> TypeEnum:
def _numba_type_to_info(numba_type: types.Type) -> TypeInfo:
context = cuda.descriptor.cuda_target.target_context
value_type = context.get_value_type(numba_type)
if isinstance(numba_type, types.Record):
# then `value_type` is a pointer and we need the
# alignment of the pointee.
value_type = value_type.pointee
size = value_type.get_abi_size(context.target_data)
alignment = value_type.get_abi_alignment(context.target_data)
return TypeInfo(size, alignment, _type_to_enum(numba_type))
Expand Down Expand Up @@ -209,6 +215,11 @@ def to_cccl_iter(array_or_iterator) -> Iterator:
return _device_array_to_cccl_iter(array_or_iterator)


def host_array_to_value(array: np.ndarray) -> Value:
info = _numpy_type_to_info(array.dtype)
return Value(info, array.ctypes.data)
def to_cccl_value(array_or_struct: np.ndarray | GpuStruct) -> Value:
if isinstance(array_or_struct, np.ndarray):
info = _numpy_type_to_info(array_or_struct.dtype)
data = ctypes.cast(array_or_struct.ctypes.data, ctypes.c_void_p)
return Value(info, data)
else:
# it's a GpuStruct, use the array underlying it
return to_cccl_value(array_or_struct._data)
11 changes: 9 additions & 2 deletions python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -15,7 +15,14 @@


def get_dtype(arr: DeviceArrayLike) -> np.dtype:
return np.dtype(arr.__cuda_array_interface__["typestr"])
typestr = arr.__cuda_array_interface__["typestr"]

if typestr.startswith("|V"):
# it's a structured dtype, use the descr field:
return np.dtype(arr.__cuda_array_interface__["descr"])
else:
# a simple dtype, use the typestr field:
return np.dtype(typestr)


def get_strides(arr: DeviceArrayLike) -> Optional[Tuple]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -18,15 +18,16 @@
from .._caching import CachableFunction, cache_with_key
from .._utils import cai
from ..iterators._iterators import IteratorBase
from ..typing import DeviceArrayLike
from ..typing import DeviceArrayLike, GpuStruct


class _Op:
def __init__(self, dtype: np.dtype, op: Callable):
value_type = numba.from_dtype(dtype)
self.ltoir, _ = cuda.compile(
op, sig=value_type(value_type, value_type), output="ltoir"
)
def __init__(self, h_init: np.ndarray | GpuStruct, op: Callable):
if isinstance(h_init, np.ndarray):
value_type = numba.from_dtype(h_init.dtype)
else:
value_type = numba.typeof(h_init)
self.ltoir, _ = cuda.compile(op, sig=(value_type, value_type), output="ltoir")
self.name = op.__name__.encode("utf-8")

def handle(self) -> cccl.Op:
Expand All @@ -53,7 +54,7 @@ def __init__(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
h_init: np.ndarray | GpuStruct,
):
d_in_cccl = cccl.to_cccl_iter(d_in)
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name(
Expand All @@ -64,17 +65,16 @@ def __init__(
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
bindings = get_bindings()
self.op_wrapper = _Op(h_init.dtype, op)
self.op_wrapper = _Op(h_init, op)
d_out_cccl = cccl.to_cccl_iter(d_out)
self.build_result = cccl.DeviceReduceBuildResult()

# TODO Figure out caching
error = bindings.cccl_device_reduce_build(
ctypes.byref(self.build_result),
d_in_cccl,
d_out_cccl,
self.op_wrapper.handle(),
cccl.host_array_to_value(h_init),
cccl.to_cccl_value(h_init),
cc_major,
cc_minor,
ctypes.c_char_p(cub_path),
Expand All @@ -85,7 +85,9 @@ def __init__(
if error != enums.CUDA_SUCCESS:
raise ValueError("Error building reduce")

def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray):
def __call__(
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
):
d_in_cccl = cccl.to_cccl_iter(d_in)
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
assert num_items is not None
Expand All @@ -99,7 +101,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray
self._ctor_d_in_cccl_type_enum_name,
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
)
_dtype_validation(self._ctor_d_out_dtype, d_out.dtype)
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
bindings = get_bindings()
if temp_storage is None:
Expand All @@ -119,7 +121,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray
d_out_cccl,
ctypes.c_ulonglong(num_items),
self.op_wrapper.handle(),
cccl.host_array_to_value(h_init),
cccl.to_cccl_value(h_init),
None,
)
if error != enums.CUDA_SUCCESS:
Expand Down
141 changes: 141 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from dataclasses import dataclass
from dataclasses import fields as dataclass_fields
from typing import Type

import numba
import numpy as np
from numba.core import cgutils
from numba.core.extending import (
make_attribute_wrapper,
models,
register_model,
typeof_impl,
)
from numba.core.typing import signature as nb_signature
from numba.core.typing.templates import AttributeTemplate, ConcreteTemplate
from numba.cuda.cudadecl import registry as cuda_registry
from numba.extending import lower_builtin

from .typing import GpuStruct


def gpu_struct(this: type) -> Type[GpuStruct]:
"""
Defines the given class as being a GpuStruct.
A GpuStruct represents a value composed of one or more other
values, and is defined as a class with annotated fields (similar
to a dataclass). The type of each field must be a subclass of
`np.number`, like `np.int32` or `np.float64`.
Arrays of GPUStruct objects can be used as inputs to cuda.parallel
algorithms.
Example:
The code snippet below shows how to use `gpu_struct` to define
a `Pixel` type (composed of `r`, `g` and `b` values), and perform
a reduction on an array of `Pixel` objects to identify the one
with the largest `g` component:
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
:dedent:
:start-after: example-begin reduce-struct
:end-before: example-end reduce-struct
"""
# Implementation-wise, @gpu_struct creates and registers a
# corresponding numba type to the given type, so that it can be
# used within device functions (e.g., unary and binary operations)
# The numba typing/lowering code is largely based on the example
# in # https://github.com/gmarkall/numba-accelerated-udfs/blob/e78876c5d3ace9e1409d37029bd79b2a8b706c62/filigree/numba_extension.py

anns = getattr(this, "__annotations__", {})

# Set a .dtype attribute on the class that returns the
# corresponding numpy structure dtype. This makes it convenient to
# create CuPy/NumPy arrays of this type.
setattr(this, "dtype", np.dtype(list(anns.items())))

# Define __post_init__ to create a numpy struct from the fields,
# and keep a reference to it in the `._data` attribute. The data
# underlying this array is what is ultimately passed to the C
# library, and we need to keep a reference to it for the lifetime
# of the object.
def __post_init__(self):
self._data = np.array(
[tuple(getattr(self, name) for name in anns)], dtype=self.dtype
)

setattr(this, "__post_init__", __post_init__)

# Wrap `this` in a dataclass for convenience:
this = dataclass(this)
fields = dataclass_fields(this)

# Define a numba type corresponding to `this`:
class ThisType(numba.types.Type):
def __init__(self):
super().__init__(name=this.__name__)

this_type = ThisType()

@typeof_impl.register(this)
def typeof_this(val, c):
return ThisType()

# Data model corresponding to ThisType:
@register_model(ThisType)
class ThisModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [(field.name, numba.from_dtype(field.type)) for field in fields]
super().__init__(dmm, fe_type, members)

# Typing for accessing attributes (fields) of the dataclass:
class ThisAttrsTemplate(AttributeTemplate):
pass

for field in fields:
typ = field.type
name = field.name

def resolver(self, this):
return numba.from_dtype(typ)

setattr(ThisAttrsTemplate, f"resolve_{name}", resolver)

@cuda_registry.register_attr
class ThisAttrs(ThisAttrsTemplate):
key = this_type

# Lowering for attribute access:
for field in fields:
make_attribute_wrapper(ThisType, field.name, field.name)

# Typing for constructor.
@cuda_registry.register
class ThisConstructor(ConcreteTemplate):
key = this
cases = [
nb_signature(this_type, *[numba.from_dtype(field.type) for field in fields])
]

cuda_registry.register_global(this, numba.types.Function(ThisConstructor))

# Lowering for constructor:
def this_constructor(context, builder, sig, args):
ty = sig.return_type
retval = cgutils.create_struct_proxy(ty)(context, builder)
for field, val in zip(fields, args):
setattr(retval, field.name, val)
return retval._getvalue()

lower_builtin(this, *[numba.from_dtype(field.type) for field in fields])(
this_constructor
)

return this
15 changes: 15 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any

from typing_extensions import (
Protocol,
) # TODO: typing_extensions required for Python 3.7 docs env
Expand All @@ -10,3 +17,11 @@ class DeviceArrayLike(Protocol):
"""

__cuda_array_interface__: dict


# TODO: type GpuStruct appropriately. It should be any type that has
# been decorated with `@gpu_struct`.
GpuStruct = Any
GpuStruct.__doc__ = """\
Type of instances of classes decorated with @gpu_struct.
"""
37 changes: 36 additions & 1 deletion python/cuda_parallel/tests/test_reduce_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

Expand Down Expand Up @@ -178,3 +178,38 @@ def square_op(a):
)
assert (d_output == expected_output).all()
# example-end transform-iterator


def test_reduce_struct_type():
# example-begin reduce-struct
import cupy as cp
import numpy as np

from cuda.parallel.experimental import algorithms
from cuda.parallel.experimental.struct import gpu_struct

@gpu_struct
class Pixel:
r: np.int32
g: np.int32
b: np.int32

def max_g_value(x, y):
return x if x.g > y.g else y

d_rgb = cp.random.randint(0, 256, (10, 3), dtype=np.int32).view(Pixel.dtype)
d_out = cp.empty(1, Pixel.dtype)

h_init = Pixel(0, 0, 0)

reduce_into = algorithms.reduce_into(d_rgb, d_out, max_g_value, h_init)
temp_storage_bytes = reduce_into(None, d_rgb, d_out, len(d_rgb), h_init)

d_temp_storage = cp.empty(temp_storage_bytes, dtype=np.uint8)
_ = reduce_into(d_temp_storage, d_rgb, d_out, len(d_rgb), h_init)

h_rgb = d_rgb.get()
expected = h_rgb[h_rgb.view("int32")[:, 1].argmax()]

np.testing.assert_equal(expected["g"], d_out.get()["g"])
# example-end reduce-struct

0 comments on commit 8428c3a

Please sign in to comment.