Skip to content

Commit

Permalink
Add initial tests for scan API
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashwin Srinath authored and Ashwin Srinath committed Jan 30, 2025
1 parent 07861f0 commit 40127c8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/cuda_parallel/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import cupy as cp
import numpy as np
import pytest


# Define a pytest fixture that returns random arrays with different dtypes
@pytest.fixture(
params=[
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.float32,
np.float64,
np.complex128,
]
)
def input_array(request):
dtype = request.param

# Generate random values based on the dtype
if np.issubdtype(dtype, np.integer):
# For integer types, use np.random.randint for random integers
array = cp.random.randint(low=0, high=100, size=10, dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# For floating-point types, use np.random.random and cast to the required dtype
array = cp.random.random(10).astype(dtype)
elif np.issubdtype(dtype, np.complexfloating):
# For complex types, generate random real and imaginary parts
real_part = cp.random.random(10)
imag_part = cp.random.random(10)
array = (real_part + 1j * imag_part).astype(dtype)

return array
42 changes: 42 additions & 0 deletions python/cuda_parallel/tests/test_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


import cupy as cp
import numba.cuda
import numba.types
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms


def exclusive_scan(inp: np.ndarray, op, init=0):
result = inp.copy()
result[0] = init
for i in range(1, len(result)):
result[i] = op(result[i - 1], inp[i - 1])
return result


def scan_test_helper(d_input, d_output, num_items, op, h_init):
scan = algorithms.scan(d_input, d_output, op, h_init)
temp_storage_size = scan(None, d_input, d_output, None, h_init)
d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=np.uint8)

scan(d_temp_storage, d_input, d_output, None, h_init)

expected = exclusive_scan(d_input.get(), op, init=h_init)
got = d_output.get()
np.testing.assert_allclose(expected, got)


def test_device_scan(input_array):
def op(a, b):
return a + b

d_input = input_array
dtype = d_input.dtype
h_init = np.array([42], dtype=dtype)
d_output = cp.empty(len(d_input), dtype=dtype)
scan_test_helper(d_input, d_output, len(d_input), op, h_init)

0 comments on commit 40127c8

Please sign in to comment.