Skip to content

Commit

Permalink
Add struct dtype scan test
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Feb 2, 2025
1 parent c9cf40e commit d126cba
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/cuda_parallel/cuda/parallel/experimental/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def __init__(self, func):
self._identity = (
self._func.__code__.co_code,
self._func.__code__.co_consts,
self._func.__closure__,
tuple(
cell.cell_contents
for cell in self._func.__closure__
if cell is not None
),
)

def __eq__(self, other):
Expand Down
29 changes: 29 additions & 0 deletions python/cuda_parallel/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators
from cuda.parallel.experimental.struct import gpu_struct


def exclusive_scan_host(h_input: np.ndarray, op, h_init=0):
Expand Down Expand Up @@ -61,3 +62,31 @@ def op(a, b):
expected = exclusive_scan_host(np.arange(1, num_items + 1, dtype=dtype), op, h_init)

np.testing.assert_allclose(expected, got)


def test_scan_struct_type():
@gpu_struct
class XY:
x: np.int32
y: np.int32

def op(a, b):
return XY(a.x + b.x, a.y + b.y)

d_input = cp.random.randint(0, 256, (10, 2), dtype=np.int32).view(XY.dtype)
d_output = cp.empty_like(d_input)

h_init = XY(0, 0)

exclusive_scan_device(d_input, d_output, len(d_input), op, h_init)

got = d_output.get()
expected_x = exclusive_scan_host(
d_input.get()["x"], lambda a, b: a + b, np.asarray([h_init.x])
)
expected_y = exclusive_scan_host(
d_input.get()["y"], lambda a, b: a + b, np.asarray([h_init.y])
)

np.testing.assert_allclose(expected_x, got["x"])
np.testing.assert_allclose(expected_y, got["y"])

0 comments on commit d126cba

Please sign in to comment.