Skip to content

Commit 2a53029

Browse files
authored
cuda.parallel: Add optional stream argument to reduce_into() (#3348)
* Add optional stream argument to reduce_into() * Add tests to check for reduce_into() stream behavior * Move protocol related utils to separate file and rework __cuda_stream__ error messages * Fix synchronization issue in stream test and add one more invalid stream test case * Rename cuda stream validation function after removing leading underscore * Unpack values from __cuda_stream__ instead of indexing * Fix linting errors * Handle TypeError when unpacking invalid __cuda_stream__ return * Use stream to allocate cupy memory in new stream test
1 parent 3e1e6e0 commit 2a53029

File tree

4 files changed

+150
-9
lines changed

4 files changed

+150
-9
lines changed

python/cuda_parallel/cuda/parallel/experimental/_cccl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from numba import cuda, types
1313

14-
from ._utils.cai import get_dtype, is_contiguous
14+
from ._utils.protocols import get_dtype, is_contiguous
1515
from .iterators._iterators import IteratorBase
1616
from .typing import DeviceArrayLike, GpuStruct
1717

python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py python/cuda_parallel/cuda/parallel/experimental/_utils/protocols.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

66
"""
7-
Utilities for extracting information from `__cuda_array_interface__`.
7+
Utilities for extracting information from protocols such as `__cuda_array_interface__` and `__cuda_stream__`.
88
"""
99

1010
from typing import Optional, Tuple
@@ -68,3 +68,30 @@ def is_contiguous(arr: DeviceArrayLike) -> bool:
6868
else:
6969
# not contiguous
7070
return False
71+
72+
73+
def validate_and_get_stream(stream) -> Optional[int]:
74+
# null stream is allowed
75+
if stream is None:
76+
return None
77+
78+
try:
79+
stream_property = stream.__cuda_stream__()
80+
except AttributeError as e:
81+
raise TypeError(
82+
f"stream argument {stream} does not implement the '__cuda_stream__' protocol"
83+
) from e
84+
85+
try:
86+
version, handle, *_ = stream_property
87+
except (TypeError, ValueError) as e:
88+
raise TypeError(
89+
f"could not obtain __cuda_stream__ protocol version and handle from {stream_property}"
90+
) from e
91+
92+
if version == 0:
93+
if not isinstance(handle, int):
94+
raise TypeError(f"invalid stream handle {handle}")
95+
return handle
96+
97+
raise TypeError(f"unsupported __cuda_stream__ version {version}")

python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .. import _cccl as cccl
1717
from .._bindings import get_bindings, get_paths
1818
from .._caching import CachableFunction, cache_with_key
19-
from .._utils import cai
19+
from .._utils import protocols
2020
from ..iterators._iterators import IteratorBase
2121
from ..typing import DeviceArrayLike, GpuStruct
2222

@@ -63,7 +63,7 @@ def __init__(
6363
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name(
6464
d_in_cccl.value_type.type.value
6565
)
66-
self._ctor_d_out_dtype = cai.get_dtype(d_out)
66+
self._ctor_d_out_dtype = protocols.get_dtype(d_out)
6767
self._ctor_init_dtype = h_init.dtype
6868
cc_major, cc_minor = cuda.get_current_device().compute_capability
6969
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
@@ -89,7 +89,13 @@ def __init__(
8989
raise ValueError("Error building reduce")
9090

9191
def __call__(
92-
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
92+
self,
93+
temp_storage,
94+
d_in,
95+
d_out,
96+
num_items: int,
97+
h_init: np.ndarray | GpuStruct,
98+
stream=None,
9399
):
94100
d_in_cccl = cccl.to_cccl_iter(d_in)
95101
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
@@ -104,8 +110,9 @@ def __call__(
104110
self._ctor_d_in_cccl_type_enum_name,
105111
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
106112
)
107-
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
113+
_dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out))
108114
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
115+
stream_handle = protocols.validate_and_get_stream(stream)
109116
bindings = get_bindings()
110117
if temp_storage is None:
111118
temp_storage_bytes = ctypes.c_size_t()
@@ -125,7 +132,7 @@ def __call__(
125132
ctypes.c_ulonglong(num_items),
126133
self.op_wrapper.handle(),
127134
cccl.to_cccl_value(h_init),
128-
None,
135+
stream_handle,
129136
)
130137
if error != enums.CUDA_SUCCESS:
131138
raise ValueError("Error reducing")
@@ -145,8 +152,10 @@ def make_cache_key(
145152
op: Callable,
146153
h_init: np.ndarray,
147154
):
148-
d_in_key = d_in.kind if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in)
149-
d_out_key = cai.get_dtype(d_out)
155+
d_in_key = (
156+
d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in)
157+
)
158+
d_out_key = protocols.get_dtype(d_out)
150159
op_key = CachableFunction(op)
151160
h_init_key = h_init.dtype
152161
return (d_in_key, d_out_key, op_key, h_init_key)

python/cuda_parallel/tests/test_reduce.py

+105
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,108 @@ def binary_op(x, y):
550550
d_in = cp.zeros(size)[::2]
551551
with pytest.raises(ValueError, match="Non-contiguous arrays are not supported."):
552552
_ = algorithms.reduce_into(d_in, d_out, binary_op, h_init)
553+
554+
555+
def test_reduce_with_stream():
556+
# Simple cupy stream wrapper that implements the __cuda_stream__ protocol for the purposes of this test
557+
class Stream:
558+
def __init__(self, cp_stream):
559+
self.cp_stream = cp_stream
560+
561+
def __cuda_stream__(self):
562+
return (0, self.cp_stream.ptr)
563+
564+
def add_op(x, y):
565+
return x + y
566+
567+
h_init = np.asarray([0], dtype=np.int32)
568+
h_in = random_int(5, np.int32)
569+
570+
stream = cp.cuda.Stream()
571+
with stream:
572+
d_in = cp.asarray(h_in)
573+
d_out = cp.empty(1, dtype=np.int32)
574+
575+
stream_wrapper = Stream(stream)
576+
reduce_into = algorithms.reduce_into(
577+
d_in=d_in, d_out=d_out, op=add_op, h_init=h_init
578+
)
579+
temp_storage_size = reduce_into(
580+
None,
581+
d_in=d_in,
582+
d_out=d_out,
583+
num_items=d_in.size,
584+
h_init=h_init,
585+
stream=stream_wrapper,
586+
)
587+
with stream:
588+
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)
589+
590+
reduce_into(d_temp_storage, d_in, d_out, d_in.size, h_init, stream=stream_wrapper)
591+
with stream:
592+
cp.testing.assert_allclose(d_in.sum().get(), d_out.get())
593+
594+
595+
def test_reduce_invalid_stream():
596+
# Invalid stream that doesn't implement __cuda_stream__
597+
class Stream1:
598+
def __init__(self):
599+
pass
600+
601+
# Invalid stream that implements __cuda_stream__ but returns the wrong type
602+
class Stream2:
603+
def __init__(self):
604+
pass
605+
606+
def __cuda_stream__(self):
607+
return None
608+
609+
# Invalid stream that returns an invalid handle
610+
class Stream3:
611+
def __init__(self):
612+
pass
613+
614+
def __cuda_stream__(self):
615+
return (0, None)
616+
617+
def add_op(x, y):
618+
return x + y
619+
620+
d_out = cp.empty(1)
621+
h_init = np.empty(1)
622+
d_in = cp.empty(1)
623+
reduce_into = algorithms.reduce_into(d_in, d_out, add_op, h_init)
624+
625+
with pytest.raises(
626+
TypeError, match="does not implement the '__cuda_stream__' protocol"
627+
):
628+
_ = reduce_into(
629+
None,
630+
d_in=d_in,
631+
d_out=d_out,
632+
num_items=d_in.size,
633+
h_init=h_init,
634+
stream=Stream1(),
635+
)
636+
637+
with pytest.raises(
638+
TypeError, match="could not obtain __cuda_stream__ protocol version and handle"
639+
):
640+
_ = reduce_into(
641+
None,
642+
d_in=d_in,
643+
d_out=d_out,
644+
num_items=d_in.size,
645+
h_init=h_init,
646+
stream=Stream2(),
647+
)
648+
649+
with pytest.raises(TypeError, match="invalid stream handle"):
650+
_ = reduce_into(
651+
None,
652+
d_in=d_in,
653+
d_out=d_out,
654+
num_items=d_in.size,
655+
h_init=h_init,
656+
stream=Stream3(),
657+
)

0 commit comments

Comments
 (0)