Skip to content

Commit

Permalink
nvcc 11.1 workarounds
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 15, 2025
1 parent e87f17d commit 8efd1b4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
2 changes: 2 additions & 0 deletions cub/cub/device/device_spmv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DeviceSpmv
{
CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceSpmv::CsrMV");

_CCCL_SUPPRESS_DEPRECATED_PUSH
SpmvParams<ValueT, int> spmv_params;
_CCCL_SUPPRESS_DEPRECATED_POP
spmv_params.d_values = d_values;
spmv_params.d_row_end_offsets = d_row_offsets + 1;
spmv_params.d_column_indices = d_column_indices;
Expand Down
39 changes: 26 additions & 13 deletions cub/cub/device/dispatch/dispatch_spmv_orig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ _CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename AgentSpmvPolicyT, typename ValueT, typename OffsetT>
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmv1ColKernel(SpmvParams<ValueT, OffsetT> spmv_params) //
_CCCL_SUPPRESS_DEPRECATED_POP
{
using VectorValueIteratorT =
CacheModifiedInputIterator<AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, ValueT, OffsetT>;
Expand All @@ -109,6 +108,7 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmv1ColKernel(SpmvParams<ValueT, Offset
spmv_params.d_vector_y[row_idx] = value;
}
}
_CCCL_SUPPRESS_DEPRECATED_POP

/**
* @brief Spmv search kernel. Identifies merge path starting coordinates for each tile.
Expand Down Expand Up @@ -214,6 +214,7 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvSearchKernel(
* @param[in] num_segment_fixup_tiles
* Number of reduce-by-key tiles (fixup grid size)
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename SpmvPolicyT,
typename ScanTileStateT,
typename ValueT,
Expand Down Expand Up @@ -243,6 +244,7 @@ __launch_bounds__(int(SpmvPolicyT::BLOCK_THREADS)) CUB_DETAIL_KERNEL_ATTRIBUTES
// Initialize fixup tile status
tile_state.InitializeStatus(num_segment_fixup_tiles);
}
_CCCL_SUPPRESS_DEPRECATED_POP

/**
* @tparam ValueT
Expand All @@ -254,6 +256,7 @@ __launch_bounds__(int(SpmvPolicyT::BLOCK_THREADS)) CUB_DETAIL_KERNEL_ATTRIBUTES
* @tparam HAS_BETA
* Whether the input parameter Beta is 0
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
template <typename ValueT, typename OffsetT, bool HAS_BETA>
CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead")
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvEmptyMatrixKernel(SpmvParams<ValueT, OffsetT> spmv_params)
Expand All @@ -272,6 +275,7 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvEmptyMatrixKernel(SpmvParams<ValueT,
spmv_params.d_vector_y[row] = result;
}
}
_CCCL_SUPPRESS_DEPRECATED_POP

/**
* @brief Multi-block reduce-by-key sweep kernel entry point
Expand Down Expand Up @@ -319,8 +323,7 @@ __launch_bounds__(int(AgentSegmentFixupPolicyT::BLOCK_THREADS))
AggregatesOutputIteratorT d_aggregates_out,
OffsetT num_items,
int num_tiles,
ScanTileStateT tile_state) //
_CCCL_SUPPRESS_DEPRECATED_POP
ScanTileStateT tile_state)
{
// Thread block type for reducing tiles of value segments
using AgentSegmentFixupT =
Expand All @@ -338,6 +341,7 @@ __launch_bounds__(int(AgentSegmentFixupPolicyT::BLOCK_THREADS))
AgentSegmentFixupT(temp_storage, d_pairs_in, d_aggregates_out, ::cuda::std::equal_to<>{}, ::cuda::std::plus<>{})
.ConsumeRange(num_items, num_tiles, tile_state);
}
_CCCL_SUPPRESS_DEPRECATED_POP

/******************************************************************************
* Dispatch
Expand Down Expand Up @@ -366,7 +370,9 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
};

// SpmvParams bundle type
_CCCL_SUPPRESS_DEPRECATED_PUSH
using SpmvParamsT = SpmvParams<ValueT, OffsetT>;
_CCCL_SUPPRESS_DEPRECATED_POP

// 2D merge path coordinate type
using CoordinateT = typename CubVector<OffsetT, 2>::Type;
Expand All @@ -384,6 +390,7 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
/// SM35
struct Policy350
{
_CCCL_SUPPRESS_DEPRECATED_PUSH
using SpmvPolicyT =
AgentSpmvPolicy<(sizeof(ValueT) > 4) ? 96 : 128,
(sizeof(ValueT) > 4) ? 4 : 7,
Expand All @@ -394,13 +401,15 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
LOAD_LDG,
(sizeof(ValueT) > 4) ? true : false,
BLOCK_SCAN_WARP_SCANS>;
_CCCL_SUPPRESS_DEPRECATED_POP

using SegmentFixupPolicyT = AgentSegmentFixupPolicy<128, 3, BLOCK_LOAD_VECTORIZE, LOAD_LDG, BLOCK_SCAN_WARP_SCANS>;
};

/// SM37
struct Policy370
{
_CCCL_SUPPRESS_DEPRECATED_PUSH
using SpmvPolicyT =
AgentSpmvPolicy<(sizeof(ValueT) > 4) ? 128 : 128,
(sizeof(ValueT) > 4) ? 9 : 14,
Expand All @@ -411,13 +420,15 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
LOAD_LDG,
false,
BLOCK_SCAN_WARP_SCANS>;
_CCCL_SUPPRESS_DEPRECATED_POP

using SegmentFixupPolicyT = AgentSegmentFixupPolicy<128, 3, BLOCK_LOAD_VECTORIZE, LOAD_LDG, BLOCK_SCAN_WARP_SCANS>;
};

/// SM50
struct Policy500
{
_CCCL_SUPPRESS_DEPRECATED_PUSH
using SpmvPolicyT =
AgentSpmvPolicy<(sizeof(ValueT) > 4) ? 64 : 128,
(sizeof(ValueT) > 4) ? 6 : 7,
Expand All @@ -428,6 +439,7 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
LOAD_LDG,
(sizeof(ValueT) > 4) ? true : false,
(sizeof(ValueT) > 4) ? BLOCK_SCAN_WARP_SCANS : BLOCK_SCAN_RAKING_MEMOIZE>;
_CCCL_SUPPRESS_DEPRECATED_POP

using SegmentFixupPolicyT =
AgentSegmentFixupPolicy<128, 3, BLOCK_LOAD_VECTORIZE, LOAD_LDG, BLOCK_SCAN_RAKING_MEMOIZE>;
Expand All @@ -436,16 +448,17 @@ struct CCCL_DEPRECATED_BECAUSE("Use the cuSPARSE library instead") DispatchSpmv
/// SM60
struct Policy600
{
using SpmvPolicyT =
AgentSpmvPolicy<(sizeof(ValueT) > 4) ? 64 : 128,
(sizeof(ValueT) > 4) ? 5 : 7,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
false,
BLOCK_SCAN_WARP_SCANS>;
using SpmvPolicyT = _CCCL_SUPPRESS_DEPRECATED_PUSH AgentSpmvPolicy<
(sizeof(ValueT) > 4) ? 64 : 128,
(sizeof(ValueT) > 4) ? 5 : 7,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
false,
BLOCK_SCAN_WARP_SCANS>;
_CCCL_SUPPRESS_DEPRECATED_POP

using SegmentFixupPolicyT = AgentSegmentFixupPolicy<128, 3, BLOCK_LOAD_DIRECT, LOAD_LDG, BLOCK_SCAN_WARP_SCANS>;
};
Expand Down

0 comments on commit 8efd1b4

Please sign in to comment.