Skip to content

Commit

Permalink
moves scan_by_key_op to detail ns
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Oct 8, 2024
1 parent 69b0ee2 commit a099ef0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 59 deletions.
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ struct AgentScanByKey
using KeyT = cub::detail::value_t<KeysInputIteratorT>;
using InputT = cub::detail::value_t<ValuesInputIteratorT>;
using FlagValuePairT = KeyValuePair<int, AccumT>;
using ReduceBySegmentOpT = ScanBySegmentOp<ScanOpT>;
using ReduceBySegmentOpT = detail::ScanBySegmentOp<ScanOpT>;

using ScanTileStateT = ReduceByKeyScanTileState<AccumT, int>;

Expand Down
116 changes: 58 additions & 58 deletions cub/cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,64 @@ struct ArgMin

namespace detail
{
template <typename ScanOpT>
struct ScanBySegmentOp
{
/// Wrapped operator
ScanOpT op;

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanBySegmentOp() {}

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanBySegmentOp(ScanOpT op)
: op(op)
{}

/**
* @brief Scan operator
*
* @tparam KeyValuePairT
* KeyValuePair pairing of T (value) and int (head flag)
*
* @param[in] first
* First partial reduction
*
* @param[in] second
* Second partial reduction
*/
template <typename KeyValuePairT>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePairT operator()(const KeyValuePairT& first, const KeyValuePairT& second)
{
KeyValuePairT retval;
retval.key = first.key | second.key;
#ifdef _NVHPC_CUDA // WAR bug on nvc++
if (second.key)
{
retval.value = second.value;
}
else
{
// If second.value isn't copied into a temporary here, nvc++ will
// crash while compiling the TestScanByKeyWithLargeTypes test in
// thrust/testing/scan_by_key.cu:
auto v2 = second.value;
retval.value = op(first.value, v2);
}
#else // not nvc++:
// if (second.key) {
// The second partial reduction spans a segment reset, so it's value
// aggregate becomes the running aggregate
// else {
// The second partial reduction does not span a reset, so accumulate both
// into the running aggregate
// }
retval.value = (second.key) ? second.value : op(first.value, second.value);
#endif
return retval;
}
};

template <class OpT>
struct basic_binary_op_t
{
Expand Down Expand Up @@ -354,64 +412,6 @@ struct ReduceBySegmentOp
}
};

template <typename ScanOpT>
struct ScanBySegmentOp
{
/// Wrapped operator
ScanOpT op;

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanBySegmentOp() {}

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanBySegmentOp(ScanOpT op)
: op(op)
{}

/**
* @brief Scan operator
*
* @tparam KeyValuePairT
* KeyValuePair pairing of T (value) and int (head flag)
*
* @param[in] first
* First partial reduction
*
* @param[in] second
* Second partial reduction
*/
template <typename KeyValuePairT>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE KeyValuePairT operator()(const KeyValuePairT& first, const KeyValuePairT& second)
{
KeyValuePairT retval;
retval.key = first.key | second.key;
#ifdef _NVHPC_CUDA // WAR bug on nvc++
if (second.key)
{
retval.value = second.value;
}
else
{
// If second.value isn't copied into a temporary here, nvc++ will
// crash while compiling the TestScanByKeyWithLargeTypes test in
// thrust/testing/scan_by_key.cu:
auto v2 = second.value;
retval.value = op(first.value, v2);
}
#else // not nvc++:
// if (second.key) {
// The second partial reduction spans a segment reset, so it's value
// aggregate becomes the running aggregate
// else {
// The second partial reduction does not span a reset, so accumulate both
// into the running aggregate
// }
retval.value = (second.key) ? second.value : op(first.value, second.value);
#endif
return retval;
}
};

/**
* @tparam ReductionOpT Binary reduction operator to apply to values
*/
Expand Down

0 comments on commit a099ef0

Please sign in to comment.