Skip to content

Commit

Permalink
Extract merge sort kernels to NVRTC compilable header (NVIDIA#3438)
Browse files Browse the repository at this point in the history
* Move merge_sort kernels to separate file

* Add merge_sort nvrtc test

* Remove include that contains host code and replace with cuda::std

* Remove unneeded headers from merge_sort header

* Move LoadIterator to separate header and replace include

* Add host device macro to has_nested_type to fix nvrtc issue

* Extract make_load_iterator into separate file to avoid nvrtc error

* Extract is_thrust_pointer into separate file to avoid nvrtc error

* Extract policy_wrapper_t into separate file, forward declare LoadIterator, and use ::cuda::std instead of std to avoid nvrtc errors

* Extract unwrap_contiguous_iterator into separate file to avoid nvrtc errors

* Add missing include following header reorganization

* Add comment explaining why we forward declare make_load_iterator

* Add missing iterator include

* Add missing thrust config include

* Use is_same_v and rearrange include according to formatter

* Add missing comment to endif

* Use SPDX license instead of longer one

* Use nested namespace specifier

* Use nested namespace specifiers and _v suffix in other files
  • Loading branch information
NaderAlAwar authored and davebayer committed Jan 29, 2025
1 parent eee6780 commit e8b7b9c
Show file tree
Hide file tree
Showing 21 changed files with 659 additions and 437 deletions.
7 changes: 4 additions & 3 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
#include <cub/util_namespace.cuh>
#include <cub/util_type.cuh>

#include <thrust/system/cuda/detail/core/util.h>
#include <thrust/system/cuda/detail/core/load_iterator.h>

#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/__cccl/cuda_capabilities.h>

CUB_NAMESPACE_BEGIN

Expand Down Expand Up @@ -86,7 +87,7 @@ struct AgentBlockSort
// Types and constants
//---------------------------------------------------------------------

static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;

using BlockMergeSortT = BlockMergeSort<KeyT, Policy::BLOCK_THREADS, Policy::ITEMS_PER_THREAD, ValueT>;

Expand Down Expand Up @@ -469,7 +470,7 @@ struct AgentMerge
struct TempStorage : Uninitialized<_TempStorage>
{};

static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD;
static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE;
Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/device_for.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <thrust/iterator/iterator_traits.h>
#include <thrust/system/cuda/detail/core/util.h>
#include <thrust/type_traits/is_contiguous_iterator.h>
#include <thrust/type_traits/unwrap_contiguous_iterator.h>

#if __cccl_lib_mdspan
# include <cuda/std/__mdspan/extents.h>
Expand Down
272 changes: 1 addition & 271 deletions cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#endif // no system header

#include <cub/agent/agent_merge_sort.cuh>
#include <cub/device/dispatch/kernels/merge_sort.cuh>
#include <cub/device/dispatch/tuning/tuning_merge_sort.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
Expand All @@ -53,277 +54,6 @@

CUB_NAMESPACE_BEGIN

namespace detail::merge_sort
{

/**
* @brief Helper class template that provides two agent template instantiations: one instantiated with the default
* policy and one with the fallback policy. This helps to avoid having to enlist all the agent's template parameters
* twice: once for the default agent and once for the fallback agent
*/
template <typename DefaultPolicyT, typename FallbackPolicyT, template <typename...> class AgentT, typename... AgentParamsT>
struct dual_policy_agent_helper_t
{
using default_agent_t = AgentT<DefaultPolicyT, AgentParamsT...>;
using fallback_agent_t = AgentT<FallbackPolicyT, AgentParamsT...>;

static constexpr auto default_size = sizeof(typename default_agent_t::TempStorage);
static constexpr auto fallback_size = sizeof(typename fallback_agent_t::TempStorage);
};

/**
* @brief Helper class template for merge sort-specific virtual shared memory handling. The merge sort algorithm in its
* current implementation relies on the fact that both the sorting as well as the merging kernels use the same tile
* size. This circumstance needs to be respected when determining whether the fallback policy for large user types is
* applicable: we must either use the fallback for both or for none of the two agents.
*/
template <typename DefaultPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
class merge_sort_vsmem_helper_t
{
private:
// Default fallback policy with a smaller tile size
using fallback_policy_t = cub::detail::policy_wrapper_t<DefaultPolicyT, 64, 1>;

// Helper for the `AgentBlockSort` template with one member type alias for the agent template instantiated with the
// default policy and one instantiated with the fallback policy
using block_sort_helper_t = dual_policy_agent_helper_t<
DefaultPolicyT,
fallback_policy_t,
merge_sort::AgentBlockSort,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;
using default_block_sort_agent_t = typename block_sort_helper_t::default_agent_t;
using fallback_block_sort_agent_t = typename block_sort_helper_t::fallback_agent_t;

// Helper for the `AgentMerge` template with one member type alias for the agent template instantiated with the
// default policy and one instantiated with the fallback policy
using merge_helper_t = dual_policy_agent_helper_t<
DefaultPolicyT,
fallback_policy_t,
merge_sort::AgentMerge,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;
using default_merge_agent_t = typename merge_helper_t::default_agent_t;
using fallback_merge_agent_t = typename merge_helper_t::fallback_agent_t;

// Use fallback if either (a) the default block sort or (b) the block merge agent exceed the maximum shared memory
// available per block and both (1) the fallback block sort and (2) the fallback merge agent would not exceed the
// available shared memory
static constexpr auto max_default_size =
(::cuda::std::max)(block_sort_helper_t::default_size, merge_helper_t::default_size);
static constexpr auto max_fallback_size =
(::cuda::std::max)(block_sort_helper_t::fallback_size, merge_helper_t::fallback_size);
static constexpr bool uses_fallback_policy =
(max_default_size > max_smem_per_block) && (max_fallback_size <= max_smem_per_block);

public:
using policy_t = ::cuda::std::_If<uses_fallback_policy, fallback_policy_t, DefaultPolicyT>;
using block_sort_agent_t =
::cuda::std::_If<uses_fallback_policy, fallback_block_sort_agent_t, default_block_sort_agent_t>;
using merge_agent_t = ::cuda::std::_If<uses_fallback_policy, fallback_merge_agent_t, default_merge_agent_t>;
};
template <typename ChainedPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
__launch_bounds__(
merge_sort_vsmem_helper_t<typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>::policy_t::BLOCK_THREADS)
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortBlockSortKernel(
bool ping,
KeyInputIteratorT keys_in,
ValueInputIteratorT items_in,
KeyIteratorT keys_out,
ValueIteratorT items_out,
OffsetT keys_count,
KeyT* tmp_keys_out,
ValueT* tmp_items_out,
CompareOpT compare_op,
vsmem_t vsmem)
{
using MergeSortHelperT = merge_sort_vsmem_helper_t<
typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;

using ActivePolicyT = typename MergeSortHelperT::policy_t;

using AgentBlockSortT = typename MergeSortHelperT::block_sort_agent_t;

using VSmemHelperT = vsmem_helper_impl<AgentBlockSortT>;

// Static shared memory allocation
__shared__ typename VSmemHelperT::static_temp_storage_t static_temp_storage;

// Get temporary storage
typename AgentBlockSortT::TempStorage& temp_storage = VSmemHelperT::get_temp_storage(static_temp_storage, vsmem);

AgentBlockSortT agent(
ping,
temp_storage,
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_in),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_in),
keys_count,
keys_out,
items_out,
tmp_keys_out,
tmp_items_out,
compare_op);

agent.Process();

// If applicable, hints to discard modified cache lines for vsmem
VSmemHelperT::discard_temp_storage(temp_storage);
}

// TODO(bgruber): if we put a call to cudaTriggerProgrammaticLaunchCompletion inside this kernel, the tests fail with
// cudaErrorIllegalAddress.
template <typename KeyIteratorT, typename OffsetT, typename CompareOpT, typename KeyT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortPartitionKernel(
bool ping,
KeyIteratorT keys_ping,
KeyT* keys_pong,
OffsetT keys_count,
OffsetT num_partitions,
OffsetT* merge_partitions,
CompareOpT compare_op,
OffsetT target_merged_tiles_number,
int items_per_tile)
{
OffsetT partition_idx = blockDim.x * blockIdx.x + threadIdx.x;

if (partition_idx < num_partitions)
{
AgentPartition<KeyIteratorT, OffsetT, CompareOpT, KeyT> agent(
ping,
keys_ping,
keys_pong,
keys_count,
partition_idx,
merge_partitions,
compare_op,
target_merged_tiles_number,
items_per_tile,
num_partitions);

agent.Process();
}
}

template <typename ChainedPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
__launch_bounds__(
merge_sort_vsmem_helper_t<typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>::policy_t::BLOCK_THREADS)
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortMergeKernel(
bool ping,
KeyIteratorT keys_ping,
ValueIteratorT items_ping,
OffsetT keys_count,
KeyT* keys_pong,
ValueT* items_pong,
CompareOpT compare_op,
OffsetT* merge_partitions,
OffsetT target_merged_tiles_number,
vsmem_t vsmem)
{
using MergeSortHelperT = merge_sort_vsmem_helper_t<
typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;

using ActivePolicyT = typename MergeSortHelperT::policy_t;

using AgentMergeT = typename MergeSortHelperT::merge_agent_t;

using VSmemHelperT = vsmem_helper_impl<AgentMergeT>;

// Static shared memory allocation
__shared__ typename VSmemHelperT::static_temp_storage_t static_temp_storage;

// Get temporary storage
typename AgentMergeT::TempStorage& temp_storage = VSmemHelperT::get_temp_storage(static_temp_storage, vsmem);

AgentMergeT agent(
ping,
temp_storage,
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_ping),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_ping),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_pong),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_pong),
keys_count,
keys_pong,
items_pong,
keys_ping,
items_ping,
compare_op,
merge_partitions,
target_merged_tiles_number);

agent.Process();

// If applicable, hints to discard modified cache lines for vsmem
VSmemHelperT::discard_temp_storage(temp_storage);
}

} // namespace detail::merge_sort

/*******************************************************************************
* Policy
******************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ _CCCL_NV_DIAG_SUPPRESS(186)
#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>
#include <thrust/type_traits/is_contiguous_iterator.h>
#include <thrust/type_traits/is_trivially_relocatable.h>
#include <thrust/type_traits/unwrap_contiguous_iterator.h>

#include <cuda/cmath>
#include <cuda/ptx>
Expand Down
Loading

0 comments on commit e8b7b9c

Please sign in to comment.