Skip to content

Commit

Permalink
[WIP] Support fancy iterators in cuda.parallel (#2788)
Browse files Browse the repository at this point in the history
* New python/cuda_parallel/cuda/parallel/experimental/random_access_iterators.py (with unit tests).

* Use TransformRAI to implement constant, counting, arbitrary RAIs.

* Transfer `num_item`-related changes from cuda_parallel_itertools branch.

The other branch is: https://github.com/rwgk/cccl/tree/cuda_parallel_itertools

* Rename `op` to `reduction_op` in `cccl_device_reduce_build()`

* Transfer test_device_sum_repeat_1_equals_num_items() from cuda_parallel_itertools branch.

The other branch is: https://github.com/rwgk/cccl/tree/cuda_parallel_itertools

* Add `class _TransformRAIUnaryOp`. Make `constant_op()` less trivial (to see if numba can still compile it).

* Rename `class _Op` to `_ReductionOp` for clarity.

* WIP Use TransformRAIUnaryOp here

* INCOMPLETE SNAPSHOT

* This links the input_unary_op() successfully into the nvrtc cubin, but then fails with: Fatal Python error: Floating point exception

* passing `_type_to_info_from_numba_type(numba.int32)` as `value_type` resolves the Floating point exception (but the `cccl_device_reduce()` call still does not succeed)

* More debug output.

LOOOK single_tile_kernel CALL /home/coder/cccl/c/parallel/src/reduce.cu:116

LOOOK EXCEPTION CUDA error: invalid argument  /home/coder/cccl/c/parallel/src/reduce.cu:703

* Substituting `fake_in_ptr` if `in_ptr == nullptr`: All tests pass.

* Rename new test function to `test_device_sum_input_unary_op()` and parametrize: `use_numpy_array`: `[True, False]`, `input_generator`: `["constant", "counting", "arbitrary", "nested"]`

* Remove python/cuda_parallel/cuda/parallel/experimental/random_access_iterators.py (because numba.cuda cannot JIT classes).

* Add `"nested_global"` test, but disable.

* `cu_repeat()`, `cu_count()`, `cu_arbitrary()` functions that return a `unary_op`, which is then compiled with `numba.cuda.compile()`

* Snapshot DOES NOT WORK, from 2024-10-20+1646

* Files copy-pasted from 2479 comment, then: notabs.py kernel.cpp main.py

* Commands to build and run the POC.

* `RawPointer(..., dtype)`

* `assert_expected_output_array(, `more_nested_map`

* Add `@register_jitable` to `cu_repeat()`, ..., `cu_map()`: this fixes the `"map_mul2"` test and the added `"map_add10_map_mul2"` test works, too.

* Restore original c/parallel/src/reduce.cu (to avoid `git merge main` conflicts).

* Transfer a few minor `printf("\nLOOOK ...);` from python_random_access_iterators branch.

* clang-format auto-fixes

* Change `ADVANCE(this, diff)` to `ADVANCE(data, diff)`

* Remove `input_unary_op`-related code.

* Plug Georgii's POC ConstantIterator, CountingIterator into cuda_parallel test_reduce.py

* Change `map(it, op)` to `cu_map(op, it)`, to be compatible with the built-in `map(func, *iterables)`

* Plug Georgii's POC map code into cuda_parallel test_reduce.py; test FAILS:

ptxas fatal   : Unresolved extern function 'transform_count_int32_mul2_advance'
ERROR NVJITLINK_ERROR_PTX_COMPILE: JIT the PTX (ltoPtx)
LOOOK EXCEPTION nvJitLink error: 4  /home/coder/cccl/c/parallel/src/reduce.cu:339
=========================== short test summary info ============================
FAILED tests/test_reduce.py::test_device_sum_sentinel_iterator[map_mul2-False] - ValueError: Error building reduce
======================= 1 failed, 11 deselected in 1.33s =======================

* Add `cccl_string_views* ltoirs` to `cccl_iterator_t`

* Populate `cccl_iterator_t::ltoirs` in `_itertools_iter_as_cccl_iter()`

* Add `nvrtc_sm_top_level::add_link_list()`

* Fix `_extract_ctypes_ltoirs()` implementation

* Add `extract_ltoirs(const cccl_iterator_t&)` implementation. All tests pass, including the `cu_map()` test.

* Plug Georgii's POC RawPointer, CacheModifiedPointer into cuda_parallel test_reduce.py

* Copy all sentinel iterators wholesale from georgii_poc_2479/pocenv/main.py to new python/cuda_parallel/cuda/parallel/experimental/sentinel_iterators.py

* Remove all sentinel iterator implementations from test_reduce.py and use cuda.parallel.experimental.sentinel_iterators instead.

* Cleanup

* Change `DEREF(this)` to `DEREF(data)`

* Make RawPointer, CacheModifiedPointer nnuba.cuda.compile() `sig` expressions more readable and fix `int32` vs `uint64` mixup for `advance` methods.

* Add `cccl_iterator_t::advance_multiply_diff_with_sizeof_value_t` feature.

* Make discovery mechanism for cuda/_include directory compatible with `pip
 install --editable`

* Make discovery mechanism for cuda/_include directory compatible with `pip
 install --editable`

* Add pytest-xdist to test requirements.

* Revert "Add `cccl_iterator_t::advance_multiply_diff_with_sizeof_value_t` feature."

This reverts commit eaee196.

* pointer_advance_sizeof(), cache_advance_sizeof(): Use Python capture to achieve `distance * sizeof_dtype` calculation.

* Rename sentinel_iterators.py -> iterators.py

* ConstantIterator(..., dtype)

* Change `dtype`, `numba_type` variable names to `ntype`.

* CountingIterator(..., ntype)

* Add `ntype.name` to `self.prefix` for `RawPointer`, `CacheModifiedPointer`. Introduce `_ncc()` helper function.

* Trivial: Shuffle order of tests to match code in iterators.py

* ruff format iterators.py (NO manual changes).

* Add "raw_pointer_int16" test.

* Generalize `ldcs()` `@intrinsic` and add "streamed_input_int16" test.

* Expand tests for raw_pointer, streamed_input: (int, uint) x (16, 32, 64)

* _numba_type_as_ir_type() experiment: does not make a difference

Compared to the `int32` case, this line is inserted:

```
        cvt.u64.u32     %rd3, %r1;
```

This is the generated ptx code (including that line) and the error message:

```
        // .globl       cacheint64_dereference
.common .global .align 8 .u64 _ZN08NumbaEnv4cuda8parallel12experimental9iterators20CacheModifiedPointer32cache_cache_dereference_bitwidth12_3clocals_3e17cache_dereferenceB3v48B96cw51cXTLSUwv1sCUt9Uw1VEw0NRRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dE11int64_2a_2a;

.visible .func  (.param .b64 func_retval0) cacheint64_dereference(
        .param .b64 cacheint64_dereference_param_0
)
{
        .reg .b32       %r<2>;
        .reg .b64       %rd<4>;

        ld.param.u64    %rd2, [cacheint64_dereference_param_0];
        ld.u64  %rd1, [%rd2];
        // begin inline asm
        ld.global.cs.s64 %r1, [%rd1];
        // end inline asm
        cvt.u64.u32     %rd3, %r1;
        st.param.b64    [func_retval0+0], %rd3;
        ret;

}
ptxas application ptx input, line 6672; error   : Arguments mismatch for instruction 'ld'
ptxas fatal   : Ptx assembly aborted due to errors
ERROR NVJITLINK_ERROR_PTX_COMPILE: JIT the PTX (ltoPtx)
```

* Revert "_numba_type_as_ir_type() experiment: does not make a difference"

This reverts commit c32e7ed.

* Plug in numba pointer arithmetic code provided by @gmarkall (with a small bug fix):

@gmarkall's POC code: #2861

The small bug fix:

```diff
 def sizeof_pointee(context, ptr):
-    size = context.get_abi_sizeof(ptr.type)
+    size = context.get_abi_sizeof(ptr.type.pointee)
     return ir.Constant(ir.IntType(64), size)
```

* WIP

* Fix streamed_input_int64, streamed_input_uint64 issue using suggestion provided by @gevtushenko:

#2788 (comment)

* Simplified and more readable `ldcs()` implementation.

* Add raw_pointer_floatXX, streamed_input_floatXX tests. Those run but the reduction results are incorrect. Needs debugging.

* Use `numpy.dtype(intty)` to obtain an actual dtype object (rather than a numpy scalar object).

* Fix `_itertools_iter_as_cccl_iter()` helper function do use `d_in.ntype` instead of hard-wired `int32`. This fixes the raw_pointer_floatXX, streamed_input_floatXX tests.

* Add raw_pointer_float16, streamed_input_float16 tests. Needs a change in iterators.py, not sure why, or if there is a better solution.

* Add constant_XXX tests. constant_float16 does not work because ctypes does not support float16.

* Add counting_XXX tests. counting_float32,64 do not work, needs debugging.

* Change cu_map() make_advance_codegen() distty to uint64, for consistency with all other iterator types.

* Replace all cu_map() int32 with either it.ntype or op_return_ntype

* Fix ntype vs types.uint64 mixup for distance in ConstantIterator, CountingIterator. This fixes the counting_float32, counting_float64 tests.

* Remove float16 code and tests. Not needed at this stage.

* Fix trivial (but slightly confusing) typo.

* Remove unused host_address(self) methods.

* Introduce _DEVICE_POINTER_SIZE, _DEVICE_POINTER_BITWIDTH, _DISTANCE_NUMBA_TYPE, _DISTANCE_IR_TYPE. Resolve most TODOs.

* Generalize map_mul2 test as map_mul2_int32_int32, but still only test with int32.

* Add more map_mul2 tests that all pass with the production code as-is. map_mul2_float32_int32 does not pass (commented out in this commit).

* Make cu_map() code more readable by introducing _CHAR_PTR_NUMBA_TYPE, _CHAR_PTR_IR_TYPE

* Archiving debugging code including bug fix in make_dereference_codegen()

* Revert "Archiving debugging code including bug fix in make_dereference_codegen()"

This reverts commit b525b80.

* FIX but does not work

* Add more map_mul2 tests. For unknown reasons these all pass when run separately, or when running all tests in parallel (pytest -v -n 32).

* Add type information to cu_map() op abi_name.

* Simplify() helper function, reuse for compiling cu_map() op

* Change map_mul2_xxx test names to map_mul2_count_xxx

* Add map_mul3_map_mul2_count_int32_int32_int32, map_mul2_map_mul2_count_float64_float32_int16 tests. map_mul* tests are flaky,TBD why.

* Introduce op_caller() in cu_map(): this solves the test failures (TBH, I cannot explain why/how).

* Move `num_items` argument to just before `h_init`, for compatibility with cub `DeviceReduce` API.

* Move iterators.py -> _iterators.py, add new iterators.py as public interface.

* Pull repeat(), count() out from _iterators.py into iterators.py; add map() forwarding function (to renamed cumap()).

* Move `class TransformIterator` out to module scope. This make it more clear where Python captures are used. It might also be a minor performance improvement because the `class` definition code is only processed once.

* Replace the very awkward _Reduce.__handle_d_in() method with the new _d_in_as_cccl_iter() function.

* Move _iterators.cache() to iterators.cache_load_modifier()

* Fix minor oversights in docstrings.

* Enable passing device array as `it` to `map()` (testing with cupy array).

* Rename `l_input` -> `l_varr` (in test_reduce.py).

* Rename `ldcs()` -> `load_cs()` and add reference to cub `CacheModifiedInputIterator` `LOAD_CS`

* Move `*_advance` and `*_dereference` functions out of class scopes, to make it obvious that they are never used as Python methods, but exclusively as source for `numba.cuda.compile()`

* Turn state_c_void_p, size, alignment methods into properties.

* Improved organization of newly added tests. NO functional changes. Applied ruff format to newly added code.

* Add comments pointing to issue #3064

* Change `ntype` to `value_type` in public APIs. Introduce `numba_type_from_any(value_type)` function.

* Change names of function in public API.

* Effectively undo commit 708c341: Move `*_advance` and `*_dereference` functions back to class scope, with comments to explicitly state that these are not actual methods.

* Use `@pytest.fixture` to replace most `@pytest.mark.parametrize`, as suggested by @shwina
  • Loading branch information
rwgk authored Dec 6, 2024
1 parent 8536b91 commit 199f2a5
Show file tree
Hide file tree
Showing 8 changed files with 866 additions and 49 deletions.
13 changes: 13 additions & 0 deletions c/parallel/include/cccl/c/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ struct cccl_value_t
void* state;
};

struct cccl_string_view
{
const char* begin;
int size;
};

struct cccl_string_views
{
cccl_string_view* views;
int size;
};

struct cccl_iterator_t
{
int size;
Expand All @@ -80,4 +92,5 @@ struct cccl_iterator_t
cccl_op_t dereference;
cccl_type_info value_type;
void* state;
cccl_string_views* ltoirs = nullptr;
};
4 changes: 2 additions & 2 deletions c/parallel/src/kernels/iterators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ struct __align__(OP_ALIGNMENT) input_iterator_state_t {
using difference_type = DIFF_T;
using pointer = VALUE_T*;
using reference = VALUE_T&;
__device__ inline value_type operator*() const { return DEREF(this); }
__device__ inline value_type operator*() const { return DEREF(data); }
__device__ inline input_iterator_state_t& operator+=(difference_type diff) {
ADVANCE(this, diff);
ADVANCE(data, diff);
return *this;
}
__device__ inline value_type operator[](difference_type diff) const {
Expand Down
48 changes: 35 additions & 13 deletions c/parallel/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build(
op_src, // 6
policy.vector_load_length); // 7

#if false // CCCL_DEBUGGING_SWITCH
fflush(stderr);
printf("\nCODE4NVRTC BEGIN\n%sCODE4NVRTC END\n", src.c_str());
fflush(stdout);
#endif

std::string single_tile_kernel_name = get_single_tile_kernel_name(input_it, output_it, op, init, false);
std::string single_tile_second_kernel_name = get_single_tile_kernel_name(input_it, output_it, op, init, true);
std::string reduction_kernel_name = get_device_reduce_kernel_name(op, input_it, init);
Expand All @@ -287,16 +293,23 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build(
}
};
ltoir_list_append({op.ltoir, op.ltoir_size});
if (cccl_iterator_kind_t::iterator == input_it.type)
{
ltoir_list_append({input_it.advance.ltoir, input_it.advance.ltoir_size});
ltoir_list_append({input_it.dereference.ltoir, input_it.dereference.ltoir_size});
}
if (cccl_iterator_kind_t::iterator == output_it.type)
{
ltoir_list_append({output_it.advance.ltoir, output_it.advance.ltoir_size});
ltoir_list_append({output_it.dereference.ltoir, output_it.dereference.ltoir_size});
}
auto extract_ltoirs = [ltoir_list_append](const cccl_iterator_t& it) {
if (cccl_iterator_kind_t::iterator == it.type)
{
ltoir_list_append({it.advance.ltoir, it.advance.ltoir_size});
ltoir_list_append({it.dereference.ltoir, it.dereference.ltoir_size});
if (it.ltoirs != nullptr)
{
for (int i = 0; i < it.ltoirs->size; i++)
{
auto view = it.ltoirs->views[i];
ltoir_list_append({view.begin, view.size});
}
}
}
};
extract_ltoirs(input_it);
extract_ltoirs(output_it);

nvrtc_cubin result =
make_nvrtc_command_list()
Expand All @@ -323,8 +336,11 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build(
build->cubin_size = result.size;
build->accumulator_size = accum_t.size;
}
catch (...)
catch (const std::exception& exc)
{
fflush(stderr);
printf("\nEXCEPTION in cccl_device_reduce_build(): %s\n", exc.what());
fflush(stdout);
error = CUDA_ERROR_UNKNOWN;
}

Expand Down Expand Up @@ -411,8 +427,11 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce(
cub::CudaDriverLauncherFactory{cu_device, build.cc},
{get_accumulator_type(op, d_in, init)});
}
catch (...)
catch (const std::exception& exc)
{
fflush(stderr);
printf("\nEXCEPTION in cccl_device_reduce(): %s\n", exc.what());
fflush(stdout);
error = CUDA_ERROR_UNKNOWN;
}

Expand All @@ -437,8 +456,11 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_cleanup(cccl_device_reduce_bui
std::unique_ptr<char[]> cubin(reinterpret_cast<char*>(bld_ptr->cubin));
check(cuLibraryUnload(bld_ptr->library));
}
catch (...)
catch (const std::exception& exc)
{
fflush(stderr);
printf("\nEXCEPTION in cccl_device_reduce_cleanup(): %s\n", exc.what());
fflush(stdout);
return CUDA_ERROR_UNKNOWN;
}

Expand Down
108 changes: 90 additions & 18 deletions python/cuda_parallel/cuda/parallel/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numba.cuda.cudadrv import enums


# Should match C++
# MUST match `cccl_type_enum` in c/include/cccl/c/types.h
class _TypeEnum(ctypes.c_int):
INT8 = 0
INT16 = 1
Expand All @@ -27,13 +27,30 @@ class _TypeEnum(ctypes.c_int):
STORAGE = 10


# Should match C++
def _cccl_type_enum_as_name(enum_value):
assert isinstance(enum_value, int)
return (
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
"STORAGE",
)[enum_value]


# MUST match `cccl_op_kind_t` in c/include/cccl/c/types.h
class _CCCLOpKindEnum(ctypes.c_int):
STATELESS = 0
STATEFUL = 1


# Should match C++
# MUST match `cccl_iterator_kind_t` in c/include/cccl/c/types.h
class _CCCLIteratorKindEnum(ctypes.c_int):
POINTER = 0
ITERATOR = 1
Expand All @@ -57,13 +74,14 @@ def _type_to_enum(numba_type):
return _TypeEnum.STORAGE


# TODO Extract into reusable module
# MUST match `cccl_type_info` in c/include/cccl/c/types.h
class _TypeInfo(ctypes.Structure):
_fields_ = [("size", ctypes.c_int),
("alignment", ctypes.c_int),
("type", _TypeEnum)]


# MUST match `cccl_op_t` in c/include/cccl/c/types.h
class _CCCLOp(ctypes.Structure):
_fields_ = [("type", _CCCLOpKindEnum),
("name", ctypes.c_char_p),
Expand All @@ -74,36 +92,55 @@ class _CCCLOp(ctypes.Structure):
("state", ctypes.c_void_p)]


# MUST match `cccl_string_view` in c/include/cccl/c/types.h
class _CCCLStringView(ctypes.Structure):
_fields_ = [("begin", ctypes.c_char_p),
("size", ctypes.c_int)]


# MUST match `cccl_string_views` in c/include/cccl/c/types.h
class _CCCLStringViews(ctypes.Structure):
_fields_ = [("views", ctypes.POINTER(_CCCLStringView)),
("size", ctypes.c_int)]


# MUST match `cccl_iterator_t` in c/include/cccl/c/types.h
class _CCCLIterator(ctypes.Structure):
_fields_ = [("size", ctypes.c_int),
("alignment", ctypes.c_int),
("type", _CCCLIteratorKindEnum),
("advance", _CCCLOp),
("dereference", _CCCLOp),
("value_type", _TypeInfo),
("state", ctypes.c_void_p)]
("state", ctypes.c_void_p),
("ltoirs", ctypes.POINTER(_CCCLStringViews))]


# MUST match `cccl_value_t` in c/include/cccl/c/types.h
class _CCCLValue(ctypes.Structure):
_fields_ = [("type", _TypeInfo),
("state", ctypes.c_void_p)]


def _type_to_info(numpy_type):
numba_type = numba.from_dtype(numpy_type)
def _type_to_info_from_numba_type(numba_type):
context = cuda.descriptor.cuda_target.target_context
size = context.get_value_type(numba_type).get_abi_size(context.target_data)
alignment = context.get_value_type(
numba_type).get_abi_alignment(context.target_data)
return _TypeInfo(size, alignment, _type_to_enum(numba_type))


def _type_to_info(numpy_type):
numba_type = numba.from_dtype(numpy_type)
return _type_to_info_from_numba_type(numba_type)


def _device_array_to_pointer(array):
dtype = array.dtype
info = _type_to_info(dtype)
# Note: this is slightly slower, but supports all ndarray-like objects as long as they support CAI
# TODO: switch to use gpumemoryview once it's ready
return _CCCLIterator(1, 1, _CCCLIteratorKindEnum.POINTER, _CCCLOp(), _CCCLOp(), info, array.__cuda_array_interface__["data"][0])
return _CCCLIterator(1, 1, _CCCLIteratorKindEnum.POINTER, _CCCLOp(), _CCCLOp(), info, array.__cuda_array_interface__["data"][0], None)


def _host_array_to_value(array):
Expand All @@ -123,6 +160,32 @@ def handle(self):
return _CCCLOp(_CCCLOpKindEnum.STATELESS, self.name, ctypes.c_char_p(self.ltoir), len(self.ltoir), 1, 1, None)


def _extract_ctypes_ltoirs(numba_cuda_compile_results):
view_lst = [_CCCLStringView(ltoir, len(ltoir))
for ltoir, _ in numba_cuda_compile_results]
view_arr = (_CCCLStringView * len(view_lst))(*view_lst)
return ctypes.pointer(_CCCLStringViews(view_arr, len(view_arr)))


def _facade_iter_as_cccl_iter(d_in):
def prefix_name(name):
return (d_in.prefix + "_" + name).encode('utf-8')
# type name ltoi ltoir_size size alignment state
adv = _CCCLOp(_CCCLOpKindEnum.STATELESS, prefix_name("advance"), None, 0, 1, 1, None)
drf = _CCCLOp(_CCCLOpKindEnum.STATELESS, prefix_name("dereference"), None, 0, 1, 1, None)
info = _type_to_info_from_numba_type(d_in.ntype)
ltoirs = _extract_ctypes_ltoirs(d_in.ltoirs)
# size alignment type advance dereference value_type state ltoirs
return _CCCLIterator(d_in.size, d_in.alignment, _CCCLIteratorKindEnum.ITERATOR, adv, drf, info, d_in.state_c_void_p, ltoirs)


def _d_in_as_cccl_iter(d_in):
if hasattr(d_in, 'ntype'):
return _facade_iter_as_cccl_iter(d_in)
assert hasattr(d_in, 'dtype')
return _device_array_to_pointer(d_in)


def _get_cuda_path():
cuda_path = os.environ.get('CUDA_PATH', '')
if os.path.exists(cuda_path):
Expand Down Expand Up @@ -177,6 +240,7 @@ def _get_paths():
return _paths


# MUST match `cccl_device_reduce_build_result_t` in c/include/cccl/c/reduce.h
class _CCCLDeviceReduceBuildResult(ctypes.Structure):
_fields_ = [("cc", ctypes.c_int),
("cubin", ctypes.c_void_p),
Expand All @@ -195,21 +259,22 @@ def _dtype_validation(dt1, dt2):

class _Reduce:
def __init__(self, d_in, d_out, op, init):
self._ctor_d_in_dtype = d_in.dtype
d_in_cccl = _d_in_as_cccl_iter(d_in)
self._ctor_d_in_cccl_type_enum_name = _cccl_type_enum_as_name(
d_in_cccl.value_type.type.value)
self._ctor_d_out_dtype = d_out.dtype
self._ctor_init_dtype = init.dtype
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = _get_paths()
bindings = _get_bindings()
accum_t = init.dtype
self.op_wrapper = _Op(accum_t, op)
d_in_ptr = _device_array_to_pointer(d_in)
d_out_ptr = _device_array_to_pointer(d_out)
self.build_result = _CCCLDeviceReduceBuildResult()

# TODO Figure out caching
error = bindings.cccl_device_reduce_build(ctypes.byref(self.build_result),
d_in_ptr,
d_in_cccl,
d_out_ptr,
self.op_wrapper.handle(),
_host_array_to_value(init),
Expand All @@ -223,9 +288,18 @@ def __init__(self, d_in, d_out, op, init):
if error != enums.CUDA_SUCCESS:
raise ValueError('Error building reduce')

def __call__(self, temp_storage, d_in, d_out, init):
# TODO validate POINTER vs ITERATOR when iterator support is added
_dtype_validation(self._ctor_d_in_dtype, d_in.dtype)
def __call__(self, temp_storage, d_in, d_out, num_items, init):
d_in_cccl = _d_in_as_cccl_iter(d_in)
if d_in_cccl.type.value == _CCCLIteratorKindEnum.ITERATOR:
assert num_items is not None
else:
assert d_in_cccl.type.value == _CCCLIteratorKindEnum.POINTER
if num_items is None:
num_items = d_in.size
else:
assert num_items == d_in.size
_dtype_validation(self._ctor_d_in_cccl_type_enum_name,
_cccl_type_enum_as_name(d_in_cccl.value_type.type.value))
_dtype_validation(self._ctor_d_out_dtype, d_out.dtype)
_dtype_validation(self._ctor_init_dtype, init.dtype)
bindings = _get_bindings()
Expand All @@ -237,15 +311,13 @@ def __call__(self, temp_storage, d_in, d_out, init):
# Note: this is slightly slower, but supports all ndarray-like objects as long as they support CAI
# TODO: switch to use gpumemoryview once it's ready
d_temp_storage = temp_storage.__cuda_array_interface__["data"][0]
d_in_ptr = _device_array_to_pointer(d_in)
d_out_ptr = _device_array_to_pointer(d_out)
num_items = ctypes.c_ulonglong(d_in.size)
error = bindings.cccl_device_reduce(self.build_result,
d_temp_storage,
ctypes.byref(temp_storage_bytes),
d_in_ptr,
d_in_cccl,
d_out_ptr,
num_items,
ctypes.c_ulonglong(num_items),
self.op_wrapper.handle(),
_host_array_to_value(init),
None)
Expand Down
Loading

0 comments on commit 199f2a5

Please sign in to comment.