Skip to content

Commit

Permalink
Update reduce tests to use cuda_stream fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Feb 3, 2025
1 parent 215a6fd commit 6294d93
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions python/cuda_parallel/tests/test_reduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

Expand Down Expand Up @@ -552,27 +552,18 @@ def binary_op(x, y):
_ = algorithms.reduce_into(d_in, d_out, binary_op, h_init)


def test_reduce_with_stream():
# Simple cupy stream wrapper that implements the __cuda_stream__ protocol for the purposes of this test
class Stream:
def __init__(self, cp_stream):
self.cp_stream = cp_stream

def __cuda_stream__(self):
return (0, self.cp_stream.ptr)

def test_reduce_with_stream(cuda_stream):
def add_op(x, y):
return x + y

h_init = np.asarray([0], dtype=np.int32)
h_in = random_int(5, np.int32)

stream = cp.cuda.Stream()
with stream:
cp_stream = cp.cuda.ExternalStream(cuda_stream.ptr)
with cp_stream:
d_in = cp.asarray(h_in)
d_out = cp.empty(1, dtype=np.int32)

stream_wrapper = Stream(stream)
reduce_into = algorithms.reduce_into(
d_in=d_in, d_out=d_out, op=add_op, h_init=h_init
)
Expand All @@ -582,13 +573,13 @@ def add_op(x, y):
d_out=d_out,
num_items=d_in.size,
h_init=h_init,
stream=stream_wrapper,
stream=cuda_stream,
)
with stream:
with cp_stream:
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

reduce_into(d_temp_storage, d_in, d_out, d_in.size, h_init, stream=stream_wrapper)
with stream:
reduce_into(d_temp_storage, d_in, d_out, d_in.size, h_init, stream=cuda_stream)
with cp_stream:
cp.testing.assert_allclose(d_in.sum().get(), d_out.get())


Expand Down

0 comments on commit 6294d93

Please sign in to comment.