Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: make launchers a CUB detail; make kernel source functions hidden. #3209

Merged
merged 4 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions c/parallel/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
//===----------------------------------------------------------------------===//

#include <cub/detail/choose_offset.cuh>
#include <cub/detail/launcher/cuda_driver.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/grid/grid_even_share.cuh>
#include <cub/launcher/cuda_driver.cuh>
#include <cub/util_device.cuh>

#include <cuda/std/cstdint>
Expand Down Expand Up @@ -405,7 +405,7 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce(
dynamic_reduce_policy_t<&get_policy>,
::cuda::std::__identity,
reduce_kernel_source,
cub::CudaDriverLauncherFactory>::
cub::detail::CudaDriverLauncherFactory>::
Dispatch(
d_temp_storage,
*temp_storage_bytes,
Expand All @@ -417,7 +417,7 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce(
stream,
{},
{build},
cub::CudaDriverLauncherFactory{cu_device, build.cc},
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
{get_accumulator_type(op, d_in, init)});
}
catch (const std::exception& exc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

CUB_NAMESPACE_BEGIN

namespace detail
{

struct CudaDriverLauncher
{
dim3 grid;
Expand Down Expand Up @@ -78,6 +81,8 @@ struct CudaDriverLauncherFactory
int cc;
};

} // namespace detail

CUB_NAMESPACE_END

#endif // _CCCL_CUDACC_AT_LEAST(0)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

CUB_NAMESPACE_BEGIN

namespace detail
{

struct TripleChevronFactory
{
CUB_RUNTIME_FUNCTION THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron
Expand Down Expand Up @@ -50,4 +53,6 @@ struct TripleChevronFactory
}
};

} // namespace detail

CUB_NAMESPACE_END
6 changes: 3 additions & 3 deletions cub/cub/device/dispatch/dispatch_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
#endif // no system header

#include <cub/agent/agent_reduce.cuh>
#include <cub/detail/launcher/cuda_runtime.cuh>
#include <cub/device/dispatch/kernels/reduce.cuh>
#include <cub/device/dispatch/tuning/tuning_reduce.cuh>
#include <cub/grid/grid_even_share.cuh>
#include <cub/iterator/arg_index_input_iterator.cuh>
#include <cub/launcher/cuda_runtime.cuh>
#include <cub/thread/thread_operators.cuh>
#include <cub/thread/thread_store.cuh>
#include <cub/util_debug.cuh>
Expand Down Expand Up @@ -273,7 +273,7 @@ template <typename InputIteratorT,
InitT,
AccumT,
TransformOpT>,
typename KernelLauncherFactory = TripleChevronFactory>
typename KernelLauncherFactory = detail::TripleChevronFactory>
struct DispatchReduce
{
//---------------------------------------------------------------------------
Expand Down Expand Up @@ -754,7 +754,7 @@ template <
InitT,
AccumT,
TransformOpT>,
typename KernelLauncherFactory = TripleChevronFactory>
typename KernelLauncherFactory = detail::TripleChevronFactory>
using DispatchTransformReduce =
DispatchReduce<InputIteratorT,
OutputIteratorT,
Expand Down
7 changes: 5 additions & 2 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,10 @@ CUB_RUNTIME_FUNCTION PolicyWrapper<PolicyT> MakePolicyWrapper(PolicyT policy)
return PolicyWrapper<PolicyT>{policy};
}

namespace detail
{
struct TripleChevronFactory;
}

/**
* Kernel dispatch configuration
Expand All @@ -654,7 +657,7 @@ struct KernelConfig
, sm_occupancy(0)
{}

template <typename AgentPolicyT, typename KernelPtrT, typename LauncherFactory = TripleChevronFactory>
template <typename AgentPolicyT, typename KernelPtrT, typename LauncherFactory = detail::TripleChevronFactory>
CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t
Init(KernelPtrT kernel_ptr, AgentPolicyT agent_policy = {}, LauncherFactory launcher_factory = {})
{
Expand Down Expand Up @@ -784,4 +787,4 @@ private:

CUB_NAMESPACE_END

#include <cub/launcher/cuda_runtime.cuh> // to complete the definition of TripleChevronFactory
#include <cub/detail/launcher/cuda_runtime.cuh> // to complete the definition of TripleChevronFactory
9 changes: 5 additions & 4 deletions cub/cub/util_macro.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,11 @@ _CCCL_DIAG_SUPPRESS_NVHPC(attribute_requires_external_linkage)
#endif

#ifndef CUB_DEFINE_SUB_POLICY_GETTER
# define CUB_DEFINE_SUB_POLICY_GETTER(name) \
CUB_RUNTIME_FUNCTION static constexpr PolicyWrapper<typename StaticPolicyT::name##Policy> name() \
{ \
return MakePolicyWrapper(typename StaticPolicyT::name##Policy()); \
# define CUB_DEFINE_SUB_POLICY_GETTER(name) \
_CCCL_HIDE_FROM_ABI CUB_RUNTIME_FUNCTION static constexpr PolicyWrapper<typename StaticPolicyT::name##Policy> \
name() \
{ \
return MakePolicyWrapper(typename StaticPolicyT::name##Policy()); \
}
#endif

Expand Down
Loading