Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the thrust dispatch mechanisms configurable #2310

Merged
merged 13 commits into from
Aug 30, 2024
10 changes: 10 additions & 0 deletions thrust/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ option(THRUST_ENABLE_TESTING "Build Thrust testing suite." "ON")
option(THRUST_ENABLE_EXAMPLES "Build Thrust examples." "ON")
option(THRUST_ENABLE_BENCHMARKS "Build Thrust runtime benchmarks." "${CCCL_ENABLE_BENCHMARKS}")

# Force the offset type to 32bit. Improves compile times and binary size at the expense of limiting input sizes
option(THRUST_FORCE_32BIT_OFFSET_TYPE "Requires thrust to use 32 bit offset types." "OFF")

# Force the offset type to 64bit. Improves compile times and binary size potentially degrading runtime performance
option(THRUST_FORCE_64BIT_OFFSET_TYPE "Requires thrust to use 64 bit offset types." "OFF")

if (THRUST_FORCE_32BIT_OFFSET_TYPE AND THRUST_FORCE_64BIT_OFFSET_TYPE)
message(FATAL_ERROR "Only THRUST_FORCE_32BIT_OFFSET_TYPE or THRUST_FORCE_64BIT_OFFSET_TYPE may be defined!")
endif()
miscco marked this conversation as resolved.
Show resolved Hide resolved

# Check if we're actually building anything before continuing. If not, no need
# to search for deps, etc. This is a common approach for packagers that just
# need the install rules. See GH issue NVIDIA/thrust#1211.
Expand Down
8 changes: 8 additions & 0 deletions thrust/cmake/ThrustBuildCompilerTargets.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ function(thrust_build_compiler_targets)
)
endforeach()

if (THRUST_FORCE_32BIT_OFFSET_TYPE)
list(APPEND cxx_compile_definitions "THRUST_FORCE_32BIT_OFFSET_TYPE")
miscco marked this conversation as resolved.
Show resolved Hide resolved
endif()

if (THRUST_FORCE_64BIT_OFFSET_TYPE)
list(APPEND cxx_compile_definitions "THRUST_FORCE_64BIT_OFFSET_TYPE")
miscco marked this conversation as resolved.
Show resolved Hide resolved
endif()

foreach (cxx_definition IN LISTS cxx_compile_definitions)
# Add these for both CUDA and CXX targets:
target_compile_definitions(thrust.compiler_interface INTERFACE
Expand Down
13 changes: 13 additions & 0 deletions thrust/cmake/ThrustHeaderTesting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ foreach(thrust_target IN LISTS THRUST_TARGETS)
"CUB_WRAPPED_NAMESPACE=wrapped_cub")
thrust_add_header_test(${thrust_target} base "${header_definitions}")

# We need to ensure that the different dispatch mechanisms work
set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
"THRUST_FORCE_32BIT_OFFSET_TYPE")
thrust_add_header_test(${thrust_target} offset_32 "${header_definitions}")

set(header_definitions
"THRUST_WRAPPED_NAMESPACE=wrapped_thrust"
"CUB_WRAPPED_NAMESPACE=wrapped_cub"
"THRUST_FORCE_64BIT_OFFSET_TYPE")
thrust_add_header_test(${thrust_target} offset_64 "${header_definitions}")

thrust_get_target_property(config_device ${thrust_target} DEVICE)
if ("CUDA" STREQUAL "${config_device}")
# Check that BF16 support can be disabled
Expand Down
217 changes: 135 additions & 82 deletions thrust/thrust/system/cuda/detail/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,92 +30,145 @@
#include <thrust/detail/preprocessor.h>

#include <cstdint>
#include <stdexcept>

/**
* Dispatch between 32-bit and 64-bit index based versions of the same algorithm implementation. This version assumes
* that callables for both branches consist of the same tokens, and is intended to be used with Thrust-style dispatch
* interfaces, that always deduce the size type from the arguments.
*/
#define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
if (count <= thrust::detail::integer_traits<std::int32_t>::const_max) \
{ \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<std::int32_t>(count); \
status = call arguments; \
} \
else \
{ \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<std::int64_t>(count); \
status = call arguments; \
}

/**
* Dispatch between 32-bit and 64-bit index based versions of the same algorithm implementation. This version assumes
* that callables for both branches consist of the same tokens, and is intended to be used with Thrust-style dispatch
* interfaces, that always deduce the size type from the arguments.
*
* This version of the macro supports providing two count variables, which is necessary for set algorithms.
*/
#define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
if (count1 + count2 <= thrust::detail::integer_traits<std::int32_t>::const_max) \
{ \
auto THRUST_PP_CAT2(count1, _fixed) = static_cast<std::int32_t>(count1); \
auto THRUST_PP_CAT2(count2, _fixed) = static_cast<std::int32_t>(count2); \
status = call arguments; \
} \
else \
{ \
auto THRUST_PP_CAT2(count1, _fixed) = static_cast<std::int64_t>(count1); \
auto THRUST_PP_CAT2(count2, _fixed) = static_cast<std::int64_t>(count2); \
status = call arguments; \
}
#if defined(THRUST_FORCE_32BIT_OFFSET_TYPE) && defined(THRUST_FORCE_64BIT_OFFSET_TYPE)
# error "Only THRUST_FORCE_32BIT_OFFSET_TYPE or THRUST_FORCE_64BIT_OFFSET_TYPE may be defined!"
#endif // THRUST_FORCE_32BIT_OFFSET_TYPE && THRUST_FORCE_64BIT_OFFSET_TYPE

/**
* Dispatch between 32-bit and 64-bit index based versions of the same algorithm implementation. This version allows
* using different token sequences for callables in both branches, and is intended to be used with CUB-style dispatch
* interfaces, where the "simple" interface always forces the size to be `int` (making it harder for us to use), but the
* complex interface that we end up using doesn't actually provide a way to fully deduce the type from just the call,
* making the size type appear in the token sequence of the callable.
*
* See reduce_n_impl to see an example of how this is meant to be used.
*/
#define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
if (count <= thrust::detail::integer_traits<std::int32_t>::const_max) \
#define _THRUST_INDEX_TYPE_DISPATCH(index_type, status, call, count, arguments) \
{ \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<std::int32_t>(count); \
status = call_32 arguments; \
} \
else \
{ \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<std::int64_t>(count); \
status = call_64 arguments; \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<index_type>(count); \
status = call arguments; \
}

/// Like \ref THRUST_INDEX_TYPE_DISPATCH2 but dispatching to uint32_t and uint64_t, respectively, depending on the
/// `count` argument. `count` must not be negative.
#define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
if (static_cast<std::uint64_t>(count) \
<= static_cast<std::uint64_t>(thrust::detail::integer_traits<std::uint32_t>::const_max)) \
{ \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<std::uint32_t>(count); \
status = call_32 arguments; \
} \
else \
{ \
auto THRUST_PP_CAT2(count, _fixed) = static_cast<std::uint64_t>(count); \
status = call_64 arguments; \
#define _THRUST_INDEX_TYPE_DISPATCH2(index_type, status, call, count1, count2, arguments) \
{ \
auto THRUST_PP_CAT2(count1, _fixed) = static_cast<index_type>(count1); \
auto THRUST_PP_CAT2(count2, _fixed) = static_cast<index_type>(count2); \
status = call arguments; \
}

/// Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
#define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
if (count1 + count2 <= thrust::detail::integer_traits<std::int32_t>::const_max) \
{ \
auto THRUST_PP_CAT2(count1, _fixed) = static_cast<std::int32_t>(count1); \
auto THRUST_PP_CAT2(count2, _fixed) = static_cast<std::int32_t>(count2); \
status = call_32 arguments; \
} \
else \
{ \
auto THRUST_PP_CAT2(count1, _fixed) = static_cast<std::int64_t>(count1); \
auto THRUST_PP_CAT2(count2, _fixed) = static_cast<std::int64_t>(count2); \
status = call_64 arguments; \
}
#if defined(THRUST_FORCE_64BIT_OFFSET_TYPE)
//! @brief Always dispatches to 64 bit offset version of an algorithm
# define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two counts
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int64_t, status, call, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two different call implementations
# define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call_64, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int64_t, status, call_64, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but always dispatching to uint64_t. `count` must not be negative.
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH(std::uint64_t, status, call_64, count, arguments)

#elif defined(THRUST_FORCE_32BIT_OFFSET_TYPE)

//! @brief Ensures that the size of the input does not overflow the offset type
# define _THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(index_type, count) \
if (static_cast<std::uint64_t>(count) \
> static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max)) \
{ \
throw ::std::runtime_error("Offset type overflow"); \
miscco marked this conversation as resolved.
Show resolved Hide resolved
}

//! @brief Ensures that the sizes of the inputs do not overflow the offset type, but two counts
# define _THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(index_type, count1, count2) \
if (static_cast<std::uint64_t>(count1) + static_cast<std::uint64_t>(count2) \
> static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max)) \
{ \
throw ::std::runtime_error("Offset type overflow"); \
miscco marked this conversation as resolved.
Show resolved Hide resolved
}

//! @brief Always dispatches to 32 bit offset version of an algorithm but throws if count would overflow
# define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two counts
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two different call implementations
# define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call_32, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call_32, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but always dispatching to uint64_t. `count` must not be negative.
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::uint32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::uint32_t, status, call_32, count, arguments)

#else // ^^^ THRUST_FORCE_32BIT_OFFSET_TYPE ^^^ / vvv !THRUST_FORCE_32BIT_OFFSET_TYPE vvv

# define _THRUST_INDEX_TYPE_DISPATCH_SELECT(index_type, count) \
(static_cast<std::uint64_t>(count) \
<= static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max))

# define _THRUST_INDEX_TYPE_DISPATCH_SELECT2(index_type, count1, count2) \
(static_cast<std::uint64_t>(count1) + static_cast<std::uint64_t>(count2) \
<= static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max))

//! Dispatch between 32-bit and 64-bit index_type based versions of the same algorithm implementation. This version
//! assumes that callables for both branches consist of the same tokens, and is intended to be used with Thrust-style
//! dispatch interfaces, that always deduce the size type from the arguments.
# define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call, count, arguments) \
else \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call, count, arguments)

//! Dispatch between 32-bit and 64-bit index_type based versions of the same algorithm implementation. This version
//! assumes that callables for both branches consist of the same tokens, and is intended to be used with Thrust-style
//! dispatch interfaces, that always deduce the size type from the arguments.
//!
//! This version of the macro supports providing two count variables, which is necessary for set algorithms.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT2 (std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call, count1, count2, arguments) \
else \
_THRUST_INDEX_TYPE_DISPATCH2(std::int64_t, status, call, count1, count2, arguments)

//! Dispatch between 32-bit and 64-bit index_type based versions of the same algorithm implementation. This version
//! allows using different token sequences for callables in both branches, and is intended to be used with CUB-style
//! dispatch interfaces, where the "simple" interface always forces the size to be `int` (making it harder for us to
//! use), but the complex interface that we end up using doesn't actually provide a way to fully deduce the type from
//! just the call, making the size type appear in the token sequence of the callable.
//!
//! See reduce_n_impl to see an example of how this is meant to be used.
# define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call_32, count, arguments) \
else \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call_64, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT2 (std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call_32, count1, count2, arguments) \
else \
_THRUST_INDEX_TYPE_DISPATCH2(std::int64_t, status, call_64, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but dispatching to uint32_t and uint64_t, respectively, depending on the
//! `count` argument. `count` must not be negative.
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::uint32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::uint32_t, status, call_32, count, arguments) \
else \
_THRUST_INDEX_TYPE_DISPATCH(std::uint64_t, status, call_64, count, arguments)

#endif // !THRUST_FORCE_32BIT_OFFSET_TYPE
Loading