Skip to content

Commit 7187e18

Browse files
NaderAlAwardavebayer
authored andcommitted
cuda.parallel: Add optional stream argument to reduce_into() (NVIDIA#3348)
* Add optional stream argument to reduce_into() * Add tests to check for reduce_into() stream behavior * Move protocol related utils to separate file and rework __cuda_stream__ error messages * Fix synchronization issue in stream test and add one more invalid stream test case * Rename cuda stream validation function after removing leading underscore * Unpack values from __cuda_stream__ instead of indexing * Fix linting errors * Handle TypeError when unpacking invalid __cuda_stream__ return * Use stream to allocate cupy memory in new stream test
1 parent aa1ca79 commit 7187e18

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

python/cuda_parallel/cuda/parallel/experimental/_cccl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from numba import cuda, types
1313

14-
from ._utils.cai import get_dtype, is_contiguous
14+
from ._utils.protocols import get_dtype, is_contiguous
1515
from .iterators._iterators import IteratorBase
1616
from .typing import DeviceArrayLike, GpuStruct
1717

python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ def __init__(
8989
raise ValueError("Error building reduce")
9090

9191
def __call__(
92-
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
92+
self,
93+
temp_storage,
94+
d_in,
95+
d_out,
96+
num_items: int,
97+
h_init: np.ndarray | GpuStruct,
98+
stream=None,
9399
):
94100
d_in_cccl = cccl.to_cccl_iter(d_in)
95101
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
@@ -104,7 +110,7 @@ def __call__(
104110
self._ctor_d_in_cccl_type_enum_name,
105111
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
106112
)
107-
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
113+
_dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out))
108114
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
109115
stream_handle = protocols.validate_and_get_stream(stream)
110116
bindings = get_bindings()
@@ -126,7 +132,7 @@ def __call__(
126132
ctypes.c_ulonglong(num_items),
127133
self.op_wrapper.handle(),
128134
cccl.to_cccl_value(h_init),
129-
None,
135+
stream_handle,
130136
)
131137
if error != enums.CUDA_SUCCESS:
132138
raise ValueError("Error reducing")

0 commit comments

Comments
 (0)