diff --git a/c/parallel/src/scan.cu b/c/parallel/src/scan.cu index 777ab718b80..053d5104c59 100644 --- a/c/parallel/src/scan.cu +++ b/c/parallel/src/scan.cu @@ -160,9 +160,11 @@ std::string get_scan_kernel_name(cccl_iterator_t input_it, cccl_iterator_t outpu init_t); } +static constexpr auto ptx_u64_assignment_regex = R"(\.visible\s+\.global\s+\.align\s+\d+\s+\.u64\s+{}\s*=\s*(\d+);)"; + size_t find_size_t(char* ptx, std::string_view name) { - std::regex regex(std::format(R"(\.visible\s+\.global\s+\.align\s+\d+\s+\.u64\s+{}\s*=\s*(\d+);)", name)); + std::regex regex(std::format(ptx_u64_assignment_regex, name)); std::cmatch match; if (std::regex_search(ptx, match, regex)) { @@ -175,6 +177,20 @@ size_t find_size_t(char* ptx, std::string_view name) } } +size_t find_size_t(char* ptx, std::string_view name, size_t default_value) +{ + std::regex regex(std::format(ptx_u64_assignment_regex, name)); + std::cmatch match; + if (std::regex_search(ptx, match, regex)) + { + return std::stoi(match[1].str()); + } + else + { + return default_value; + } +} + struct scan_tile_state { // scan_tile_state implements the same (host) interface as cub::ScanTileStateT, except @@ -188,95 +204,38 @@ struct scan_tile_state void* d_tile_partial; void* d_tile_inclusive; - bool is_primitive; - size_t status_size; - size_t uninitialized_size; + size_t description_bytes_per_tile; + size_t payload_bytes_per_tile; - scan_tile_state(bool is_primitive, size_t status_size, size_t uninitialized_size) + scan_tile_state(size_t description_bytes_per_tile, size_t payload_bytes_per_tile) : d_tile_status(nullptr) , d_tile_partial(nullptr) , d_tile_inclusive(nullptr) - , is_primitive(is_primitive) - , status_size(status_size) - , uninitialized_size(uninitialized_size) + , description_bytes_per_tile(description_bytes_per_tile) + , payload_bytes_per_tile(payload_bytes_per_tile) {} cudaError_t Init(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) { - return is_primitive ? InitPrimitive(num_tiles, d_temp_storage, temp_storage_bytes) - : InitStorage(num_tiles, d_temp_storage, temp_storage_bytes); - } - - cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) const - { - return is_primitive ? AllocationSizePrimitive(num_tiles, temp_storage_bytes) - : AllocationSizeStorage(num_tiles, temp_storage_bytes); - } - - cudaError_t InitPrimitive(int, void* d_temp_storage, size_t) - { - d_tile_status = d_temp_storage; + void* allocations[3] = {}; + auto status = cub::detail::tile_state_init( + description_bytes_per_tile, payload_bytes_per_tile, num_tiles, d_temp_storage, temp_storage_bytes, allocations); + if (status != cudaSuccess) + { + return status; + } + d_tile_status = allocations[0]; + d_tile_partial = allocations[1]; + d_tile_inclusive = allocations[2]; return cudaSuccess; } - cudaError_t AllocationSizePrimitive(int num_tiles, size_t& d_temp_storage_bytes) const + cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) const { - d_temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * status_size; + temp_storage_bytes = + cub::detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles); return cudaSuccess; } - - cudaError_t InitStorage(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) - { - cudaError_t error = cudaSuccess; - do - { - void* allocations[3] = {}; - size_t allocation_sizes[3]; - - // bytes needed for tile status descriptors - allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * status_size; - - // bytes needed for partials - allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * uninitialized_size; - - // bytes needed for inclusives - allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * uninitialized_size; - - // Compute allocation pointers into the single storage blob - error = CubDebug(cub::AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)); - - if (cudaSuccess != error) - { - break; - } - - // Alias the offsets - d_tile_status = allocations[0]; - d_tile_partial = allocations[1]; - d_tile_inclusive = allocations[2]; - } while (0); - - return error; - } - - cudaError_t AllocationSizeStorage(int num_tiles, size_t& d_temp_storage_bytes) const - { - // Specify storage allocation requirements - size_t allocation_sizes[3]; - - // bytes needed for tile status descriptors - allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * status_size; - - // bytes needed for partials - allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * uninitialized_size; - - // bytes needed for inclusives - allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * uninitialized_size; - - // Set the necessary size of the blob - void* allocations[3] = {}; - return CubDebug(cub::AliasTemporaries(nullptr, d_temp_storage_bytes, allocations, allocation_sizes)); - } }; template @@ -444,50 +403,29 @@ extern "C" CCCL_C_API CUresult cccl_device_scan_build( constexpr size_t num_ptx_lto_args = 3; const char* ptx_lopts[num_ptx_lto_args] = {"-lto", arch.c_str(), "-ptx"}; - size_t status_size{0}; - size_t uninitialized_size{0}; - if (accum_t.type == cccl_type_enum::STORAGE) - { - std::string src = std::format( - "#include \n" - "#include \n" - "struct __align__({1}) storage_t {{\n" - " char data[{0}];\n" - "}};\n" - "__device__ size_t status_size = sizeof(typename cub::ScanTileState<{2}>::StatusWord);\n" - "__device__ size_t uninitialized_size = sizeof(cub::Uninitialized<{2}>);\n", - accum_t.size, - accum_t.alignment, - accum_cpp); - auto compile_result = - make_nvrtc_command_list() - .add_program(nvrtc_translation_unit{src.c_str(), "tile_state_info"}) - .compile_program({ptx_args, num_ptx_args}) - .cleanup_program() - .finalize_program(num_ptx_lto_args, ptx_lopts); - auto ptx_code = compile_result.cubin.get(); - status_size = scan::find_size_t(ptx_code, "status_size"); - uninitialized_size = scan::find_size_t(ptx_code, "uninitialized_size"); - } - else - { - std::string src = std::format( - "#include \n" - "#include \n" - "__device__ size_t status_size = sizeof(typename cub::ScanTileState<{0}>::TxnWord);\n", - accum_cpp); - auto compile_result = - make_nvrtc_command_list() - .add_program(nvrtc_translation_unit{src.c_str(), "tile_state_info"}) - .compile_program({ptx_args, num_ptx_args}) - .cleanup_program() - .finalize_program(num_ptx_lto_args, ptx_lopts); - auto ptx_code = compile_result.cubin.get(); - status_size = scan::find_size_t(ptx_code, "status_size"); - } - - bool is_primitive = not(accum_t.type == cccl_type_enum::STORAGE); - auto tile_state = std::make_unique(is_primitive, status_size, uninitialized_size); + size_t description_bytes_per_tile; + size_t payload_bytes_per_tile; + std::string ptx_src = std::format( + "#include \n" + "#include \n" + "struct __align__({1}) storage_t {{\n" + " char data[{0}];\n" + "}};\n" + "__device__ size_t description_bytes_per_tile = cub::ScanTileState<{2}>::description_bytes_per_tile;\n" + "__device__ size_t payload_bytes_per_tile = cub::ScanTileState<{2}>::payload_bytes_per_tile;\n", + accum_t.size, + accum_t.alignment, + accum_cpp); + auto compile_result = + make_nvrtc_command_list() + .add_program(nvrtc_translation_unit{ptx_src.c_str(), "tile_state_info"}) + .compile_program({ptx_args, num_ptx_args}) + .cleanup_program() + .finalize_program(num_ptx_lto_args, ptx_lopts); + auto ptx_code = compile_result.cubin.get(); + description_bytes_per_tile = scan::find_size_t(ptx_code, "description_bytes_per_tile"); + payload_bytes_per_tile = scan::find_size_t(ptx_code, "payload_bytes_per_tile", 0); + auto tile_state = std::make_unique(description_bytes_per_tile, payload_bytes_per_tile); build->cc = cc; build->cubin = (void*) result.cubin.release();