From c9cf40eccd0df95c8101f793d96f618180512dcb Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Thu, 30 Jan 2025 16:29:41 -0800 Subject: [PATCH] Add test for iterator --- python/cuda_parallel/tests/test_scan.py | 47 ++++++++++++++++++------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/python/cuda_parallel/tests/test_scan.py b/python/cuda_parallel/tests/test_scan.py index 7a3858347f6..c50a37a0696 100644 --- a/python/cuda_parallel/tests/test_scan.py +++ b/python/cuda_parallel/tests/test_scan.py @@ -9,29 +9,26 @@ import numpy as np import cuda.parallel.experimental.algorithms as algorithms +import cuda.parallel.experimental.iterators as iterators -def exclusive_scan(inp: np.ndarray, op, init=0): - result = inp.copy() - result[0] = init +def exclusive_scan_host(h_input: np.ndarray, op, h_init=0): + result = h_input.copy() + result[0] = h_init[0] for i in range(1, len(result)): - result[i] = op(result[i - 1], inp[i - 1]) + result[i] = op(result[i - 1], h_input[i - 1]) return result -def scan_test_helper(d_input, d_output, num_items, op, h_init): +def exclusive_scan_device(d_input, d_output, num_items, op, h_init): scan = algorithms.scan(d_input, d_output, op, h_init) - temp_storage_size = scan(None, d_input, d_output, None, h_init) + temp_storage_size = scan(None, d_input, d_output, num_items, h_init) d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=np.uint8) - scan(d_temp_storage, d_input, d_output, None, h_init) - - expected = exclusive_scan(d_input.get(), op, init=h_init) - got = d_output.get() - np.testing.assert_allclose(expected, got) + scan(d_temp_storage, d_input, d_output, num_items, h_init) -def test_device_scan(input_array): +def test_scan_array_input(input_array): def op(a, b): return a + b @@ -39,4 +36,28 @@ def op(a, b): dtype = d_input.dtype h_init = np.array([42], dtype=dtype) d_output = cp.empty(len(d_input), dtype=dtype) - scan_test_helper(d_input, d_output, len(d_input), op, h_init) + + exclusive_scan_device(d_input, d_output, len(d_input), op, h_init) + + got = d_output.get() + expected = exclusive_scan_host(d_input.get(), op, h_init) + + np.testing.assert_allclose(expected, got) + + +def test_scan_iterator_input(): + def op(a, b): + return a + b + + d_input = iterators.CountingIterator(np.int32(1)) + num_items = 1024 + dtype = np.dtype("int32") + h_init = np.array([42], dtype=dtype) + d_output = cp.empty(num_items, dtype=dtype) + + exclusive_scan_device(d_input, d_output, num_items, op, h_init) + + got = d_output.get() + expected = exclusive_scan_host(np.arange(1, num_items + 1, dtype=dtype), op, h_init) + + np.testing.assert_allclose(expected, got)