From 33b6a96d6658c9dce75dd9b04b1a5ea29efe8daa Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 7 Jan 2025 14:11:01 -0500 Subject: [PATCH] Switch to using gpu_struct decorator --- .../cuda/parallel/experimental/_cccl.py | 7 +- .../parallel/experimental/_structwrapper.py | 86 ---------------- .../experimental/algorithms/reduce.py | 20 ++-- .../cuda/parallel/experimental/gpu_struct.py | 99 +++++++++++++++++++ .../cuda/parallel/experimental/typing.py | 6 ++ python/cuda_parallel/tests/test_reduce.py | 17 ++-- 6 files changed, 130 insertions(+), 105 deletions(-) delete mode 100644 python/cuda_parallel/cuda/parallel/experimental/_structwrapper.py create mode 100644 python/cuda_parallel/cuda/parallel/experimental/gpu_struct.py diff --git a/python/cuda_parallel/cuda/parallel/experimental/_cccl.py b/python/cuda_parallel/cuda/parallel/experimental/_cccl.py index 99acee8fda2..905f2733974 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/_cccl.py +++ b/python/cuda_parallel/cuda/parallel/experimental/_cccl.py @@ -215,4 +215,9 @@ def to_cccl_iter(array_or_iterator) -> Iterator: def host_array_to_value(array: np.ndarray) -> Value: info = _numpy_type_to_info(array.dtype) - return Value(info, array.ctypes.data) + if isinstance(array, np.ndarray): + data = array.ctypes.data + else: + # it's a gpudataclass: + data = ctypes.cast(ctypes.pointer(array._data), ctypes.c_void_p) + return Value(info, data) diff --git a/python/cuda_parallel/cuda/parallel/experimental/_structwrapper.py b/python/cuda_parallel/cuda/parallel/experimental/_structwrapper.py deleted file mode 100644 index 198b4744716..00000000000 --- a/python/cuda_parallel/cuda/parallel/experimental/_structwrapper.py +++ /dev/null @@ -1,86 +0,0 @@ -import dataclasses -import operator - -import numba -import numpy as np -from numba import types -from numba.core import cgutils -from numba.core.extending import ( - make_attribute_wrapper, - models, - register_model, - typeof_impl, -) -from numba.core.typing import signature -from numba.core.typing.templates import AttributeTemplate, CallableTemplate -from numba.cuda.cudadecl import registry as cuda_registry -from numba.cuda.cudaimpl import lower as cuda_lower - - -def wrap_struct(dtype: np.dtype) -> numba.types.Type: - """ - Wrap the given numpy structure dtype in a numba type. - """ - StructWrapper = dataclasses.make_dataclass( - "StructWrapper", - [(name, dt) for name, (dt, _) in dtype.fields.items()], # type: ignore - ) - - class StructWrapperType(types.Type): - def __init__(self): - super().__init__(name="StructWrapper") - - struct_wrapper_type = StructWrapperType() - - @typeof_impl.register(StructWrapper) - def typeof_struct_wrapper(val, c): - return StructWrapperType() - - class StructWrapperAttrsTemplate(AttributeTemplate): - pass - - fields = dataclasses.fields(StructWrapper) - for f in fields: - name = f.name - typ = f.type - - def resolver(self, wrapper): - return numba.from_dtype(typ) - - setattr(StructWrapperAttrsTemplate, f"resolve_{name}", resolver) - - @cuda_registry.register_attr - class StructWrapperAttrs(StructWrapperAttrsTemplate): - key = struct_wrapper_type - - @register_model(StructWrapperType) - class StructWrapperModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [(f.name, numba.from_dtype(f.type)) for f in fields] - super().__init__(dmm, fe_type, members) - - for f in fields: - make_attribute_wrapper(StructWrapperType, f.name, f.name) - - @cuda_registry.register_global(operator.getitem) - class StructWrapperGetitem(CallableTemplate): - def generic(self): - def typer(obj, index): - if not isinstance(obj, StructWrapperType): - return None - if not isinstance(index, types.StringLiteral): - return None - retty = numba.from_dtype(dtype[index.literal_value]) - return signature(retty, obj, index) - - return typer - - @cuda_lower(operator.getitem, struct_wrapper_type, types.StringLiteral) - def struct_wrapper_getitem(context, builder, sig, args): - obj_arg, index_arg = args - obj = cgutils.create_struct_proxy(struct_wrapper_type)( - context, builder, value=obj_arg - ) - return getattr(obj, sig.args[1].literal_value) - - return struct_wrapper_type diff --git a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py index 84a2ed65474..41c03fe3498 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py +++ b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py @@ -16,19 +16,14 @@ from .. import _cccl as cccl from .._bindings import get_bindings, get_paths from .._caching import CachableFunction, cache_with_key -from .._structwrapper import wrap_struct from .._utils import cai from ..iterators._iterators import IteratorBase -from ..typing import DeviceArrayLike +from ..typing import DeviceArrayLike, GpuStruct class _Op: - def __init__(self, dtype: np.dtype, op: Callable): - # if h_init is a struct, wrap it in a Record type: - if dtype.names is not None: - value_type = wrap_struct(dtype) - else: - value_type = numba.from_dtype(dtype) + def __init__(self, h_init: np.ndarray | GpuStruct, op: Callable): + value_type = numba.typeof(h_init) self.ltoir, _ = cuda.compile(op, sig=(value_type, value_type), output="ltoir") self.name = op.__name__.encode("utf-8") @@ -56,7 +51,7 @@ def __init__( d_in: DeviceArrayLike | IteratorBase, d_out: DeviceArrayLike, op: Callable, - h_init: np.ndarray, + h_init: np.ndarray | GpuStruct, ): d_in_cccl = cccl.to_cccl_iter(d_in) self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name( @@ -67,11 +62,10 @@ def __init__( cc_major, cc_minor = cuda.get_current_device().compute_capability cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths() bindings = get_bindings() - self.op_wrapper = _Op(h_init.dtype, op) + self.op_wrapper = _Op(h_init, op) d_out_cccl = cccl.to_cccl_iter(d_out) self.build_result = cccl.DeviceReduceBuildResult() - # TODO Figure out caching error = bindings.cccl_device_reduce_build( ctypes.byref(self.build_result), d_in_cccl, @@ -88,7 +82,9 @@ def __init__( if error != enums.CUDA_SUCCESS: raise ValueError("Error building reduce") - def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray): + def __call__( + self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct + ): d_in_cccl = cccl.to_cccl_iter(d_in) if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR: assert num_items is not None diff --git a/python/cuda_parallel/cuda/parallel/experimental/gpu_struct.py b/python/cuda_parallel/cuda/parallel/experimental/gpu_struct.py new file mode 100644 index 00000000000..fdd944488b7 --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/gpu_struct.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from dataclasses import fields as dataclass_fields + +import numba +import numpy as np +from numba.core import cgutils +from numba.core.extending import ( + make_attribute_wrapper, + models, + register_model, + typeof_impl, +) +from numba.core.typing import signature as nb_signature +from numba.core.typing.templates import AttributeTemplate, ConcreteTemplate +from numba.cuda.cudadecl import registry as cuda_registry +from numba.extending import lower_builtin + +from .typing import GpuStruct + + +def gpu_struct(this: type) -> GpuStruct: + anns = getattr(this, "__annotations__", {}) + + # set the .dtype attribute on the class for numpy compatibility: + setattr(this, "dtype", np.dtype(list(anns.items()))) + + # define __post_init__ to create a ctypes object from the fields, + # and keep a reference to it in the `._data` attribute. + def __post_init__(self): + ctypes_typ = np.ctypeslib.as_ctypes_type(this.dtype) + self._data = ctypes_typ(*(getattr(self, name) for name in this.dtype.names)) + + setattr(this, "__post_init__", __post_init__) + + # create a dataclass: + this = dataclass(this) + fields = dataclass_fields(this) + + # define a numba type corresponding to the dataclass: + class ThisType(numba.types.Type): + def __init__(self): + super().__init__(name=this.__name__) + + this_type = ThisType() + + @typeof_impl.register(this) + def typeof_this(val, c): + return ThisType() + + # Data model corresponding to ThisType: + @register_model(ThisType) + class ThisModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [(field.name, numba.from_dtype(field.type)) for field in fields] + super().__init__(dmm, fe_type, members) + + # Typing for accessing attributes (fields) of the dataclass: + class ThisAttrsTemplate(AttributeTemplate): + pass + + for field in fields: + typ = field.type + name = field.name + + def resolver(self, this): + return numba.from_dtype(typ) + + setattr(ThisAttrsTemplate, f"resolve_{name}", resolver) + + @cuda_registry.register_attr + class ThisAttrs(ThisAttrsTemplate): + key = this_type + + # Lowering for attribute access: + for field in fields: + make_attribute_wrapper(ThisType, field.name, field.name) + + # Register typing for constructor. + @cuda_registry.register + class TypeConstructor(ConcreteTemplate): + key = this + cases = [ + nb_signature(this_type, *[numba.from_dtype(field.type) for field in fields]) + ] + + cuda_registry.register_global(this, numba.types.Function(TypeConstructor)) + + def type_constructor(context, builder, sig, args): + ty = sig.return_type + retval = cgutils.create_struct_proxy(ty)(context, builder) + for field, val in zip(fields, args): + setattr(retval, field.name, val) + return retval._getvalue() + + lower_builtin(this, *[numba.from_dtype(field.type) for field in fields])( + type_constructor + ) + + return this diff --git a/python/cuda_parallel/cuda/parallel/experimental/typing.py b/python/cuda_parallel/cuda/parallel/experimental/typing.py index 1c4e9c9975f..7c8da5fe577 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/typing.py +++ b/python/cuda_parallel/cuda/parallel/experimental/typing.py @@ -1,3 +1,5 @@ +from typing import Any + from typing_extensions import ( Protocol, ) # TODO: typing_extensions required for Python 3.7 docs env @@ -10,3 +12,7 @@ class DeviceArrayLike(Protocol): """ __cuda_array_interface__: dict + + +# return type of @gpu_struct +GpuStruct = Any diff --git a/python/cuda_parallel/tests/test_reduce.py b/python/cuda_parallel/tests/test_reduce.py index 50fd5e0ad2b..1b56550eb8a 100644 --- a/python/cuda_parallel/tests/test_reduce.py +++ b/python/cuda_parallel/tests/test_reduce.py @@ -12,6 +12,7 @@ import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators +from cuda.parallel.experimental.gpu_struct import gpu_struct def random_int(shape, dtype): @@ -553,15 +554,19 @@ def binary_op(x, y): def test_reduce_struct_type(): - def max_g_value(x, y): - return x if x["g"] > y["g"] else y + @gpu_struct + class Pixel: + r: np.int32 + g: np.int32 + b: np.int32 - dtype = np.dtype([("r", "int32"), ("g", "int32"), ("b", "int32")]) - d_rgb = cp.random.randint(0, 256, (10, 3), dtype=cp.int32).view(dtype) + def max_g_value(x, y): + return x if x.g > y.g else y - d_out = cp.zeros(1, dtype) + d_rgb = cp.random.randint(0, 256, (10, 3), dtype=np.int32).view(Pixel.dtype) + d_out = cp.zeros(1, Pixel.dtype) - h_init = np.asarray([(0, 0, 0)], dtype=dtype) + h_init = Pixel(0, 0, 0) reduce_into = algorithms.reduce_into(d_rgb, d_out, max_g_value, h_init) temp_storage_bytes = reduce_into(None, d_rgb, d_out, len(d_rgb), h_init)