Skip to content

Commit

Permalink
cuda.parallel: Support structured types as algorithm inputs (NVIDIA#3218
Browse files Browse the repository at this point in the history
)

* Introduce gpu_struct decorator and typing

* Enable `reduce` to accept arrays of structs as inputs

* Add test for reducing arrays-of-struct

* Update documentation

* Use a numpy array rather than ctypes object

* Change zeros -> empty for output array and temp storage

* Add a TODO for typing GpuStruct

* Documentation udpates

* Remove test_reduce_struct_type from test_reduce.py

* Revert to `to_cccl_value()` accepting ndarray + GpuStruct

* Bump copyrights

---------

Co-authored-by: Ashwin Srinath <[email protected]>
  • Loading branch information
2 people authored and davebayer committed Jan 22, 2025
1 parent 5a0094c commit 34dffe5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/cuda_parallel/cuda/parallel/experimental/_cccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from numba import cuda, types

from ._utils.protocols import get_dtype, is_contiguous
from ._utils.cai import get_dtype, is_contiguous
from .iterators._iterators import IteratorBase
from .typing import DeviceArrayLike, GpuStruct

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,7 @@ def __init__(
raise ValueError("Error building reduce")

def __call__(
self,
temp_storage,
d_in,
d_out,
num_items: int,
h_init: np.ndarray | GpuStruct,
stream=None,
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
):
d_in_cccl = cccl.to_cccl_iter(d_in)
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
Expand All @@ -110,7 +104,7 @@ def __call__(
self._ctor_d_in_cccl_type_enum_name,
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
)
_dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out))
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
stream_handle = protocols.validate_and_get_stream(stream)
bindings = get_bindings()
Expand All @@ -132,7 +126,7 @@ def __call__(
ctypes.c_ulonglong(num_items),
self.op_wrapper.handle(),
cccl.to_cccl_value(h_init),
stream_handle,
None,
)
if error != enums.CUDA_SUCCESS:
raise ValueError("Error reducing")
Expand Down

0 comments on commit 34dffe5

Please sign in to comment.