Skip to content

Commit 612f475

Browse files
authoredFeb 4, 2025
Add b200 tunings for radix_sort.keys (#3611) (#3655)
1 parent 25901d7 commit 612f475

File tree

1 file changed

+125
-2
lines changed

1 file changed

+125
-2
lines changed
 

‎cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh

+125-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace detail
5050
{
5151
namespace radix
5252
{
53-
// default
53+
// sm90 default
5454
template <std::size_t KeySize, std::size_t ValueSize, std::size_t OffsetSize>
5555
struct sm90_small_key_tuning
5656
{
@@ -1069,7 +1069,130 @@ struct policy_hub
10691069
SEGMENTED_RADIX_BITS - 1>;
10701070
};
10711071

1072-
using MaxPolicy = Policy900;
1072+
// todo(@gonidelis): refactor this as to not duplicate SM90.
1073+
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
1074+
{
1075+
static constexpr bool ONESWEEP = true;
1076+
static constexpr int ONESWEEP_RADIX_BITS = 8;
1077+
1078+
using HistogramPolicy = AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>;
1079+
using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>;
1080+
1081+
private:
1082+
static constexpr int PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5;
1083+
static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
1084+
static constexpr int SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
1085+
static constexpr int OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0;
1086+
static constexpr int FLOAT_KEYS = ::cuda::std::is_same<KeyT, float>::value ? 1 : 0;
1087+
1088+
using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy<
1089+
384,
1090+
KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
1091+
: (sizeof(ValueT) < 8 ? (OFFSET_64BIT ? 17 : 23) : (OFFSET_64BIT ? 29 : 30)),
1092+
DominantT,
1093+
1,
1094+
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
1095+
BLOCK_SCAN_RAKING_MEMOIZE,
1096+
RADIX_SORT_STORE_DIRECT,
1097+
ONESWEEP_RADIX_BITS>;
1098+
1099+
using OnesweepPolicyKey64 = AgentRadixSortOnesweepPolicy<
1100+
384,
1101+
sizeof(ValueT) < 8 ? 30 : 24,
1102+
DominantT,
1103+
1,
1104+
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
1105+
BLOCK_SCAN_RAKING_MEMOIZE,
1106+
RADIX_SORT_STORE_DIRECT,
1107+
ONESWEEP_RADIX_BITS>;
1108+
1109+
using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;
1110+
1111+
using OnesweepSmallKeyPolicySizes =
1112+
sm100_small_key_tuning<ValueT, sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;
1113+
1114+
using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
1115+
OnesweepSmallKeyPolicySizes::threads,
1116+
OnesweepSmallKeyPolicySizes::items,
1117+
DominantT,
1118+
1,
1119+
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
1120+
BLOCK_SCAN_RAKING_MEMOIZE,
1121+
RADIX_SORT_STORE_DIRECT,
1122+
8>;
1123+
1124+
public:
1125+
using OnesweepPolicy = ::cuda::std::_If<sizeof(KeyT) < 4, OnesweepSmallKeyPolicy, OnesweepLargeKeyPolicy>;
1126+
1127+
// The Scan, Downsweep and Upsweep policies are never run on SM90, but we have to include them to prevent a
1128+
// compilation error: When we compile e.g. for SM70 **and** SM90, the host compiler will reach calls to those
1129+
// kernels, and instantiate them for MaxPolicy (which is Policy900) on the host, which will reach into the policies
1130+
// below to set the launch bounds. The device compiler pass will also compile all kernels for SM70 **and** SM90,
1131+
// even though only the Onesweep kernel is used on SM90.
1132+
using ScanPolicy =
1133+
AgentScanPolicy<512,
1134+
23,
1135+
OffsetT,
1136+
BLOCK_LOAD_WARP_TRANSPOSE,
1137+
LOAD_DEFAULT,
1138+
BLOCK_STORE_WARP_TRANSPOSE,
1139+
BLOCK_SCAN_RAKING_MEMOIZE>;
1140+
1141+
using DownsweepPolicy = AgentRadixSortDownsweepPolicy<
1142+
512,
1143+
23,
1144+
DominantT,
1145+
BLOCK_LOAD_TRANSPOSE,
1146+
LOAD_DEFAULT,
1147+
RADIX_RANK_MATCH,
1148+
BLOCK_SCAN_WARP_SCANS,
1149+
PRIMARY_RADIX_BITS>;
1150+
1151+
using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<
1152+
(sizeof(KeyT) > 1) ? 256 : 128,
1153+
47,
1154+
DominantT,
1155+
BLOCK_LOAD_TRANSPOSE,
1156+
LOAD_DEFAULT,
1157+
RADIX_RANK_MEMOIZE,
1158+
BLOCK_SCAN_WARP_SCANS,
1159+
PRIMARY_RADIX_BITS - 1>;
1160+
1161+
using UpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
1162+
using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1>;
1163+
1164+
using SingleTilePolicy = AgentRadixSortDownsweepPolicy<
1165+
256,
1166+
19,
1167+
DominantT,
1168+
BLOCK_LOAD_DIRECT,
1169+
LOAD_LDG,
1170+
RADIX_RANK_MEMOIZE,
1171+
BLOCK_SCAN_WARP_SCANS,
1172+
SINGLE_TILE_RADIX_BITS>;
1173+
1174+
using SegmentedPolicy = AgentRadixSortDownsweepPolicy<
1175+
192,
1176+
39,
1177+
DominantT,
1178+
BLOCK_LOAD_TRANSPOSE,
1179+
LOAD_DEFAULT,
1180+
RADIX_RANK_MEMOIZE,
1181+
BLOCK_SCAN_WARP_SCANS,
1182+
SEGMENTED_RADIX_BITS>;
1183+
1184+
using AltSegmentedPolicy = AgentRadixSortDownsweepPolicy<
1185+
384,
1186+
11,
1187+
DominantT,
1188+
BLOCK_LOAD_TRANSPOSE,
1189+
LOAD_DEFAULT,
1190+
RADIX_RANK_MEMOIZE,
1191+
BLOCK_SCAN_WARP_SCANS,
1192+
SEGMENTED_RADIX_BITS - 1>;
1193+
};
1194+
1195+
using MaxPolicy = Policy1000;
10731196
};
10741197

10751198
} // namespace radix

0 commit comments

Comments
 (0)