From 199f2a5230280ea0a296274068f750b6fff8d842 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Fri, 6 Dec 2024 15:32:49 -0800 Subject: [PATCH] [WIP] Support fancy iterators in cuda.parallel (#2788) * New python/cuda_parallel/cuda/parallel/experimental/random_access_iterators.py (with unit tests). * Use TransformRAI to implement constant, counting, arbitrary RAIs. * Transfer `num_item`-related changes from cuda_parallel_itertools branch. The other branch is: https://github.com/rwgk/cccl/tree/cuda_parallel_itertools * Rename `op` to `reduction_op` in `cccl_device_reduce_build()` * Transfer test_device_sum_repeat_1_equals_num_items() from cuda_parallel_itertools branch. The other branch is: https://github.com/rwgk/cccl/tree/cuda_parallel_itertools * Add `class _TransformRAIUnaryOp`. Make `constant_op()` less trivial (to see if numba can still compile it). * Rename `class _Op` to `_ReductionOp` for clarity. * WIP Use TransformRAIUnaryOp here * INCOMPLETE SNAPSHOT * This links the input_unary_op() successfully into the nvrtc cubin, but then fails with: Fatal Python error: Floating point exception * passing `_type_to_info_from_numba_type(numba.int32)` as `value_type` resolves the Floating point exception (but the `cccl_device_reduce()` call still does not succeed) * More debug output. LOOOK single_tile_kernel CALL /home/coder/cccl/c/parallel/src/reduce.cu:116 LOOOK EXCEPTION CUDA error: invalid argument /home/coder/cccl/c/parallel/src/reduce.cu:703 * Substituting `fake_in_ptr` if `in_ptr == nullptr`: All tests pass. * Rename new test function to `test_device_sum_input_unary_op()` and parametrize: `use_numpy_array`: `[True, False]`, `input_generator`: `["constant", "counting", "arbitrary", "nested"]` * Remove python/cuda_parallel/cuda/parallel/experimental/random_access_iterators.py (because numba.cuda cannot JIT classes). * Add `"nested_global"` test, but disable. * `cu_repeat()`, `cu_count()`, `cu_arbitrary()` functions that return a `unary_op`, which is then compiled with `numba.cuda.compile()` * Snapshot DOES NOT WORK, from 2024-10-20+1646 * Files copy-pasted from 2479 comment, then: notabs.py kernel.cpp main.py * Commands to build and run the POC. * `RawPointer(..., dtype)` * `assert_expected_output_array(, `more_nested_map` * Add `@register_jitable` to `cu_repeat()`, ..., `cu_map()`: this fixes the `"map_mul2"` test and the added `"map_add10_map_mul2"` test works, too. * Restore original c/parallel/src/reduce.cu (to avoid `git merge main` conflicts). * Transfer a few minor `printf("\nLOOOK ...);` from python_random_access_iterators branch. * clang-format auto-fixes * Change `ADVANCE(this, diff)` to `ADVANCE(data, diff)` * Remove `input_unary_op`-related code. * Plug Georgii's POC ConstantIterator, CountingIterator into cuda_parallel test_reduce.py * Change `map(it, op)` to `cu_map(op, it)`, to be compatible with the built-in `map(func, *iterables)` * Plug Georgii's POC map code into cuda_parallel test_reduce.py; test FAILS: ptxas fatal : Unresolved extern function 'transform_count_int32_mul2_advance' ERROR NVJITLINK_ERROR_PTX_COMPILE: JIT the PTX (ltoPtx) LOOOK EXCEPTION nvJitLink error: 4 /home/coder/cccl/c/parallel/src/reduce.cu:339 =========================== short test summary info ============================ FAILED tests/test_reduce.py::test_device_sum_sentinel_iterator[map_mul2-False] - ValueError: Error building reduce ======================= 1 failed, 11 deselected in 1.33s ======================= * Add `cccl_string_views* ltoirs` to `cccl_iterator_t` * Populate `cccl_iterator_t::ltoirs` in `_itertools_iter_as_cccl_iter()` * Add `nvrtc_sm_top_level::add_link_list()` * Fix `_extract_ctypes_ltoirs()` implementation * Add `extract_ltoirs(const cccl_iterator_t&)` implementation. All tests pass, including the `cu_map()` test. * Plug Georgii's POC RawPointer, CacheModifiedPointer into cuda_parallel test_reduce.py * Copy all sentinel iterators wholesale from georgii_poc_2479/pocenv/main.py to new python/cuda_parallel/cuda/parallel/experimental/sentinel_iterators.py * Remove all sentinel iterator implementations from test_reduce.py and use cuda.parallel.experimental.sentinel_iterators instead. * Cleanup * Change `DEREF(this)` to `DEREF(data)` * Make RawPointer, CacheModifiedPointer nnuba.cuda.compile() `sig` expressions more readable and fix `int32` vs `uint64` mixup for `advance` methods. * Add `cccl_iterator_t::advance_multiply_diff_with_sizeof_value_t` feature. * Make discovery mechanism for cuda/_include directory compatible with `pip install --editable` * Make discovery mechanism for cuda/_include directory compatible with `pip install --editable` * Add pytest-xdist to test requirements. * Revert "Add `cccl_iterator_t::advance_multiply_diff_with_sizeof_value_t` feature." This reverts commit eaee1962506326163f70e23726126b335ddc2f6a. * pointer_advance_sizeof(), cache_advance_sizeof(): Use Python capture to achieve `distance * sizeof_dtype` calculation. * Rename sentinel_iterators.py -> iterators.py * ConstantIterator(..., dtype) * Change `dtype`, `numba_type` variable names to `ntype`. * CountingIterator(..., ntype) * Add `ntype.name` to `self.prefix` for `RawPointer`, `CacheModifiedPointer`. Introduce `_ncc()` helper function. * Trivial: Shuffle order of tests to match code in iterators.py * ruff format iterators.py (NO manual changes). * Add "raw_pointer_int16" test. * Generalize `ldcs()` `@intrinsic` and add "streamed_input_int16" test. * Expand tests for raw_pointer, streamed_input: (int, uint) x (16, 32, 64) * _numba_type_as_ir_type() experiment: does not make a difference Compared to the `int32` case, this line is inserted: ``` cvt.u64.u32 %rd3, %r1; ``` This is the generated ptx code (including that line) and the error message: ``` // .globl cacheint64_dereference .common .global .align 8 .u64 _ZN08NumbaEnv4cuda8parallel12experimental9iterators20CacheModifiedPointer32cache_cache_dereference_bitwidth12_3clocals_3e17cache_dereferenceB3v48B96cw51cXTLSUwv1sCUt9Uw1VEw0NRRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dE11int64_2a_2a; .visible .func (.param .b64 func_retval0) cacheint64_dereference( .param .b64 cacheint64_dereference_param_0 ) { .reg .b32 %r<2>; .reg .b64 %rd<4>; ld.param.u64 %rd2, [cacheint64_dereference_param_0]; ld.u64 %rd1, [%rd2]; // begin inline asm ld.global.cs.s64 %r1, [%rd1]; // end inline asm cvt.u64.u32 %rd3, %r1; st.param.b64 [func_retval0+0], %rd3; ret; } ptxas application ptx input, line 6672; error : Arguments mismatch for instruction 'ld' ptxas fatal : Ptx assembly aborted due to errors ERROR NVJITLINK_ERROR_PTX_COMPILE: JIT the PTX (ltoPtx) ``` * Revert "_numba_type_as_ir_type() experiment: does not make a difference" This reverts commit c32e7ed90ab451c367a149bf6107b31460f0a4bb. * Plug in numba pointer arithmetic code provided by @gmarkall (with a small bug fix): @gmarkall's POC code: https://github.com/NVIDIA/cccl/issues/2861 The small bug fix: ```diff def sizeof_pointee(context, ptr): - size = context.get_abi_sizeof(ptr.type) + size = context.get_abi_sizeof(ptr.type.pointee) return ir.Constant(ir.IntType(64), size) ``` * WIP * Fix streamed_input_int64, streamed_input_uint64 issue using suggestion provided by @gevtushenko: https://github.com/NVIDIA/cccl/pull/2788#issuecomment-2487412798 * Simplified and more readable `ldcs()` implementation. * Add raw_pointer_floatXX, streamed_input_floatXX tests. Those run but the reduction results are incorrect. Needs debugging. * Use `numpy.dtype(intty)` to obtain an actual dtype object (rather than a numpy scalar object). * Fix `_itertools_iter_as_cccl_iter()` helper function do use `d_in.ntype` instead of hard-wired `int32`. This fixes the raw_pointer_floatXX, streamed_input_floatXX tests. * Add raw_pointer_float16, streamed_input_float16 tests. Needs a change in iterators.py, not sure why, or if there is a better solution. * Add constant_XXX tests. constant_float16 does not work because ctypes does not support float16. * Add counting_XXX tests. counting_float32,64 do not work, needs debugging. * Change cu_map() make_advance_codegen() distty to uint64, for consistency with all other iterator types. * Replace all cu_map() int32 with either it.ntype or op_return_ntype * Fix ntype vs types.uint64 mixup for distance in ConstantIterator, CountingIterator. This fixes the counting_float32, counting_float64 tests. * Remove float16 code and tests. Not needed at this stage. * Fix trivial (but slightly confusing) typo. * Remove unused host_address(self) methods. * Introduce _DEVICE_POINTER_SIZE, _DEVICE_POINTER_BITWIDTH, _DISTANCE_NUMBA_TYPE, _DISTANCE_IR_TYPE. Resolve most TODOs. * Generalize map_mul2 test as map_mul2_int32_int32, but still only test with int32. * Add more map_mul2 tests that all pass with the production code as-is. map_mul2_float32_int32 does not pass (commented out in this commit). * Make cu_map() code more readable by introducing _CHAR_PTR_NUMBA_TYPE, _CHAR_PTR_IR_TYPE * Archiving debugging code including bug fix in make_dereference_codegen() * Revert "Archiving debugging code including bug fix in make_dereference_codegen()" This reverts commit b525b80f3a9355eef0455f950b9c5fa9b92ae450. * FIX but does not work * Add more map_mul2 tests. For unknown reasons these all pass when run separately, or when running all tests in parallel (pytest -v -n 32). * Add type information to cu_map() op abi_name. * Simplify() helper function, reuse for compiling cu_map() op * Change map_mul2_xxx test names to map_mul2_count_xxx * Add map_mul3_map_mul2_count_int32_int32_int32, map_mul2_map_mul2_count_float64_float32_int16 tests. map_mul* tests are flaky,TBD why. * Introduce op_caller() in cu_map(): this solves the test failures (TBH, I cannot explain why/how). * Move `num_items` argument to just before `h_init`, for compatibility with cub `DeviceReduce` API. * Move iterators.py -> _iterators.py, add new iterators.py as public interface. * Pull repeat(), count() out from _iterators.py into iterators.py; add map() forwarding function (to renamed cumap()). * Move `class TransformIterator` out to module scope. This make it more clear where Python captures are used. It might also be a minor performance improvement because the `class` definition code is only processed once. * Replace the very awkward _Reduce.__handle_d_in() method with the new _d_in_as_cccl_iter() function. * Move _iterators.cache() to iterators.cache_load_modifier() * Fix minor oversights in docstrings. * Enable passing device array as `it` to `map()` (testing with cupy array). * Rename `l_input` -> `l_varr` (in test_reduce.py). * Rename `ldcs()` -> `load_cs()` and add reference to cub `CacheModifiedInputIterator` `LOAD_CS` * Move `*_advance` and `*_dereference` functions out of class scopes, to make it obvious that they are never used as Python methods, but exclusively as source for `numba.cuda.compile()` * Turn state_c_void_p, size, alignment methods into properties. * Improved organization of newly added tests. NO functional changes. Applied ruff format to newly added code. * Add comments pointing to issue #3064 * Change `ntype` to `value_type` in public APIs. Introduce `numba_type_from_any(value_type)` function. * Change names of function in public API. * Effectively undo commit 708c3411846de05d5dd169b8e6d6391f5d366233: Move `*_advance` and `*_dereference` functions back to class scope, with comments to explicitly state that these are not actual methods. * Use `@pytest.fixture` to replace most `@pytest.mark.parametrize`, as suggested by @shwina --- c/parallel/include/cccl/c/types.h | 13 + c/parallel/src/kernels/iterators.cpp | 4 +- c/parallel/src/reduce.cu | 48 +- .../cuda/parallel/experimental/__init__.py | 108 ++++- .../cuda/parallel/experimental/_iterators.py | 438 ++++++++++++++++++ .../cuda/parallel/experimental/iterators.py | 37 ++ python/cuda_parallel/tests/test_reduce.py | 263 ++++++++++- python/cuda_parallel/tests/test_reduce_api.py | 4 +- 8 files changed, 866 insertions(+), 49 deletions(-) create mode 100644 python/cuda_parallel/cuda/parallel/experimental/_iterators.py create mode 100644 python/cuda_parallel/cuda/parallel/experimental/iterators.py diff --git a/c/parallel/include/cccl/c/types.h b/c/parallel/include/cccl/c/types.h index 4cc9c13d26c..3b54dde0967 100644 --- a/c/parallel/include/cccl/c/types.h +++ b/c/parallel/include/cccl/c/types.h @@ -71,6 +71,18 @@ struct cccl_value_t void* state; }; +struct cccl_string_view +{ + const char* begin; + int size; +}; + +struct cccl_string_views +{ + cccl_string_view* views; + int size; +}; + struct cccl_iterator_t { int size; @@ -80,4 +92,5 @@ struct cccl_iterator_t cccl_op_t dereference; cccl_type_info value_type; void* state; + cccl_string_views* ltoirs = nullptr; }; diff --git a/c/parallel/src/kernels/iterators.cpp b/c/parallel/src/kernels/iterators.cpp index 90d718e94e1..31e85e9c22c 100644 --- a/c/parallel/src/kernels/iterators.cpp +++ b/c/parallel/src/kernels/iterators.cpp @@ -51,9 +51,9 @@ struct __align__(OP_ALIGNMENT) input_iterator_state_t { using difference_type = DIFF_T; using pointer = VALUE_T*; using reference = VALUE_T&; - __device__ inline value_type operator*() const { return DEREF(this); } + __device__ inline value_type operator*() const { return DEREF(data); } __device__ inline input_iterator_state_t& operator+=(difference_type diff) { - ADVANCE(this, diff); + ADVANCE(data, diff); return *this; } __device__ inline value_type operator[](difference_type diff) const { diff --git a/c/parallel/src/reduce.cu b/c/parallel/src/reduce.cu index 42d23486bcb..77675f0f29d 100644 --- a/c/parallel/src/reduce.cu +++ b/c/parallel/src/reduce.cu @@ -263,6 +263,12 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build( op_src, // 6 policy.vector_load_length); // 7 +#if false // CCCL_DEBUGGING_SWITCH + fflush(stderr); + printf("\nCODE4NVRTC BEGIN\n%sCODE4NVRTC END\n", src.c_str()); + fflush(stdout); +#endif + std::string single_tile_kernel_name = get_single_tile_kernel_name(input_it, output_it, op, init, false); std::string single_tile_second_kernel_name = get_single_tile_kernel_name(input_it, output_it, op, init, true); std::string reduction_kernel_name = get_device_reduce_kernel_name(op, input_it, init); @@ -287,16 +293,23 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build( } }; ltoir_list_append({op.ltoir, op.ltoir_size}); - if (cccl_iterator_kind_t::iterator == input_it.type) - { - ltoir_list_append({input_it.advance.ltoir, input_it.advance.ltoir_size}); - ltoir_list_append({input_it.dereference.ltoir, input_it.dereference.ltoir_size}); - } - if (cccl_iterator_kind_t::iterator == output_it.type) - { - ltoir_list_append({output_it.advance.ltoir, output_it.advance.ltoir_size}); - ltoir_list_append({output_it.dereference.ltoir, output_it.dereference.ltoir_size}); - } + auto extract_ltoirs = [ltoir_list_append](const cccl_iterator_t& it) { + if (cccl_iterator_kind_t::iterator == it.type) + { + ltoir_list_append({it.advance.ltoir, it.advance.ltoir_size}); + ltoir_list_append({it.dereference.ltoir, it.dereference.ltoir_size}); + if (it.ltoirs != nullptr) + { + for (int i = 0; i < it.ltoirs->size; i++) + { + auto view = it.ltoirs->views[i]; + ltoir_list_append({view.begin, view.size}); + } + } + } + }; + extract_ltoirs(input_it); + extract_ltoirs(output_it); nvrtc_cubin result = make_nvrtc_command_list() @@ -323,8 +336,11 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build( build->cubin_size = result.size; build->accumulator_size = accum_t.size; } - catch (...) + catch (const std::exception& exc) { + fflush(stderr); + printf("\nEXCEPTION in cccl_device_reduce_build(): %s\n", exc.what()); + fflush(stdout); error = CUDA_ERROR_UNKNOWN; } @@ -411,8 +427,11 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce( cub::CudaDriverLauncherFactory{cu_device, build.cc}, {get_accumulator_type(op, d_in, init)}); } - catch (...) + catch (const std::exception& exc) { + fflush(stderr); + printf("\nEXCEPTION in cccl_device_reduce(): %s\n", exc.what()); + fflush(stdout); error = CUDA_ERROR_UNKNOWN; } @@ -437,8 +456,11 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_cleanup(cccl_device_reduce_bui std::unique_ptr cubin(reinterpret_cast(bld_ptr->cubin)); check(cuLibraryUnload(bld_ptr->library)); } - catch (...) + catch (const std::exception& exc) { + fflush(stderr); + printf("\nEXCEPTION in cccl_device_reduce_cleanup(): %s\n", exc.what()); + fflush(stdout); return CUDA_ERROR_UNKNOWN; } diff --git a/python/cuda_parallel/cuda/parallel/experimental/__init__.py b/python/cuda_parallel/cuda/parallel/experimental/__init__.py index f3b3ebe2fd7..78641569880 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/__init__.py +++ b/python/cuda_parallel/cuda/parallel/experimental/__init__.py @@ -12,7 +12,7 @@ from numba.cuda.cudadrv import enums -# Should match C++ +# MUST match `cccl_type_enum` in c/include/cccl/c/types.h class _TypeEnum(ctypes.c_int): INT8 = 0 INT16 = 1 @@ -27,13 +27,30 @@ class _TypeEnum(ctypes.c_int): STORAGE = 10 -# Should match C++ +def _cccl_type_enum_as_name(enum_value): + assert isinstance(enum_value, int) + return ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "STORAGE", + )[enum_value] + + +# MUST match `cccl_op_kind_t` in c/include/cccl/c/types.h class _CCCLOpKindEnum(ctypes.c_int): STATELESS = 0 STATEFUL = 1 -# Should match C++ +# MUST match `cccl_iterator_kind_t` in c/include/cccl/c/types.h class _CCCLIteratorKindEnum(ctypes.c_int): POINTER = 0 ITERATOR = 1 @@ -57,13 +74,14 @@ def _type_to_enum(numba_type): return _TypeEnum.STORAGE -# TODO Extract into reusable module +# MUST match `cccl_type_info` in c/include/cccl/c/types.h class _TypeInfo(ctypes.Structure): _fields_ = [("size", ctypes.c_int), ("alignment", ctypes.c_int), ("type", _TypeEnum)] +# MUST match `cccl_op_t` in c/include/cccl/c/types.h class _CCCLOp(ctypes.Structure): _fields_ = [("type", _CCCLOpKindEnum), ("name", ctypes.c_char_p), @@ -74,6 +92,19 @@ class _CCCLOp(ctypes.Structure): ("state", ctypes.c_void_p)] +# MUST match `cccl_string_view` in c/include/cccl/c/types.h +class _CCCLStringView(ctypes.Structure): + _fields_ = [("begin", ctypes.c_char_p), + ("size", ctypes.c_int)] + + +# MUST match `cccl_string_views` in c/include/cccl/c/types.h +class _CCCLStringViews(ctypes.Structure): + _fields_ = [("views", ctypes.POINTER(_CCCLStringView)), + ("size", ctypes.c_int)] + + +# MUST match `cccl_iterator_t` in c/include/cccl/c/types.h class _CCCLIterator(ctypes.Structure): _fields_ = [("size", ctypes.c_int), ("alignment", ctypes.c_int), @@ -81,16 +112,17 @@ class _CCCLIterator(ctypes.Structure): ("advance", _CCCLOp), ("dereference", _CCCLOp), ("value_type", _TypeInfo), - ("state", ctypes.c_void_p)] + ("state", ctypes.c_void_p), + ("ltoirs", ctypes.POINTER(_CCCLStringViews))] +# MUST match `cccl_value_t` in c/include/cccl/c/types.h class _CCCLValue(ctypes.Structure): _fields_ = [("type", _TypeInfo), ("state", ctypes.c_void_p)] -def _type_to_info(numpy_type): - numba_type = numba.from_dtype(numpy_type) +def _type_to_info_from_numba_type(numba_type): context = cuda.descriptor.cuda_target.target_context size = context.get_value_type(numba_type).get_abi_size(context.target_data) alignment = context.get_value_type( @@ -98,12 +130,17 @@ def _type_to_info(numpy_type): return _TypeInfo(size, alignment, _type_to_enum(numba_type)) +def _type_to_info(numpy_type): + numba_type = numba.from_dtype(numpy_type) + return _type_to_info_from_numba_type(numba_type) + + def _device_array_to_pointer(array): dtype = array.dtype info = _type_to_info(dtype) # Note: this is slightly slower, but supports all ndarray-like objects as long as they support CAI # TODO: switch to use gpumemoryview once it's ready - return _CCCLIterator(1, 1, _CCCLIteratorKindEnum.POINTER, _CCCLOp(), _CCCLOp(), info, array.__cuda_array_interface__["data"][0]) + return _CCCLIterator(1, 1, _CCCLIteratorKindEnum.POINTER, _CCCLOp(), _CCCLOp(), info, array.__cuda_array_interface__["data"][0], None) def _host_array_to_value(array): @@ -123,6 +160,32 @@ def handle(self): return _CCCLOp(_CCCLOpKindEnum.STATELESS, self.name, ctypes.c_char_p(self.ltoir), len(self.ltoir), 1, 1, None) +def _extract_ctypes_ltoirs(numba_cuda_compile_results): + view_lst = [_CCCLStringView(ltoir, len(ltoir)) + for ltoir, _ in numba_cuda_compile_results] + view_arr = (_CCCLStringView * len(view_lst))(*view_lst) + return ctypes.pointer(_CCCLStringViews(view_arr, len(view_arr))) + + +def _facade_iter_as_cccl_iter(d_in): + def prefix_name(name): + return (d_in.prefix + "_" + name).encode('utf-8') + # type name ltoi ltoir_size size alignment state + adv = _CCCLOp(_CCCLOpKindEnum.STATELESS, prefix_name("advance"), None, 0, 1, 1, None) + drf = _CCCLOp(_CCCLOpKindEnum.STATELESS, prefix_name("dereference"), None, 0, 1, 1, None) + info = _type_to_info_from_numba_type(d_in.ntype) + ltoirs = _extract_ctypes_ltoirs(d_in.ltoirs) + # size alignment type advance dereference value_type state ltoirs + return _CCCLIterator(d_in.size, d_in.alignment, _CCCLIteratorKindEnum.ITERATOR, adv, drf, info, d_in.state_c_void_p, ltoirs) + + +def _d_in_as_cccl_iter(d_in): + if hasattr(d_in, 'ntype'): + return _facade_iter_as_cccl_iter(d_in) + assert hasattr(d_in, 'dtype') + return _device_array_to_pointer(d_in) + + def _get_cuda_path(): cuda_path = os.environ.get('CUDA_PATH', '') if os.path.exists(cuda_path): @@ -177,6 +240,7 @@ def _get_paths(): return _paths +# MUST match `cccl_device_reduce_build_result_t` in c/include/cccl/c/reduce.h class _CCCLDeviceReduceBuildResult(ctypes.Structure): _fields_ = [("cc", ctypes.c_int), ("cubin", ctypes.c_void_p), @@ -195,7 +259,9 @@ def _dtype_validation(dt1, dt2): class _Reduce: def __init__(self, d_in, d_out, op, init): - self._ctor_d_in_dtype = d_in.dtype + d_in_cccl = _d_in_as_cccl_iter(d_in) + self._ctor_d_in_cccl_type_enum_name = _cccl_type_enum_as_name( + d_in_cccl.value_type.type.value) self._ctor_d_out_dtype = d_out.dtype self._ctor_init_dtype = init.dtype cc_major, cc_minor = cuda.get_current_device().compute_capability @@ -203,13 +269,12 @@ def __init__(self, d_in, d_out, op, init): bindings = _get_bindings() accum_t = init.dtype self.op_wrapper = _Op(accum_t, op) - d_in_ptr = _device_array_to_pointer(d_in) d_out_ptr = _device_array_to_pointer(d_out) self.build_result = _CCCLDeviceReduceBuildResult() # TODO Figure out caching error = bindings.cccl_device_reduce_build(ctypes.byref(self.build_result), - d_in_ptr, + d_in_cccl, d_out_ptr, self.op_wrapper.handle(), _host_array_to_value(init), @@ -223,9 +288,18 @@ def __init__(self, d_in, d_out, op, init): if error != enums.CUDA_SUCCESS: raise ValueError('Error building reduce') - def __call__(self, temp_storage, d_in, d_out, init): - # TODO validate POINTER vs ITERATOR when iterator support is added - _dtype_validation(self._ctor_d_in_dtype, d_in.dtype) + def __call__(self, temp_storage, d_in, d_out, num_items, init): + d_in_cccl = _d_in_as_cccl_iter(d_in) + if d_in_cccl.type.value == _CCCLIteratorKindEnum.ITERATOR: + assert num_items is not None + else: + assert d_in_cccl.type.value == _CCCLIteratorKindEnum.POINTER + if num_items is None: + num_items = d_in.size + else: + assert num_items == d_in.size + _dtype_validation(self._ctor_d_in_cccl_type_enum_name, + _cccl_type_enum_as_name(d_in_cccl.value_type.type.value)) _dtype_validation(self._ctor_d_out_dtype, d_out.dtype) _dtype_validation(self._ctor_init_dtype, init.dtype) bindings = _get_bindings() @@ -237,15 +311,13 @@ def __call__(self, temp_storage, d_in, d_out, init): # Note: this is slightly slower, but supports all ndarray-like objects as long as they support CAI # TODO: switch to use gpumemoryview once it's ready d_temp_storage = temp_storage.__cuda_array_interface__["data"][0] - d_in_ptr = _device_array_to_pointer(d_in) d_out_ptr = _device_array_to_pointer(d_out) - num_items = ctypes.c_ulonglong(d_in.size) error = bindings.cccl_device_reduce(self.build_result, d_temp_storage, ctypes.byref(temp_storage_bytes), - d_in_ptr, + d_in_cccl, d_out_ptr, - num_items, + ctypes.c_ulonglong(num_items), self.op_wrapper.handle(), _host_array_to_value(init), None) diff --git a/python/cuda_parallel/cuda/parallel/experimental/_iterators.py b/python/cuda_parallel/cuda/parallel/experimental/_iterators.py new file mode 100644 index 00000000000..a229da9f191 --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/_iterators.py @@ -0,0 +1,438 @@ +import ctypes +import operator + +from numba.core import cgutils +from llvmlite import ir +from numba.core.typing import signature +from numba.core.extending import intrinsic, overload +import numba +import numba.cuda +import numba.types + + +_DEVICE_POINTER_SIZE = 8 +_DEVICE_POINTER_BITWIDTH = _DEVICE_POINTER_SIZE * 8 +_DISTANCE_NUMBA_TYPE = numba.types.uint64 +_DISTANCE_IR_TYPE = ir.IntType(64) +_CHAR_PTR_NUMBA_TYPE = numba.types.CPointer(numba.types.int8) +_CHAR_PTR_IR_TYPE = ir.PointerType(ir.IntType(8)) + + +def numba_type_from_any(value_type): + return getattr(numba.types, str(value_type)) + + +def _sizeof_numba_type(ntype): + mapping = { + numba.types.int8: 1, + numba.types.int16: 2, + numba.types.int32: 4, + numba.types.int64: 8, + numba.types.uint8: 1, + numba.types.uint16: 2, + numba.types.uint32: 4, + numba.types.uint64: 8, + numba.types.float32: 4, + numba.types.float64: 8, + } + return mapping[ntype] + + +def _ctypes_type_given_numba_type(ntype): + mapping = { + numba.types.int8: ctypes.c_int8, + numba.types.int16: ctypes.c_int16, + numba.types.int32: ctypes.c_int32, + numba.types.int64: ctypes.c_int64, + numba.types.uint8: ctypes.c_uint8, + numba.types.uint16: ctypes.c_uint16, + numba.types.uint32: ctypes.c_uint32, + numba.types.uint64: ctypes.c_uint64, + numba.types.float32: ctypes.c_float, + numba.types.float64: ctypes.c_double, + } + return mapping[ntype] + + +def _ncc(abi_name, pyfunc, sig): + return numba.cuda.compile( + pyfunc=pyfunc, sig=sig, abi_info={"abi_name": abi_name}, output="ltoir" + ) + + +def sizeof_pointee(context, ptr): + size = context.get_abi_sizeof(ptr.type.pointee) + return ir.Constant(ir.IntType(_DEVICE_POINTER_BITWIDTH), size) + + +@intrinsic +def pointer_add_intrinsic(context, ptr, offset): + def codegen(context, builder, sig, args): + ptr, index = args + base = builder.ptrtoint(ptr, ir.IntType(_DEVICE_POINTER_BITWIDTH)) + offset = builder.mul(index, sizeof_pointee(context, ptr)) + result = builder.add(base, offset) + return builder.inttoptr(result, ptr.type) + + return ptr(ptr, offset), codegen + + +@overload(operator.add) +def pointer_add(ptr, offset): + if not isinstance(ptr, numba.types.CPointer) or not isinstance( + offset, numba.types.Integer + ): + return + + def impl(ptr, offset): + return pointer_add_intrinsic(ptr, offset) + + return impl + + +class RawPointer: + def __init__(self, ptr, ntype): + self.val = ctypes.c_void_p(ptr) + data_as_ntype_pp = numba.types.CPointer(numba.types.CPointer(ntype)) + self.ntype = ntype + self.prefix = "pointer_" + ntype.name + self.ltoirs = [ + _ncc( + f"{self.prefix}_advance", + RawPointer.pointer_advance, + numba.types.void(data_as_ntype_pp, _DISTANCE_NUMBA_TYPE), + ), + _ncc( + f"{self.prefix}_dereference", + RawPointer.pointer_dereference, + ntype(data_as_ntype_pp), + ), + ] + + # Exclusively for numba.cuda.compile (this is not an actual method). + def pointer_advance(this, distance): + this[0] = this[0] + distance + + # Exclusively for numba.cuda.compile (this is not an actual method). + def pointer_dereference(this): + return this[0][0] + + @property + def state_c_void_p(self): + return ctypes.cast(ctypes.pointer(self.val), ctypes.c_void_p) + + @property + def size(self): + return _DEVICE_POINTER_SIZE + + @property + def alignment(self): + return _DEVICE_POINTER_SIZE + + +def pointer(container, ntype): + return RawPointer(container.__cuda_array_interface__["data"][0], ntype) + + +def _ir_type_given_numba_type(ntype): + bw = ntype.bitwidth + irt = None + if isinstance(ntype, numba.core.types.scalars.Integer): + irt = ir.IntType(bw) + elif isinstance(ntype, numba.core.types.scalars.Float): + if bw == 32: + irt = ir.FloatType() + elif bw == 64: + irt = ir.DoubleType() + return irt + + +@intrinsic +def load_cs(typingctx, base): + # Corresponding to `LOAD_CS` here: + # https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html + def codegen(context, builder, sig, args): + rt = _ir_type_given_numba_type(sig.return_type) + if rt is None: + raise RuntimeError(f"Unsupported: {type(sig.return_type)=}") + ftype = ir.FunctionType(rt, [rt.as_pointer()]) + bw = sig.return_type.bitwidth + asm_txt = f"ld.global.cs.b{bw} $0, [$1];" + if bw < 64: + constraint = "=r, l" + else: + constraint = "=l, l" + asm_ir = ir.InlineAsm(ftype, asm_txt, constraint) + return builder.call(asm_ir, args) + + return base.dtype(base), codegen + + +class CacheModifiedPointer: + def __init__(self, ptr, ntype): + self.val = ctypes.c_void_p(ptr) + self.ntype = ntype + data_as_ntype_pp = numba.types.CPointer(numba.types.CPointer(ntype)) + self.prefix = "cache" + ntype.name + self.ltoirs = [ + _ncc( + f"{self.prefix}_advance", + CacheModifiedPointer.cache_advance, + numba.types.void(data_as_ntype_pp, _DISTANCE_NUMBA_TYPE), + ), + _ncc( + f"{self.prefix}_dereference", + CacheModifiedPointer.cache_dereference, + ntype(data_as_ntype_pp), + ), + ] + + # Exclusively for numba.cuda.compile (this is not an actual method). + def cache_advance(this, distance): + this[0] = this[0] + distance + + # Exclusively for numba.cuda.compile (this is not an actual method). + def cache_dereference(this): + return load_cs(this[0]) + + @property + def state_c_void_p(self): + return ctypes.cast(ctypes.pointer(self.val), ctypes.c_void_p) + + @property + def size(self): + return _DEVICE_POINTER_SIZE + + @property + def alignment(self): + return _DEVICE_POINTER_SIZE + + +class ConstantIterator: + def __init__(self, val, ntype): + thisty = numba.types.CPointer(ntype) + self.val = _ctypes_type_given_numba_type(ntype)(val) + self.ntype = ntype + self.prefix = "constant_" + ntype.name + self.ltoirs = [ + _ncc( + f"{self.prefix}_advance", + ConstantIterator.constant_advance, + numba.types.void(thisty, _DISTANCE_NUMBA_TYPE), + ), + _ncc( + f"{self.prefix}_dereference", + ConstantIterator.constant_dereference, + ntype(thisty), + ), + ] + + # Exclusively for numba.cuda.compile (this is not an actual method). + def constant_advance(this, _): + pass + + # Exclusively for numba.cuda.compile (this is not an actual method). + def constant_dereference(this): + return this[0] + + @property + def state_c_void_p(self): + return ctypes.cast(ctypes.pointer(self.val), ctypes.c_void_p) + + @property + def size(self): + return self.ntype.bitwidth // 8 + + @property + def alignment(self): + return self.size + + +class CountingIterator: + def __init__(self, count, ntype): + thisty = numba.types.CPointer(ntype) + self.count = _ctypes_type_given_numba_type(ntype)(count) + self.ntype = ntype + self.prefix = "count_" + ntype.name + self.ltoirs = [ + _ncc( + f"{self.prefix}_advance", + CountingIterator.count_advance, + numba.types.void(thisty, _DISTANCE_NUMBA_TYPE), + ), + _ncc( + f"{self.prefix}_dereference", + CountingIterator.count_dereference, + ntype(thisty), + ), + ] + + # Exclusively for numba.cuda.compile (this is not an actual method). + def count_advance(this, diff): + this[0] += diff + + # Exclusively for numba.cuda.compile (this is not an actual method). + def count_dereference(this): + return this[0] + + @property + def state_c_void_p(self): + return ctypes.cast(ctypes.pointer(self.count), ctypes.c_void_p) + + @property + def size(self): + return self.ntype.bitwidth // 8 + + @property + def alignment(self): + return self.size + + +class TransformIteratorImpl: + def __init__( + self, + it, + op, + op_return_ntype, + transform_advance, + transform_dereference, + op_abi_name, + ): + self.it = it + self.ntype = op_return_ntype + self.prefix = f"transform_{it.prefix}_{op.__name__}" + self.ltoirs = it.ltoirs + [ + _ncc( + f"{self.prefix}_advance", + transform_advance, + numba.types.void( + numba.types.CPointer(numba.types.char), _DISTANCE_NUMBA_TYPE + ), + ), + _ncc( + f"{self.prefix}_dereference", + transform_dereference, + op_return_ntype(numba.types.CPointer(numba.types.char)), + ), + # ATTENTION: NOT op_caller here! (see issue #3064) + _ncc(op_abi_name, op, op_return_ntype(it.ntype)), + ] + + @property + def state_c_void_p(self): + return self.it.state_c_void_p + + @property + def size(self): + return self.it.size # TODO fix for stateful op + + @property + def alignment(self): + return self.it.alignment # TODO fix for stateful op + + +def TransformIterator(op, it, op_return_ntype): + # TODO(rwgk): Resolve issue #3064 + + op_return_ntype_ir = _ir_type_given_numba_type(op_return_ntype) + if op_return_ntype_ir is None: + raise RuntimeError(f"Unsupported: {type(op_return_ntype)=}") + if hasattr(it, "dtype"): + assert not hasattr(it, "ntype") + it = pointer(it, getattr(numba.types, str(it.dtype))) + it_ntype_ir = _ir_type_given_numba_type(it.ntype) + if it_ntype_ir is None: + raise RuntimeError(f"Unsupported: {type(it.ntype)=}") + + def source_advance(it_state_ptr, diff): + pass + + def make_advance_codegen(name): + def codegen(context, builder, sig, args): + state_ptr, dist = args + fnty = ir.FunctionType( + ir.VoidType(), (_CHAR_PTR_IR_TYPE, _DISTANCE_IR_TYPE) + ) + fn = cgutils.get_or_insert_function(builder.module, fnty, name) + builder.call(fn, (state_ptr, dist)) + + return signature( + numba.types.void, _CHAR_PTR_NUMBA_TYPE, _DISTANCE_NUMBA_TYPE + ), codegen + + def advance_codegen(func_to_overload, name): + @intrinsic + def intrinsic_impl(typingctx, it_state_ptr, diff): + return make_advance_codegen(name) + + @overload(func_to_overload, target="cuda") + def impl(it_state_ptr, diff): + def impl(it_state_ptr, diff): + return intrinsic_impl(it_state_ptr, diff) + + return impl + + def source_dereference(it_state_ptr): + pass + + def make_dereference_codegen(name): + def codegen(context, builder, sig, args): + (state_ptr,) = args + fnty = ir.FunctionType(it_ntype_ir, (_CHAR_PTR_IR_TYPE,)) + fn = cgutils.get_or_insert_function(builder.module, fnty, name) + return builder.call(fn, (state_ptr,)) + + return signature(it.ntype, _CHAR_PTR_NUMBA_TYPE), codegen + + def dereference_codegen(func_to_overload, name): + @intrinsic + def intrinsic_impl(typingctx, it_state_ptr): + return make_dereference_codegen(name) + + @overload(func_to_overload, target="cuda") + def impl(it_state_ptr): + def impl(it_state_ptr): + return intrinsic_impl(it_state_ptr) + + return impl + + def make_op_codegen(name): + def codegen(context, builder, sig, args): + (val,) = args + fnty = ir.FunctionType(op_return_ntype_ir, (it_ntype_ir,)) + fn = cgutils.get_or_insert_function(builder.module, fnty, name) + return builder.call(fn, (val,)) + + return signature(op_return_ntype, it.ntype), codegen + + def op_codegen(func_to_overload, name): + @intrinsic + def intrinsic_impl(typingctx, val): + return make_op_codegen(name) + + @overload(func_to_overload, target="cuda") + def impl(val): + def impl(val): + return intrinsic_impl(val) + + return impl + + advance_codegen(source_advance, f"{it.prefix}_advance") + dereference_codegen(source_dereference, f"{it.prefix}_dereference") + + def op_caller(value): + return op(value) + + op_abi_name = f"{op.__name__}_{op_return_ntype.name}_{it.ntype.name}" + op_codegen(op_caller, op_abi_name) + + def transform_advance(it_state_ptr, diff): + source_advance(it_state_ptr, diff) # just a function call + + def transform_dereference(it_state_ptr): + # ATTENTION: op_caller here (see issue #3064) + return op_caller(source_dereference(it_state_ptr)) + + return TransformIteratorImpl( + it, op, op_return_ntype, transform_advance, transform_dereference, op_abi_name + ) diff --git a/python/cuda_parallel/cuda/parallel/experimental/iterators.py b/python/cuda_parallel/cuda/parallel/experimental/iterators.py new file mode 100644 index 00000000000..605dbc4544e --- /dev/null +++ b/python/cuda_parallel/cuda/parallel/experimental/iterators.py @@ -0,0 +1,37 @@ +from . import _iterators + + +def CacheModifiedInputIterator(device_array, value_type, modifier): + """Python fascade for Random Access Cache Modified Iterator that wraps a native device pointer. + + Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html + + Currently the only supported modifier is "stream" (LOAD_CS). + """ + if modifier != "stream": + raise NotImplementedError("Only stream modifier is supported") + return _iterators.CacheModifiedPointer( + device_array.__cuda_array_interface__["data"][0], + _iterators.numba_type_from_any(value_type), + ) + + +def ConstantIterator(value, value_type): + """Python fascade (similar to itertools.repeat) for C++ Random Access ConstantIterator.""" + return _iterators.ConstantIterator( + value, _iterators.numba_type_from_any(value_type) + ) + + +def CountingIterator(offset, value_type): + """Python fascade (similar to itertools.count) for C++ Random Access CountingIterator.""" + return _iterators.CountingIterator( + offset, _iterators.numba_type_from_any(value_type) + ) + + +def TransformIterator(op, it, op_return_value_type): + """Python fascade (similar to built-in map) mimicking a C++ Random Access TransformIterator.""" + return _iterators.TransformIterator( + op, it, _iterators.numba_type_from_any(op_return_value_type) + ) diff --git a/python/cuda_parallel/tests/test_reduce.py b/python/cuda_parallel/tests/test_reduce.py index 78c14b47931..99ce64b146a 100644 --- a/python/cuda_parallel/tests/test_reduce.py +++ b/python/cuda_parallel/tests/test_reduce.py @@ -2,10 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# TODO(rwgk): Format entire file. +# fmt: off + +import cupy as cp import numpy import pytest -from numba import cuda +import random +import re +import numba.cuda +import numba.types import cuda.parallel.experimental as cudax +from cuda.parallel.experimental import _iterators +from cuda.parallel.experimental import iterators def random_int(shape, dtype): @@ -32,17 +41,17 @@ def op(a, b): init_value = 42 h_init = numpy.array([init_value], dtype=dtype) - d_output = cuda.device_array(1, dtype=dtype) + d_output = numba.cuda.device_array(1, dtype=dtype) reduce_into = cudax.reduce_into(d_output, d_output, op, h_init) for num_items_pow2 in type_to_problem_sizes(dtype): num_items = 2 ** num_items_pow2 h_input = random_int(num_items, dtype) - d_input = cuda.to_device(h_input) - temp_storage_size = reduce_into(None, d_input, d_output, h_init) - d_temp_storage = cuda.device_array( + d_input = numba.cuda.to_device(h_input) + temp_storage_size = reduce_into(None, d_input, d_output, None, h_init) + d_temp_storage = numba.cuda.device_array( temp_storage_size, dtype=numpy.uint8) - reduce_into(d_temp_storage, d_input, d_output, h_init) + reduce_into(d_temp_storage, d_input, d_output, None, h_init) h_output = d_output.copy_to_host() assert h_output[0] == sum(h_input) + init_value @@ -52,16 +61,16 @@ def op(a, b): return a + b h_init = numpy.array([40.0 + 2.0j], dtype=complex) - d_output = cuda.device_array(1, dtype=complex) + d_output = numba.cuda.device_array(1, dtype=complex) reduce_into = cudax.reduce_into(d_output, d_output, op, h_init) for num_items in [42, 420000]: h_input = numpy.random.random( num_items) + 1j * numpy.random.random(num_items) - d_input = cuda.to_device(h_input) - temp_storage_bytes = reduce_into(None, d_input, d_output, h_init) - d_temp_storage = cuda.device_array(temp_storage_bytes, numpy.uint8) - reduce_into(d_temp_storage, d_input, d_output, h_init) + d_input = numba.cuda.to_device(h_input) + temp_storage_bytes = reduce_into(None, d_input, d_output, None, h_init) + d_temp_storage = numba.cuda.device_array(temp_storage_bytes, numpy.uint8) + reduce_into(d_temp_storage, d_input, d_output, None, h_init) result = d_output.copy_to_host()[0] expected = numpy.sum(h_input, initial=h_init[0]) @@ -75,11 +84,237 @@ def min_op(a, b): dtypes = [numpy.int32, numpy.int64] h_inits = [numpy.array([], dt) for dt in dtypes] h_inputs = [numpy.array([], dt) for dt in dtypes] - d_outputs = [cuda.device_array(1, dt) for dt in dtypes] - d_inputs = [cuda.to_device(h_inp) for h_inp in h_inputs] + d_outputs = [numba.cuda.device_array(1, dt) for dt in dtypes] + d_inputs = [numba.cuda.to_device(h_inp) for h_inp in h_inputs] reduce_into = cudax.reduce_into(d_inputs[0], d_outputs[0], min_op, h_inits[0]) for ix in range(3): with pytest.raises(TypeError, match=r"^dtype mismatch: __init__=int32, __call__=int64$"): - reduce_into(None, d_inputs[int(ix == 0)], d_outputs[int(ix == 1)], h_inits[int(ix == 2)]) + reduce_into(None, d_inputs[int(ix == 0)], d_outputs[int(ix == 1)], None, h_inits[int(ix == 2)]) + + +# fmt: on +def _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array +): + def add_op(a, b): + return a + b + + expected_result = start_sum_with + for v in l_varr: + expected_result = add_op(expected_result, v) + + if use_numpy_array: + h_input = numpy.array(l_varr, dtype_inp) + d_input = numba.cuda.to_device(h_input) + else: + d_input = i_input + + d_output = numba.cuda.device_array(1, dtype_out) # to store device sum + + h_init = numpy.array([start_sum_with], dtype_out) + + reduce_into = cudax.reduce_into( + d_in=d_input, d_out=d_output, op=add_op, init=h_init + ) + + temp_storage_size = reduce_into( + None, d_in=d_input, d_out=d_output, num_items=len(l_varr), init=h_init + ) + d_temp_storage = numba.cuda.device_array(temp_storage_size, dtype=numpy.uint8) + + reduce_into(d_temp_storage, d_input, d_output, len(l_varr), h_init) + + h_output = d_output.copy_to_host() + assert h_output[0] == expected_result + + +def mul2(val): + return 2 * val + + +def mul3(val): + return 3 * val + + +SUPPORTED_VALUE_TYPE_NAMES = ( + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + "float32", + "float64", +) + + +@pytest.fixture(params=SUPPORTED_VALUE_TYPE_NAMES) +def supported_value_type(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def use_numpy_array(request): + return request.param + + +@pytest.mark.parametrize( + "type_obj_from_str", [_iterators.numba_type_from_any, numpy.dtype, cp.dtype] +) +def test_value_type_name_round_trip(type_obj_from_str, supported_value_type): + # If all round trip tests here pass for all value types we are supporting, + # this provides a super easy way to support numba.types, numpy.dtypes, + # cupy.dtypes and plain strings as `value_type` arguments. + type_obj = type_obj_from_str(supported_value_type) + assert str(type_obj) == supported_value_type + + +def test_device_sum_raw_pointer_it( + use_numpy_array, supported_value_type, num_items=3, start_sum_with=10 +): + # Exercise non-public _iterators.pointer() independently from iterators.TransformIterator(). + rng = random.Random(0) + l_varr = [rng.randrange(100) for _ in range(num_items)] + dtype_inp = numpy.dtype(supported_value_type) + dtype_out = dtype_inp + raw_pointer_devarr = numba.cuda.to_device(numpy.array(l_varr, dtype=dtype_inp)) + i_input = _iterators.pointer( + raw_pointer_devarr, _iterators.numba_type_from_any(supported_value_type) + ) + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) + + +def test_device_sum_cache_modified_input_it( + use_numpy_array, supported_value_type, num_items=3, start_sum_with=10 +): + rng = random.Random(0) + l_varr = [rng.randrange(100) for _ in range(num_items)] + dtype_inp = numpy.dtype(supported_value_type) + dtype_out = dtype_inp + input_devarr = numba.cuda.to_device(numpy.array(l_varr, dtype=dtype_inp)) + i_input = iterators.CacheModifiedInputIterator( + input_devarr, value_type=supported_value_type, modifier="stream" + ) + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) + + +def test_device_sum_constant_it( + use_numpy_array, supported_value_type, num_items=3, start_sum_with=10 +): + l_varr = [42 for distance in range(num_items)] + dtype_inp = numpy.dtype(supported_value_type) + dtype_out = dtype_inp + i_input = iterators.ConstantIterator(42, value_type=supported_value_type) + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) + + +def test_device_sum_counting_it( + use_numpy_array, supported_value_type, num_items=3, start_sum_with=10 +): + l_varr = [start_sum_with + distance for distance in range(num_items)] + dtype_inp = numpy.dtype(supported_value_type) + dtype_out = dtype_inp + i_input = iterators.CountingIterator( + start_sum_with, value_type=supported_value_type + ) + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) + + +@pytest.mark.parametrize( + "value_type_name_pair", + list(zip(SUPPORTED_VALUE_TYPE_NAMES, SUPPORTED_VALUE_TYPE_NAMES)) + + [ + ("float32", "int16"), + ("float32", "int32"), + ("float64", "int32"), + ("float64", "int64"), + ("int64", "float32"), + ], +) +def test_device_sum_map_mul2_count_it( + use_numpy_array, value_type_name_pair, num_items=3, start_sum_with=10 +): + l_varr = [2 * (start_sum_with + distance) for distance in range(num_items)] + vtn_out, vtn_inp = value_type_name_pair + dtype_inp = numpy.dtype(vtn_inp) + dtype_out = numpy.dtype(vtn_out) + i_input = iterators.TransformIterator( + mul2, + iterators.CountingIterator(start_sum_with, value_type=vtn_inp), + op_return_value_type=vtn_out, + ) + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) + + +@pytest.mark.parametrize( + ("fac_out", "fac_mid", "vtn_out", "vtn_mid", "vtn_inp"), + [ + (3, 2, "int32", "int32", "int32"), + (2, 2, "float64", "float32", "int16"), + ], +) +def test_device_sum_map_mul_map_mul_count_it( + use_numpy_array, + fac_out, + fac_mid, + vtn_out, + vtn_mid, + vtn_inp, + num_items=3, + start_sum_with=10, +): + l_varr = [ + fac_out * (fac_mid * (start_sum_with + distance)) + for distance in range(num_items) + ] + dtype_inp = numpy.dtype(vtn_inp) + dtype_out = numpy.dtype(vtn_out) + mul_funcs = {2: mul2, 3: mul3} + i_input = iterators.TransformIterator( + mul_funcs[fac_out], + iterators.TransformIterator( + mul_funcs[fac_mid], + iterators.CountingIterator(start_sum_with, value_type=vtn_inp), + op_return_value_type=vtn_mid, + ), + op_return_value_type=vtn_out, + ) + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) + + +@pytest.mark.parametrize( + "value_type_name_pair", + [ + ("int32", "int32"), + ("int64", "int32"), + ("int32", "int64"), + ], +) +def test_device_sum_map_mul2_cp_array_it( + use_numpy_array, value_type_name_pair, num_items=3, start_sum_with=10 +): + vtn_out, vtn_inp = value_type_name_pair + dtype_inp = numpy.dtype(vtn_inp) + dtype_out = numpy.dtype(vtn_out) + rng = random.Random(0) + l_d_in = [rng.randrange(100) for _ in range(num_items)] + a_d_in = cp.array(l_d_in, dtype_inp) + i_input = iterators.TransformIterator(mul2, a_d_in, vtn_out) + l_varr = [mul2(v) for v in l_d_in] + _test_device_sum_with_iterator( + l_varr, start_sum_with, i_input, dtype_inp, dtype_out, use_numpy_array + ) diff --git a/python/cuda_parallel/tests/test_reduce_api.py b/python/cuda_parallel/tests/test_reduce_api.py index 9eccee8622c..afed1caef1a 100644 --- a/python/cuda_parallel/tests/test_reduce_api.py +++ b/python/cuda_parallel/tests/test_reduce_api.py @@ -25,13 +25,13 @@ def min_op(a, b): reduce_into = cudax.reduce_into(d_output, d_output, min_op, h_init) # Determine temporary device storage requirements - temp_storage_size = reduce_into(None, d_input, d_output, h_init) + temp_storage_size = reduce_into(None, d_input, d_output, None, h_init) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction - reduce_into(d_temp_storage, d_input, d_output, h_init) + reduce_into(d_temp_storage, d_input, d_output, None, h_init) # Check the result is correct expected_output = 0