Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Speedup compilation time of segmented sort test
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Oct 22, 2021
1 parent 1072b8c commit 3cda69e
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 153 deletions.
62 changes: 33 additions & 29 deletions cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ CUB_NAMESPACE_BEGIN


template <bool UseVShmem,
typename ActivePolicyT,
typename ChainedPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
Expand All @@ -48,7 +48,7 @@ template <bool UseVShmem,
typename CompareOpT,
typename KeyT,
typename ValueT>
void __global__ __launch_bounds__(ActivePolicyT::BLOCK_THREADS)
void __global__ __launch_bounds__(ChainedPolicyT::ActivePolicy::MergeSortPolicy::BLOCK_THREADS)
DeviceMergeSortBlockSortKernel(bool ping,
KeyInputIteratorT keys_in,
ValueInputIteratorT items_in,
Expand All @@ -61,6 +61,7 @@ DeviceMergeSortBlockSortKernel(bool ping,
char *vshmem)
{
extern __shared__ char shmem[];
using ActivePolicyT = typename ChainedPolicyT::ActivePolicy::MergeSortPolicy;

using AgentBlockSortT = AgentBlockSort<ActivePolicyT,
KeyInputIteratorT,
Expand Down Expand Up @@ -126,16 +127,15 @@ __global__ void DeviceMergeSortPartitionKernel(bool ping,
}
}

template <
bool UseVShmem,
typename ActivePolicyT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
void __global__ __launch_bounds__(ActivePolicyT::BLOCK_THREADS)
template <bool UseVShmem,
typename ChainedPolicyT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename KeyT,
typename ValueT>
void __global__ __launch_bounds__(ChainedPolicyT::ActivePolicy::MergeSortPolicy::BLOCK_THREADS)
DeviceMergeSortMergeKernel(bool ping,
KeyIteratorT keys_ping,
ValueIteratorT items_ping,
Expand All @@ -145,11 +145,11 @@ DeviceMergeSortMergeKernel(bool ping,
CompareOpT compare_op,
OffsetT *merge_partitions,
OffsetT target_merged_tiles_number,
char *vshmem
)
char *vshmem)
{
extern __shared__ char shmem[];

using ActivePolicyT = typename ChainedPolicyT::ActivePolicy::MergeSortPolicy;
using AgentMergeT = AgentMerge<ActivePolicyT,
KeyIteratorT,
ValueIteratorT,
Expand Down Expand Up @@ -241,7 +241,8 @@ template <typename KeyInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename MergePolicyT,
typename ChainedPolicyT,
typename ActivePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
Expand Down Expand Up @@ -310,11 +311,11 @@ struct BlockSortLauncher
{
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles,
MergePolicyT::BLOCK_THREADS,
ActivePolicyT::MergeSortPolicy::BLOCK_THREADS,
block_sort_shmem_size,
stream)
.doit(DeviceMergeSortBlockSortKernel<UseVShmem,
MergePolicyT,
ChainedPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
Expand All @@ -336,14 +337,14 @@ struct BlockSortLauncher
}
};

template <
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename MergePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
template <typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename ChainedPolicyT,
typename ActivePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
struct MergeLauncher
{
int num_tiles;
Expand Down Expand Up @@ -403,11 +404,11 @@ struct MergeLauncher
{
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles,
MergePolicyT::BLOCK_THREADS,
ActivePolicyT::MergeSortPolicy::BLOCK_THREADS,
merge_shmem_size,
stream)
.doit(DeviceMergeSortMergeKernel<UseVShmem,
MergePolicyT,
ChainedPolicyT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
Expand Down Expand Up @@ -508,6 +509,7 @@ struct DispatchMergeSort : SelectedPolicy
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke()
{
using MergePolicyT = typename ActivePolicyT::MergeSortPolicy;
using MaxPolicyT = typename DispatchMergeSort::MaxPolicy;

using BlockSortAgentT = AgentBlockSort<MergePolicyT,
KeyInputIteratorT,
Expand Down Expand Up @@ -627,7 +629,8 @@ struct DispatchMergeSort : SelectedPolicy
KeyIteratorT,
ValueIteratorT,
OffsetT,
MergePolicyT,
MaxPolicyT,
ActivePolicyT,
CompareOpT,
KeyT,
ValueT>
Expand Down Expand Up @@ -672,7 +675,8 @@ struct DispatchMergeSort : SelectedPolicy
MergeLauncher<KeyIteratorT,
ValueIteratorT,
OffsetT,
MergePolicyT,
MaxPolicyT,
ActivePolicyT,
CompareOpT,
KeyT,
ValueT>
Expand Down
Loading

0 comments on commit 3cda69e

Please sign in to comment.