From 9fe8dfed5fa5660f5b0426313fa6aeb10724d28b Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 10 Feb 2025 10:43:45 -0500 Subject: [PATCH] Fixes following merge from main --- c/parallel/src/kernels/operators.cpp | 2 +- c/parallel/src/kernels/operators.h | 2 +- c/parallel/src/merge_sort.cu | 6 +++--- c/parallel/src/reduce.cu | 2 +- c/parallel/src/scan.cu | 8 +++++--- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/c/parallel/src/kernels/operators.cpp b/c/parallel/src/kernels/operators.cpp index 87dc5ce6d30..bbde349993e 100644 --- a/c/parallel/src/kernels/operators.cpp +++ b/c/parallel/src/kernels/operators.cpp @@ -71,7 +71,7 @@ make_kernel_binary_operator_full_source(std::string_view input_t, cccl_op_t oper : std::format(stateful_binary_op_template, return_type)); } -std::string make_kernel_user_arithmetic_operator(std::string_view input_t, cccl_op_t operation) +std::string make_kernel_user_binary_operator(std::string_view input_t, cccl_op_t operation) { return make_kernel_binary_operator_full_source(input_t, operation, "VALUE_T"); } diff --git a/c/parallel/src/kernels/operators.h b/c/parallel/src/kernels/operators.h index 2e8e11df39e..2e269857572 100644 --- a/c/parallel/src/kernels/operators.h +++ b/c/parallel/src/kernels/operators.h @@ -14,6 +14,6 @@ #include -std::string make_kernel_user_arithmetic_operator(std::string_view input_value_t, cccl_op_t operation); +std::string make_kernel_user_binary_operator(std::string_view input_value_t, cccl_op_t operation); std::string make_kernel_user_comparison_operator(std::string_view input_value_t, cccl_op_t operation); diff --git a/c/parallel/src/merge_sort.cu b/c/parallel/src/merge_sort.cu index 4ff17376e4f..edb03ceca17 100644 --- a/c/parallel/src/merge_sort.cu +++ b/c/parallel/src/merge_sort.cu @@ -410,7 +410,7 @@ extern "C" CCCL_C_API CUresult cccl_device_merge_sort_build( ltoir_list_append({output_items_it.dereference.ltoir, output_items_it.dereference.ltoir_size}); } - nvrtc_cubin result = + nvrtc_link_result result = make_nvrtc_command_list() .add_program(nvrtc_translation_unit{src.c_str(), name}) .add_expression({block_sort_kernel_name}) @@ -424,13 +424,13 @@ extern "C" CCCL_C_API CUresult cccl_device_merge_sort_build( .add_link_list(ltoir_list) .finalize_program(num_lto_args, lopts); - cuLibraryLoadData(&build->library, result.cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0); + cuLibraryLoadData(&build->library, result.data.get(), nullptr, nullptr, 0, nullptr, nullptr, 0); check(cuLibraryGetKernel(&build->block_sort_kernel, build->library, block_sort_kernel_lowered_name.c_str())); check(cuLibraryGetKernel(&build->partition_kernel, build->library, partition_kernel_lowered_name.c_str())); check(cuLibraryGetKernel(&build->merge_kernel, build->library, merge_kernel_lowered_name.c_str())); build->cc = cc; - build->cubin = (void*) result.cubin.release(); + build->cubin = (void*) result.data.release(); build->cubin_size = result.size; } catch (const std::exception& exc) diff --git a/c/parallel/src/reduce.cu b/c/parallel/src/reduce.cu index bc38090bcc2..09ee9268e92 100644 --- a/c/parallel/src/reduce.cu +++ b/c/parallel/src/reduce.cu @@ -273,7 +273,7 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build( const std::string output_iterator_src = make_kernel_output_iterator(offset_t, "output_iterator_t", accum_cpp, output_it); - const std::string op_src = make_kernel_user_arithmetic_operator(accum_cpp, op); + const std::string op_src = make_kernel_user_binary_operator(accum_cpp, op); const std::string src = std::format( "#include \n" diff --git a/c/parallel/src/scan.cu b/c/parallel/src/scan.cu index 1777876182f..844097a4867 100644 --- a/c/parallel/src/scan.cu +++ b/c/parallel/src/scan.cu @@ -294,8 +294,10 @@ extern "C" CCCL_C_API CUresult cccl_device_scan_build( const auto input_it_value_t = cccl_type_enum_to_string(input_it.value_type.type); const auto offset_t = cccl_type_enum_to_string(cccl_type_enum::UINT64); - const std::string input_iterator_src = make_kernel_input_iterator(offset_t, input_it_value_t, input_it); - const std::string output_iterator_src = make_kernel_output_iterator(offset_t, accum_cpp, output_it); + const std::string input_iterator_src = + make_kernel_input_iterator(offset_t, "input_iterator_state_t", input_it_value_t, input_it); + const std::string output_iterator_src = + make_kernel_output_iterator(offset_t, "output_iterator_t", accum_cpp, output_it); const std::string op_src = make_kernel_user_binary_operator(accum_cpp, op); @@ -472,8 +474,8 @@ extern "C" CCCL_C_API CUresult cccl_device_scan( indirect_arg_t, ::cuda::std::size_t, void, - scan::dynamic_scan_policy_t<&scan::get_policy>, cub::ForceInclusive::No, + scan::dynamic_scan_policy_t<&scan::get_policy>, scan::scan_kernel_source, cub::detail::CudaDriverLauncherFactory>:: Dispatch(