diff --git a/docs/cuda_parallel/index.rst b/docs/cuda_parallel/index.rst index e494fb1e323..c54feb81f85 100644 --- a/docs/cuda_parallel/index.rst +++ b/docs/cuda_parallel/index.rst @@ -22,3 +22,9 @@ Iterators :members: :undoc-members: :imported-members: + +Utilities +--------- + +.. automodule:: cuda.parallel.experimental.struct + :members: diff --git a/python/cuda_parallel/cuda/parallel/experimental/_cccl.py b/python/cuda_parallel/cuda/parallel/experimental/_cccl.py index e09191dac2c..955274d66e4 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/_cccl.py +++ b/python/cuda_parallel/cuda/parallel/experimental/_cccl.py @@ -1,7 +1,8 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from __future__ import annotations import ctypes import functools @@ -10,8 +11,9 @@ import numpy as np from numba import cuda, types -from ._utils.cai import DeviceArrayLike, get_dtype, is_contiguous +from ._utils.cai import get_dtype, is_contiguous from .iterators._iterators import IteratorBase +from .typing import DeviceArrayLike, GpuStruct # MUST match `cccl_type_enum` in c/include/cccl/c/types.h @@ -121,6 +123,10 @@ def _type_to_enum(numba_type: types.Type) -> TypeEnum: def _numba_type_to_info(numba_type: types.Type) -> TypeInfo: context = cuda.descriptor.cuda_target.target_context value_type = context.get_value_type(numba_type) + if isinstance(numba_type, types.Record): + # then `value_type` is a pointer and we need the + # alignment of the pointee. + value_type = value_type.pointee size = value_type.get_abi_size(context.target_data) alignment = value_type.get_abi_alignment(context.target_data) return TypeInfo(size, alignment, _type_to_enum(numba_type)) @@ -209,6 +215,11 @@ def to_cccl_iter(array_or_iterator) -> Iterator: return _device_array_to_cccl_iter(array_or_iterator) -def host_array_to_value(array: np.ndarray) -> Value: - info = _numpy_type_to_info(array.dtype) - return Value(info, array.ctypes.data) +def to_cccl_value(array_or_struct: np.ndarray | GpuStruct) -> Value: + if isinstance(array_or_struct, np.ndarray): + info = _numpy_type_to_info(array_or_struct.dtype) + data = ctypes.cast(array_or_struct.ctypes.data, ctypes.c_void_p) + return Value(info, data) + else: + # it's a GpuStruct, use the array underlying it + return to_cccl_value(array_or_struct._data) diff --git a/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py b/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py index 4d435171aad..3a3391f93f2 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py +++ b/python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -15,7 +15,14 @@ def get_dtype(arr: DeviceArrayLike) -> np.dtype: - return np.dtype(arr.__cuda_array_interface__["typestr"]) + typestr = arr.__cuda_array_interface__["typestr"] + + if typestr.startswith("|V"): + # it's a structured dtype, use the descr field: + return np.dtype(arr.__cuda_array_interface__["descr"]) + else: + # a simple dtype, use the typestr field: + return np.dtype(typestr) def get_strides(arr: DeviceArrayLike) -> Optional[Tuple]: diff --git a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py index 7a1a26bbc9f..10a9cf12051 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py +++ b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -18,15 +18,16 @@ from .._caching import CachableFunction, cache_with_key from .._utils import cai from ..iterators._iterators import IteratorBase -from ..typing import DeviceArrayLike +from ..typing import DeviceArrayLike, GpuStruct class _Op: - def __init__(self, dtype: np.dtype, op: Callable): - value_type = numba.from_dtype(dtype) - self.ltoir, _ = cuda.compile( - op, sig=value_type(value_type, value_type), output="ltoir" - ) + def __init__(self, h_init: np.ndarray | GpuStruct, op: Callable): + if isinstance(h_init, np.ndarray): + value_type = numba.from_dtype(h_init.dtype) + else: + value_type = numba.typeof(h_init) + self.ltoir, _ = cuda.compile(op, sig=(value_type, value_type), output="ltoir") self.name = op.__name__.encode("utf-8") def handle(self) -> cccl.Op: @@ -53,7 +54,7 @@ def __init__( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike, op: Callable, - h_init: np.ndarray, + h_init: np.ndarray | GpuStruct, ): d_in_cccl = cccl.to_cccl_iter(d_in) self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name( @@ -64,17 +65,16 @@ def __init__( cc_major, cc_minor = cuda.get_current_device().compute_capability cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths() bindings = get_bindings() - self.op_wrapper = _Op(h_init.dtype, op) + self.op_wrapper = _Op(h_init, op) d_out_cccl = cccl.to_cccl_iter(d_out) self.build_result = cccl.DeviceReduceBuildResult() - # TODO Figure out caching error = bindings.cccl_device_reduce_build( ctypes.byref(self.build_result), d_in_cccl, d_out_cccl, self.op_wrapper.handle(), - cccl.host_array_to_value(h_init), + cccl.to_cccl_value(h_init), cc_major, cc_minor, ctypes.c_char_p(cub_path), @@ -85,7 +85,9 @@ def __init__( if error != enums.CUDA_SUCCESS: raise ValueError("Error building reduce") - def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray): + def __call__( + self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct + ): d_in_cccl = cccl.to_cccl_iter(d_in) if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR: assert num_items is not None @@ -99,7 +101,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray self._ctor_d_in_cccl_type_enum_name, cccl.type_enum_as_name(d_in_cccl.value_type.type.value), ) - _dtype_validation(self._ctor_d_out_dtype, d_out.dtype) + _dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out)) _dtype_validation(self._ctor_init_dtype, h_init.dtype) bindings = get_bindings() if temp_storage is None: @@ -119,7 +121,7 @@ def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray d_out_cccl, ctypes.c_ulonglong(num_items), self.op_wrapper.handle(), - cccl.host_array_to_value(h_init), + cccl.to_cccl_value(h_init), None, ) if error != enums.CUDA_SUCCESS: diff --git a/python/cuda_parallel/cuda/parallel/experimental/struct.py b/python/cuda_parallel/cuda/parallel/experimental/struct.py new file mode 100644 index 00000000000..3ca09d39676 --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/struct.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from dataclasses import fields as dataclass_fields +from typing import Type + +import numba +import numpy as np +from numba.core import cgutils +from numba.core.extending import ( + make_attribute_wrapper, + models, + register_model, + typeof_impl, +) +from numba.core.typing import signature as nb_signature +from numba.core.typing.templates import AttributeTemplate, ConcreteTemplate +from numba.cuda.cudadecl import registry as cuda_registry +from numba.extending import lower_builtin + +from .typing import GpuStruct + + +def gpu_struct(this: type) -> Type[GpuStruct]: + """ + Defines the given class as being a GpuStruct. + + A GpuStruct represents a value composed of one or more other + values, and is defined as a class with annotated fields (similar + to a dataclass). The type of each field must be a subclass of + `np.number`, like `np.int32` or `np.float64`. + + Arrays of GPUStruct objects can be used as inputs to cuda.parallel + algorithms. + + Example: + The code snippet below shows how to use `gpu_struct` to define + a `Pixel` type (composed of `r`, `g` and `b` values), and perform + a reduction on an array of `Pixel` objects to identify the one + with the largest `g` component: + + .. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py + :language: python + :dedent: + :start-after: example-begin reduce-struct + :end-before: example-end reduce-struct + """ + # Implementation-wise, @gpu_struct creates and registers a + # corresponding numba type to the given type, so that it can be + # used within device functions (e.g., unary and binary operations) + # The numba typing/lowering code is largely based on the example + # in # https://github.com/gmarkall/numba-accelerated-udfs/blob/e78876c5d3ace9e1409d37029bd79b2a8b706c62/filigree/numba_extension.py + + anns = getattr(this, "__annotations__", {}) + + # Set a .dtype attribute on the class that returns the + # corresponding numpy structure dtype. This makes it convenient to + # create CuPy/NumPy arrays of this type. + setattr(this, "dtype", np.dtype(list(anns.items()))) + + # Define __post_init__ to create a numpy struct from the fields, + # and keep a reference to it in the `._data` attribute. The data + # underlying this array is what is ultimately passed to the C + # library, and we need to keep a reference to it for the lifetime + # of the object. + def __post_init__(self): + self._data = np.array( + [tuple(getattr(self, name) for name in anns)], dtype=self.dtype + ) + + setattr(this, "__post_init__", __post_init__) + + # Wrap `this` in a dataclass for convenience: + this = dataclass(this) + fields = dataclass_fields(this) + + # Define a numba type corresponding to `this`: + class ThisType(numba.types.Type): + def __init__(self): + super().__init__(name=this.__name__) + + this_type = ThisType() + + @typeof_impl.register(this) + def typeof_this(val, c): + return ThisType() + + # Data model corresponding to ThisType: + @register_model(ThisType) + class ThisModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [(field.name, numba.from_dtype(field.type)) for field in fields] + super().__init__(dmm, fe_type, members) + + # Typing for accessing attributes (fields) of the dataclass: + class ThisAttrsTemplate(AttributeTemplate): + pass + + for field in fields: + typ = field.type + name = field.name + + def resolver(self, this): + return numba.from_dtype(typ) + + setattr(ThisAttrsTemplate, f"resolve_{name}", resolver) + + @cuda_registry.register_attr + class ThisAttrs(ThisAttrsTemplate): + key = this_type + + # Lowering for attribute access: + for field in fields: + make_attribute_wrapper(ThisType, field.name, field.name) + + # Typing for constructor. + @cuda_registry.register + class ThisConstructor(ConcreteTemplate): + key = this + cases = [ + nb_signature(this_type, *[numba.from_dtype(field.type) for field in fields]) + ] + + cuda_registry.register_global(this, numba.types.Function(ThisConstructor)) + + # Lowering for constructor: + def this_constructor(context, builder, sig, args): + ty = sig.return_type + retval = cgutils.create_struct_proxy(ty)(context, builder) + for field, val in zip(fields, args): + setattr(retval, field.name, val) + return retval._getvalue() + + lower_builtin(this, *[numba.from_dtype(field.type) for field in fields])( + this_constructor + ) + + return this diff --git a/python/cuda_parallel/cuda/parallel/experimental/typing.py b/python/cuda_parallel/cuda/parallel/experimental/typing.py index 1c4e9c9975f..38b63a60e51 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/typing.py +++ b/python/cuda_parallel/cuda/parallel/experimental/typing.py @@ -1,3 +1,10 @@ +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any + from typing_extensions import ( Protocol, ) # TODO: typing_extensions required for Python 3.7 docs env @@ -10,3 +17,11 @@ class DeviceArrayLike(Protocol): """ __cuda_array_interface__: dict + + +# TODO: type GpuStruct appropriately. It should be any type that has +# been decorated with `@gpu_struct`. +GpuStruct = Any +GpuStruct.__doc__ = """\ + Type of instances of classes decorated with @gpu_struct. +""" diff --git a/python/cuda_parallel/tests/test_reduce_api.py b/python/cuda_parallel/tests/test_reduce_api.py index c8c20f51cd7..c920824fa54 100644 --- a/python/cuda_parallel/tests/test_reduce_api.py +++ b/python/cuda_parallel/tests/test_reduce_api.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -178,3 +178,38 @@ def square_op(a): ) assert (d_output == expected_output).all() # example-end transform-iterator + + +def test_reduce_struct_type(): + # example-begin reduce-struct + import cupy as cp + import numpy as np + + from cuda.parallel.experimental import algorithms + from cuda.parallel.experimental.struct import gpu_struct + + @gpu_struct + class Pixel: + r: np.int32 + g: np.int32 + b: np.int32 + + def max_g_value(x, y): + return x if x.g > y.g else y + + d_rgb = cp.random.randint(0, 256, (10, 3), dtype=np.int32).view(Pixel.dtype) + d_out = cp.empty(1, Pixel.dtype) + + h_init = Pixel(0, 0, 0) + + reduce_into = algorithms.reduce_into(d_rgb, d_out, max_g_value, h_init) + temp_storage_bytes = reduce_into(None, d_rgb, d_out, len(d_rgb), h_init) + + d_temp_storage = cp.empty(temp_storage_bytes, dtype=np.uint8) + _ = reduce_into(d_temp_storage, d_rgb, d_out, len(d_rgb), h_init) + + h_rgb = d_rgb.get() + expected = h_rgb[h_rgb.view("int32")[:, 1].argmax()] + + np.testing.assert_equal(expected["g"], d_out.get()["g"]) + # example-end reduce-struct