@@ -50,7 +50,7 @@ namespace detail
50
50
{
51
51
namespace radix
52
52
{
53
- // default
53
+ // sm90 default
54
54
template <std::size_t KeySize, std::size_t ValueSize, std::size_t OffsetSize>
55
55
struct sm90_small_key_tuning
56
56
{
@@ -1069,7 +1069,130 @@ struct policy_hub
1069
1069
SEGMENTED_RADIX_BITS - 1 >;
1070
1070
};
1071
1071
1072
- using MaxPolicy = Policy900;
1072
+ // todo(@gonidelis): refactor this as to not duplicate SM90.
1073
+ struct Policy1000 : ChainedPolicy<1000 , Policy1000, Policy900>
1074
+ {
1075
+ static constexpr bool ONESWEEP = true ;
1076
+ static constexpr int ONESWEEP_RADIX_BITS = 8 ;
1077
+
1078
+ using HistogramPolicy = AgentRadixSortHistogramPolicy<128 , 16 , 1 , KeyT, ONESWEEP_RADIX_BITS>;
1079
+ using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256 , ONESWEEP_RADIX_BITS>;
1080
+
1081
+ private:
1082
+ static constexpr int PRIMARY_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 7 : 5 ;
1083
+ static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 6 : 5 ;
1084
+ static constexpr int SEGMENTED_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 6 : 5 ;
1085
+ static constexpr int OFFSET_64BIT = sizeof (OffsetT) == 8 ? 1 : 0 ;
1086
+ static constexpr int FLOAT_KEYS = ::cuda::std::is_same<KeyT, float >::value ? 1 : 0 ;
1087
+
1088
+ using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy<
1089
+ 384 ,
1090
+ KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
1091
+ : (sizeof (ValueT) < 8 ? (OFFSET_64BIT ? 17 : 23 ) : (OFFSET_64BIT ? 29 : 30 )),
1092
+ DominantT,
1093
+ 1 ,
1094
+ RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
1095
+ BLOCK_SCAN_RAKING_MEMOIZE,
1096
+ RADIX_SORT_STORE_DIRECT,
1097
+ ONESWEEP_RADIX_BITS>;
1098
+
1099
+ using OnesweepPolicyKey64 = AgentRadixSortOnesweepPolicy<
1100
+ 384 ,
1101
+ sizeof (ValueT) < 8 ? 30 : 24 ,
1102
+ DominantT,
1103
+ 1 ,
1104
+ RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
1105
+ BLOCK_SCAN_RAKING_MEMOIZE,
1106
+ RADIX_SORT_STORE_DIRECT,
1107
+ ONESWEEP_RADIX_BITS>;
1108
+
1109
+ using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof (KeyT) == 4 , OnesweepPolicyKey32, OnesweepPolicyKey64>;
1110
+
1111
+ using OnesweepSmallKeyPolicySizes =
1112
+ sm100_small_key_tuning<ValueT, sizeof (KeyT), KEYS_ONLY ? 0 : sizeof (ValueT), sizeof (OffsetT)>;
1113
+
1114
+ using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
1115
+ OnesweepSmallKeyPolicySizes::threads,
1116
+ OnesweepSmallKeyPolicySizes::items,
1117
+ DominantT,
1118
+ 1 ,
1119
+ RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
1120
+ BLOCK_SCAN_RAKING_MEMOIZE,
1121
+ RADIX_SORT_STORE_DIRECT,
1122
+ 8 >;
1123
+
1124
+ public:
1125
+ using OnesweepPolicy = ::cuda::std::_If<sizeof (KeyT) < 4 , OnesweepSmallKeyPolicy, OnesweepLargeKeyPolicy>;
1126
+
1127
+ // The Scan, Downsweep and Upsweep policies are never run on SM90, but we have to include them to prevent a
1128
+ // compilation error: When we compile e.g. for SM70 **and** SM90, the host compiler will reach calls to those
1129
+ // kernels, and instantiate them for MaxPolicy (which is Policy900) on the host, which will reach into the policies
1130
+ // below to set the launch bounds. The device compiler pass will also compile all kernels for SM70 **and** SM90,
1131
+ // even though only the Onesweep kernel is used on SM90.
1132
+ using ScanPolicy =
1133
+ AgentScanPolicy<512 ,
1134
+ 23 ,
1135
+ OffsetT,
1136
+ BLOCK_LOAD_WARP_TRANSPOSE,
1137
+ LOAD_DEFAULT,
1138
+ BLOCK_STORE_WARP_TRANSPOSE,
1139
+ BLOCK_SCAN_RAKING_MEMOIZE>;
1140
+
1141
+ using DownsweepPolicy = AgentRadixSortDownsweepPolicy<
1142
+ 512 ,
1143
+ 23 ,
1144
+ DominantT,
1145
+ BLOCK_LOAD_TRANSPOSE,
1146
+ LOAD_DEFAULT,
1147
+ RADIX_RANK_MATCH,
1148
+ BLOCK_SCAN_WARP_SCANS,
1149
+ PRIMARY_RADIX_BITS>;
1150
+
1151
+ using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<
1152
+ (sizeof (KeyT) > 1 ) ? 256 : 128 ,
1153
+ 47 ,
1154
+ DominantT,
1155
+ BLOCK_LOAD_TRANSPOSE,
1156
+ LOAD_DEFAULT,
1157
+ RADIX_RANK_MEMOIZE,
1158
+ BLOCK_SCAN_WARP_SCANS,
1159
+ PRIMARY_RADIX_BITS - 1 >;
1160
+
1161
+ using UpsweepPolicy = AgentRadixSortUpsweepPolicy<256 , 23 , DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
1162
+ using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy<256 , 47 , DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1 >;
1163
+
1164
+ using SingleTilePolicy = AgentRadixSortDownsweepPolicy<
1165
+ 256 ,
1166
+ 19 ,
1167
+ DominantT,
1168
+ BLOCK_LOAD_DIRECT,
1169
+ LOAD_LDG,
1170
+ RADIX_RANK_MEMOIZE,
1171
+ BLOCK_SCAN_WARP_SCANS,
1172
+ SINGLE_TILE_RADIX_BITS>;
1173
+
1174
+ using SegmentedPolicy = AgentRadixSortDownsweepPolicy<
1175
+ 192 ,
1176
+ 39 ,
1177
+ DominantT,
1178
+ BLOCK_LOAD_TRANSPOSE,
1179
+ LOAD_DEFAULT,
1180
+ RADIX_RANK_MEMOIZE,
1181
+ BLOCK_SCAN_WARP_SCANS,
1182
+ SEGMENTED_RADIX_BITS>;
1183
+
1184
+ using AltSegmentedPolicy = AgentRadixSortDownsweepPolicy<
1185
+ 384 ,
1186
+ 11 ,
1187
+ DominantT,
1188
+ BLOCK_LOAD_TRANSPOSE,
1189
+ LOAD_DEFAULT,
1190
+ RADIX_RANK_MEMOIZE,
1191
+ BLOCK_SCAN_WARP_SCANS,
1192
+ SEGMENTED_RADIX_BITS - 1 >;
1193
+ };
1194
+
1195
+ using MaxPolicy = Policy1000;
1073
1196
};
1074
1197
1075
1198
} // namespace radix
0 commit comments