Skip to content

Commit

Permalink
cuda.parallel: Add optional stream argument to reduce_into() (NVIDIA#…
Browse files Browse the repository at this point in the history
…3348)

* Add optional stream argument to reduce_into()

* Add tests to check for reduce_into() stream behavior

* Move protocol related utils to separate file and rework __cuda_stream__ error messages

* Fix synchronization issue in stream test and add one more invalid stream test case

* Rename cuda stream validation function after removing leading underscore

* Unpack values from __cuda_stream__ instead of indexing

* Fix linting errors

* Handle TypeError when unpacking invalid __cuda_stream__ return

* Use stream to allocate cupy memory in new stream test
  • Loading branch information
NaderAlAwar authored and davebayer committed Jan 23, 2025
1 parent aa1ca79 commit 7187e18
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/cuda_parallel/cuda/parallel/experimental/_cccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from numba import cuda, types

from ._utils.cai import get_dtype, is_contiguous
from ._utils.protocols import get_dtype, is_contiguous
from .iterators._iterators import IteratorBase
from .typing import DeviceArrayLike, GpuStruct

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def __init__(
raise ValueError("Error building reduce")

def __call__(
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
self,
temp_storage,
d_in,
d_out,
num_items: int,
h_init: np.ndarray | GpuStruct,
stream=None,
):
d_in_cccl = cccl.to_cccl_iter(d_in)
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
Expand All @@ -104,7 +110,7 @@ def __call__(
self._ctor_d_in_cccl_type_enum_name,
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
)
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
_dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out))
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
stream_handle = protocols.validate_and_get_stream(stream)
bindings = get_bindings()
Expand All @@ -126,7 +132,7 @@ def __call__(
ctypes.c_ulonglong(num_items),
self.op_wrapper.handle(),
cccl.to_cccl_value(h_init),
None,
stream_handle,
)
if error != enums.CUDA_SUCCESS:
raise ValueError("Error reducing")
Expand Down

0 comments on commit 7187e18

Please sign in to comment.