Skip to content

Commit

Permalink
[Improvement] cuda.parallel: Don't require value_type when constructi…
Browse files Browse the repository at this point in the history
…ng iterators (#3105)

* Don't require value_type when constructing iterators

* Small fixes

---------

Co-authored-by: Ashwin Srinath <[email protected]>
  • Loading branch information
shwina and shwina authored Dec 9, 2024
1 parent 2cb56c3 commit 8e2d6b2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 22 deletions.
6 changes: 4 additions & 2 deletions python/cuda_parallel/cuda/parallel/experimental/_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def alignment(self):


class ConstantIterator:
def __init__(self, val, ntype):
def __init__(self, val):
ntype = numba.from_dtype(val.dtype)
thisty = numba.types.CPointer(ntype)
self.val = _ctypes_type_given_numba_type(ntype)(val)
self.ntype = ntype
Expand Down Expand Up @@ -252,7 +253,8 @@ def alignment(self):


class CountingIterator:
def __init__(self, count, ntype):
def __init__(self, count):
ntype = numba.from_dtype(count.dtype)
thisty = numba.types.CPointer(ntype)
self.count = _ctypes_type_given_numba_type(ntype)(count)
self.ntype = ntype
Expand Down
25 changes: 12 additions & 13 deletions python/cuda_parallel/cuda/parallel/experimental/iterators.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
from . import _iterators


def CacheModifiedInputIterator(device_array, value_type, modifier):
"""Python fascade for Random Access Cache Modified Iterator that wraps a native device pointer.
def CacheModifiedInputIterator(device_array, modifier):
"""Python facade for Random Access Cache Modified Iterator that wraps a native device pointer.
Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html
Currently the only supported modifier is "stream" (LOAD_CS).
"""
if modifier != "stream":
raise NotImplementedError("Only stream modifier is supported")
value_type = device_array.dtype
return _iterators.CacheModifiedPointer(
device_array.__cuda_array_interface__["data"][0],
_iterators.numba_type_from_any(value_type),
)


def ConstantIterator(value, value_type):
"""Python fascade (similar to itertools.repeat) for C++ Random Access ConstantIterator."""
return _iterators.ConstantIterator(
value, _iterators.numba_type_from_any(value_type)
)
def ConstantIterator(value):
"""Python facade (similar to itertools.repeat) for C++ Random Access ConstantIterator."""
value_type = value.dtype
return _iterators.ConstantIterator(value)


def CountingIterator(offset, value_type):
"""Python fascade (similar to itertools.count) for C++ Random Access CountingIterator."""
return _iterators.CountingIterator(
offset, _iterators.numba_type_from_any(value_type)
)
def CountingIterator(offset):
"""Python facade (similar to itertools.count) for C++ Random Access CountingIterator."""
value_type = offset.dtype
return _iterators.CountingIterator(offset)


def TransformIterator(op, it):
"""Python fascade (similar to built-in map) mimicking a C++ Random Access TransformIterator."""
"""Python facade (similar to built-in map) mimicking a C++ Random Access TransformIterator."""
return _iterators.TransformIterator(op, it)
12 changes: 5 additions & 7 deletions python/cuda_parallel/tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_device_sum_cache_modified_input_it(
dtype_out = dtype_inp
input_devarr = numba.cuda.to_device(numpy.array(l_varr, dtype=dtype_inp))
i_input = iterators.CacheModifiedInputIterator(
input_devarr, value_type=supported_value_type, modifier="stream"
input_devarr, modifier="stream"
)
_test_device_sum_with_iterator(
l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array
Expand All @@ -210,7 +210,7 @@ def test_device_sum_constant_it(
l_varr = [42 for distance in range(num_items)]
dtype_inp = numpy.dtype(supported_value_type)
dtype_out = dtype_inp
i_input = iterators.ConstantIterator(42, value_type=supported_value_type)
i_input = iterators.ConstantIterator(dtype_inp.type(42))
_test_device_sum_with_iterator(
l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array
)
Expand All @@ -222,9 +222,7 @@ def test_device_sum_counting_it(
l_varr = [start_sum_with + distance for distance in range(num_items)]
dtype_inp = numpy.dtype(supported_value_type)
dtype_out = dtype_inp
i_input = iterators.CountingIterator(
start_sum_with, value_type=supported_value_type
)
i_input = iterators.CountingIterator(dtype_inp.type(start_sum_with))
_test_device_sum_with_iterator(
l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array
)
Expand All @@ -250,7 +248,7 @@ def test_device_sum_map_mul2_count_it(
dtype_out = numpy.dtype(vtn_out)
i_input = iterators.TransformIterator(
mul2,
iterators.CountingIterator(start_sum_with, value_type=vtn_inp),
iterators.CountingIterator(dtype_inp.type(start_sum_with))
)
_test_device_sum_with_iterator(
l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array
Expand Down Expand Up @@ -285,7 +283,7 @@ def test_device_sum_map_mul_map_mul_count_it(
mul_funcs[fac_out],
iterators.TransformIterator(
mul_funcs[fac_mid],
iterators.CountingIterator(start_sum_with, value_type=vtn_inp),
iterators.CountingIterator(dtype_inp.type(start_sum_with)),
)
)
_test_device_sum_with_iterator(
Expand Down

0 comments on commit 8e2d6b2

Please sign in to comment.