Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract merge sort kernels to NVRTC compilable header #3438

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c939b2c
Move merge_sort kernels to separate file
NaderAlAwar Jan 16, 2025
2f44ceb
Add merge_sort nvrtc test
NaderAlAwar Jan 16, 2025
0de0643
Remove include that contains host code and replace with cuda::std
NaderAlAwar Jan 16, 2025
8d42feb
Remove unneeded headers from merge_sort header
NaderAlAwar Jan 16, 2025
7ada1a4
Move LoadIterator to separate header and replace include
NaderAlAwar Jan 16, 2025
465dd84
Add host device macro to has_nested_type to fix nvrtc issue
NaderAlAwar Jan 17, 2025
2a514ae
Extract make_load_iterator into separate file to avoid nvrtc error
NaderAlAwar Jan 17, 2025
617049f
Extract is_thrust_pointer into separate file to avoid nvrtc error
NaderAlAwar Jan 17, 2025
c657606
Extract policy_wrapper_t into separate file, forward declare LoadIter…
NaderAlAwar Jan 17, 2025
84b66f2
Extract unwrap_contiguous_iterator into separate file to avoid nvrtc …
NaderAlAwar Jan 17, 2025
0b67d2c
Add missing include following header reorganization
NaderAlAwar Jan 17, 2025
559cd52
Add comment explaining why we forward declare make_load_iterator
NaderAlAwar Jan 17, 2025
108dc42
Add missing iterator include
NaderAlAwar Jan 17, 2025
1fd9031
Add missing thrust config include
NaderAlAwar Jan 17, 2025
4169f33
Merge branch 'main' into extract-merge-sort-to-nvrtc-header
NaderAlAwar Jan 17, 2025
42dc031
Merge branch 'main' of https://github.com/naderalawar/cccl into extra…
NaderAlAwar Jan 21, 2025
ff27ef4
Use is_same_v and rearrange include according to formatter
NaderAlAwar Jan 21, 2025
04cdbea
Add missing comment to endif
NaderAlAwar Jan 21, 2025
73c10e6
Use SPDX license instead of longer one
NaderAlAwar Jan 21, 2025
9ba92b8
Use nested namespace specifier
NaderAlAwar Jan 21, 2025
6446316
Use nested namespace specifiers and _v suffix in other files
NaderAlAwar Jan 21, 2025
65c599d
Merge branch 'main' into extract-merge-sort-to-nvrtc-header
NaderAlAwar Jan 21, 2025
9347f1b
Merge branch 'main' into extract-merge-sort-to-nvrtc-header
NaderAlAwar Jan 22, 2025
c404d71
Merge branch 'main' of https://github.com/NaderAlAwar/cccl into extra…
NaderAlAwar Jan 22, 2025
271ff98
Merge branch 'main' into extract-merge-sort-to-nvrtc-header
NaderAlAwar Jan 22, 2025
7efa329
Merge branch 'main' of https://github.com/naderalawar/cccl into extra…
NaderAlAwar Jan 23, 2025
63ef177
Merge branch 'main' into extract-merge-sort-to-nvrtc-header
NaderAlAwar Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
#include <cub/util_namespace.cuh>
#include <cub/util_type.cuh>

#include <thrust/system/cuda/detail/core/util.h>
#include <thrust/system/cuda/detail/core/load_iterator.h>

#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/__cccl/cuda_capabilities.h>

CUB_NAMESPACE_BEGIN

Expand Down Expand Up @@ -86,7 +87,7 @@ struct AgentBlockSort
// Types and constants
//---------------------------------------------------------------------

static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;

using BlockMergeSortT = BlockMergeSort<KeyT, Policy::BLOCK_THREADS, Policy::ITEMS_PER_THREAD, ValueT>;

Expand Down Expand Up @@ -469,7 +470,7 @@ struct AgentMerge
struct TempStorage : Uninitialized<_TempStorage>
{};

static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS;
static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD;
static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE;
Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/device_for.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <thrust/iterator/iterator_traits.h>
#include <thrust/system/cuda/detail/core/util.h>
#include <thrust/type_traits/is_contiguous_iterator.h>
#include <thrust/type_traits/unwrap_contiguous_iterator.h>

#if __cccl_lib_mdspan
# include <cuda/std/__mdspan/extents.h>
Expand Down
272 changes: 1 addition & 271 deletions cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#endif // no system header

#include <cub/agent/agent_merge_sort.cuh>
#include <cub/device/dispatch/kernels/merge_sort.cuh>
#include <cub/device/dispatch/tuning/tuning_merge_sort.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
Expand All @@ -53,277 +54,6 @@

CUB_NAMESPACE_BEGIN

namespace detail::merge_sort
{

/**
* @brief Helper class template that provides two agent template instantiations: one instantiated with the default
* policy and one with the fallback policy. This helps to avoid having to enlist all the agent's template parameters
* twice: once for the default agent and once for the fallback agent
*/
template <typename DefaultPolicyT, typename FallbackPolicyT, template <typename...> class AgentT, typename... AgentParamsT>
struct dual_policy_agent_helper_t
{
using default_agent_t = AgentT<DefaultPolicyT, AgentParamsT...>;
using fallback_agent_t = AgentT<FallbackPolicyT, AgentParamsT...>;

static constexpr auto default_size = sizeof(typename default_agent_t::TempStorage);
static constexpr auto fallback_size = sizeof(typename fallback_agent_t::TempStorage);
};

/**
* @brief Helper class template for merge sort-specific virtual shared memory handling. The merge sort algorithm in its
* current implementation relies on the fact that both the sorting as well as the merging kernels use the same tile
* size. This circumstance needs to be respected when determining whether the fallback policy for large user types is
* applicable: we must either use the fallback for both or for none of the two agents.
*/
template <typename DefaultPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
class merge_sort_vsmem_helper_t
{
private:
// Default fallback policy with a smaller tile size
using fallback_policy_t = cub::detail::policy_wrapper_t<DefaultPolicyT, 64, 1>;

// Helper for the `AgentBlockSort` template with one member type alias for the agent template instantiated with the
// default policy and one instantiated with the fallback policy
using block_sort_helper_t = dual_policy_agent_helper_t<
DefaultPolicyT,
fallback_policy_t,
merge_sort::AgentBlockSort,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;
using default_block_sort_agent_t = typename block_sort_helper_t::default_agent_t;
using fallback_block_sort_agent_t = typename block_sort_helper_t::fallback_agent_t;

// Helper for the `AgentMerge` template with one member type alias for the agent template instantiated with the
// default policy and one instantiated with the fallback policy
using merge_helper_t = dual_policy_agent_helper_t<
DefaultPolicyT,
fallback_policy_t,
merge_sort::AgentMerge,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;
using default_merge_agent_t = typename merge_helper_t::default_agent_t;
using fallback_merge_agent_t = typename merge_helper_t::fallback_agent_t;

// Use fallback if either (a) the default block sort or (b) the block merge agent exceed the maximum shared memory
// available per block and both (1) the fallback block sort and (2) the fallback merge agent would not exceed the
// available shared memory
static constexpr auto max_default_size =
(::cuda::std::max)(block_sort_helper_t::default_size, merge_helper_t::default_size);
static constexpr auto max_fallback_size =
(::cuda::std::max)(block_sort_helper_t::fallback_size, merge_helper_t::fallback_size);
static constexpr bool uses_fallback_policy =
(max_default_size > max_smem_per_block) && (max_fallback_size <= max_smem_per_block);

public:
using policy_t = ::cuda::std::_If<uses_fallback_policy, fallback_policy_t, DefaultPolicyT>;
using block_sort_agent_t =
::cuda::std::_If<uses_fallback_policy, fallback_block_sort_agent_t, default_block_sort_agent_t>;
using merge_agent_t = ::cuda::std::_If<uses_fallback_policy, fallback_merge_agent_t, default_merge_agent_t>;
};
template <typename ChainedPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
__launch_bounds__(
merge_sort_vsmem_helper_t<typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>::policy_t::BLOCK_THREADS)
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortBlockSortKernel(
bool ping,
KeyInputIteratorT keys_in,
ValueInputIteratorT items_in,
KeyIteratorT keys_out,
ValueIteratorT items_out,
OffsetT keys_count,
KeyT* tmp_keys_out,
ValueT* tmp_items_out,
CompareOpT compare_op,
vsmem_t vsmem)
{
using MergeSortHelperT = merge_sort_vsmem_helper_t<
typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;

using ActivePolicyT = typename MergeSortHelperT::policy_t;

using AgentBlockSortT = typename MergeSortHelperT::block_sort_agent_t;

using VSmemHelperT = vsmem_helper_impl<AgentBlockSortT>;

// Static shared memory allocation
__shared__ typename VSmemHelperT::static_temp_storage_t static_temp_storage;

// Get temporary storage
typename AgentBlockSortT::TempStorage& temp_storage = VSmemHelperT::get_temp_storage(static_temp_storage, vsmem);

AgentBlockSortT agent(
ping,
temp_storage,
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_in),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_in),
keys_count,
keys_out,
items_out,
tmp_keys_out,
tmp_items_out,
compare_op);

agent.Process();

// If applicable, hints to discard modified cache lines for vsmem
VSmemHelperT::discard_temp_storage(temp_storage);
}

// TODO(bgruber): if we put a call to cudaTriggerProgrammaticLaunchCompletion inside this kernel, the tests fail with
// cudaErrorIllegalAddress.
template <typename KeyIteratorT, typename OffsetT, typename CompareOpT, typename KeyT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortPartitionKernel(
bool ping,
KeyIteratorT keys_ping,
KeyT* keys_pong,
OffsetT keys_count,
OffsetT num_partitions,
OffsetT* merge_partitions,
CompareOpT compare_op,
OffsetT target_merged_tiles_number,
int items_per_tile)
{
OffsetT partition_idx = blockDim.x * blockIdx.x + threadIdx.x;

if (partition_idx < num_partitions)
{
AgentPartition<KeyIteratorT, OffsetT, CompareOpT, KeyT> agent(
ping,
keys_ping,
keys_pong,
keys_count,
partition_idx,
merge_partitions,
compare_op,
target_merged_tiles_number,
items_per_tile,
num_partitions);

agent.Process();
}
}

template <typename ChainedPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
__launch_bounds__(
merge_sort_vsmem_helper_t<typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>::policy_t::BLOCK_THREADS)
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortMergeKernel(
bool ping,
KeyIteratorT keys_ping,
ValueIteratorT items_ping,
OffsetT keys_count,
KeyT* keys_pong,
ValueT* items_pong,
CompareOpT compare_op,
OffsetT* merge_partitions,
OffsetT target_merged_tiles_number,
vsmem_t vsmem)
{
using MergeSortHelperT = merge_sort_vsmem_helper_t<
typename ChainedPolicyT::ActivePolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;

using ActivePolicyT = typename MergeSortHelperT::policy_t;

using AgentMergeT = typename MergeSortHelperT::merge_agent_t;

using VSmemHelperT = vsmem_helper_impl<AgentMergeT>;

// Static shared memory allocation
__shared__ typename VSmemHelperT::static_temp_storage_t static_temp_storage;

// Get temporary storage
typename AgentMergeT::TempStorage& temp_storage = VSmemHelperT::get_temp_storage(static_temp_storage, vsmem);

AgentMergeT agent(
ping,
temp_storage,
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_ping),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_ping),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_pong),
THRUST_NS_QUALIFIER::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_pong),
keys_count,
keys_pong,
items_pong,
keys_ping,
items_ping,
compare_op,
merge_partitions,
target_merged_tiles_number);

agent.Process();

// If applicable, hints to discard modified cache lines for vsmem
VSmemHelperT::discard_temp_storage(temp_storage);
}

} // namespace detail::merge_sort

/*******************************************************************************
* Policy
******************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ _CCCL_NV_DIAG_SUPPRESS(186)
#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>
#include <thrust/type_traits/is_contiguous_iterator.h>
#include <thrust/type_traits/is_trivially_relocatable.h>
#include <thrust/type_traits/unwrap_contiguous_iterator.h>

#include <cuda/cmath>
#include <cuda/ptx>
Expand Down
Loading
Loading