Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nv/target] Add sm_120 macros. #3550

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions libcudacxx/include/nv/detail/__target_macros
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define _NV_TARGET_ARCH_TO_SELECTOR_900 nv::target::sm_90
#define _NV_TARGET_ARCH_TO_SELECTOR_1000 nv::target::sm_100
#define _NV_TARGET_ARCH_TO_SELECTOR_1010 nv::target::sm_101
#define _NV_TARGET_ARCH_TO_SELECTOR_1200 nv::target::sm_120

#define _NV_TARGET_ARCH_TO_SM_350 35
#define _NV_TARGET_ARCH_TO_SM_370 37
Expand All @@ -54,6 +55,7 @@
#define _NV_TARGET_ARCH_TO_SM_900 90
#define _NV_TARGET_ARCH_TO_SM_1000 100
#define _NV_TARGET_ARCH_TO_SM_1010 101
#define _NV_TARGET_ARCH_TO_SM_1200 120

// Only enable when compiling for CUDA/stdpar
#if defined(_NV_COMPILER_NVCXX) && defined(_NVHPC_CUDA)
Expand All @@ -76,6 +78,7 @@
# define _NV_TARGET_VAL_SM_90 nv::target::sm_90
# define _NV_TARGET_VAL_SM_100 nv::target::sm_100
# define _NV_TARGET_VAL_SM_101 nv::target::sm_101
# define _NV_TARGET_VAL_SM_120 nv::target::sm_120

# define _NV_TARGET___NV_IS_HOST nv::target::is_host
# define _NV_TARGET___NV_IS_DEVICE nv::target::is_device
Expand Down Expand Up @@ -112,6 +115,7 @@
# define _NV_TARGET_VAL_SM_90 900
# define _NV_TARGET_VAL_SM_100 1000
# define _NV_TARGET_VAL_SM_101 1010
# define _NV_TARGET_VAL_SM_120 1200

# if defined(__CUDA_ARCH__)
# define _NV_TARGET_VAL __CUDA_ARCH__
Expand Down Expand Up @@ -160,6 +164,7 @@
# define _NV_TARGET_VAL_SM_90 900
# define _NV_TARGET_VAL_SM_100 1000
# define _NV_TARGET_VAL_SM_101 1010
# define _NV_TARGET_VAL_SM_120 1200

# define _NV_TARGET_VAL 0

Expand Down Expand Up @@ -191,6 +196,7 @@
#define _NV_TARGET___NV_PROVIDES_SM_90 (_NV_TARGET_PROVIDES(_NV_TARGET_VAL_SM_90))
#define _NV_TARGET___NV_PROVIDES_SM_100 (_NV_TARGET_PROVIDES(_NV_TARGET_VAL_SM_100))
#define _NV_TARGET___NV_PROVIDES_SM_101 (_NV_TARGET_PROVIDES(_NV_TARGET_VAL_SM_101))
#define _NV_TARGET___NV_PROVIDES_SM_120 (_NV_TARGET_PROVIDES(_NV_TARGET_VAL_SM_120))

#define _NV_TARGET___NV_IS_EXACTLY_SM_35 (_NV_TARGET_IS_EXACTLY(_NV_TARGET_VAL_SM_35))
#define _NV_TARGET___NV_IS_EXACTLY_SM_37 (_NV_TARGET_IS_EXACTLY(_NV_TARGET_VAL_SM_37))
Expand All @@ -210,6 +216,7 @@
#define _NV_TARGET___NV_IS_EXACTLY_SM_90 (_NV_TARGET_IS_EXACTLY(_NV_TARGET_VAL_SM_90))
#define _NV_TARGET___NV_IS_EXACTLY_SM_100 (_NV_TARGET_IS_EXACTLY(_NV_TARGET_VAL_SM_100))
#define _NV_TARGET___NV_IS_EXACTLY_SM_101 (_NV_TARGET_IS_EXACTLY(_NV_TARGET_VAL_SM_101))
#define _NV_TARGET___NV_IS_EXACTLY_SM_120 (_NV_TARGET_IS_EXACTLY(_NV_TARGET_VAL_SM_120))

#define NV_PROVIDES_SM_35 __NV_PROVIDES_SM_35
#define NV_PROVIDES_SM_37 __NV_PROVIDES_SM_37
Expand All @@ -229,6 +236,7 @@
#define NV_PROVIDES_SM_90 __NV_PROVIDES_SM_90
#define NV_PROVIDES_SM_100 __NV_PROVIDES_SM_100
#define NV_PROVIDES_SM_101 __NV_PROVIDES_SM_101
#define NV_PROVIDES_SM_120 __NV_PROVIDES_SM_120

#define NV_IS_EXACTLY_SM_35 __NV_IS_EXACTLY_SM_35
#define NV_IS_EXACTLY_SM_37 __NV_IS_EXACTLY_SM_37
Expand All @@ -248,6 +256,7 @@
#define NV_IS_EXACTLY_SM_90 __NV_IS_EXACTLY_SM_90
#define NV_IS_EXACTLY_SM_100 __NV_IS_EXACTLY_SM_100
#define NV_IS_EXACTLY_SM_101 __NV_IS_EXACTLY_SM_101
#define NV_IS_EXACTLY_SM_120 __NV_IS_EXACTLY_SM_120

// Disable SM_90a support on non-supporting compilers.
// Will re-enable for nvcc below.
Expand Down Expand Up @@ -381,6 +390,12 @@
# define _NV_TARGET_BOOL___NV_IS_EXACTLY_SM_101 0
# endif

# if (_NV_TARGET___NV_IS_EXACTLY_SM_120)
# define _NV_TARGET_BOOL___NV_IS_EXACTLY_SM_120 1
# else
# define _NV_TARGET_BOOL___NV_IS_EXACTLY_SM_120 0
# endif

// Re-enable sm_90a support in nvcc.
# undef NV_HAS_FEATURE_SM_90a
# define NV_HAS_FEATURE_SM_90a __NV_HAS_FEATURE_SM_90a
Expand Down Expand Up @@ -529,6 +544,12 @@
# define _NV_TARGET_BOOL___NV_PROVIDES_SM_101 0
# endif

# if (_NV_TARGET___NV_PROVIDES_SM_120)
# define _NV_TARGET_BOOL___NV_PROVIDES_SM_120 1
# else
# define _NV_TARGET_BOOL___NV_PROVIDES_SM_120 0
# endif

# define _NV_ARCH_COND_CAT1(cond) _NV_TARGET_BOOL_##cond
# define _NV_ARCH_COND_CAT(cond) _NV_EVAL(_NV_ARCH_COND_CAT1(cond))

Expand Down
9 changes: 7 additions & 2 deletions libcudacxx/include/nv/target
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ constexpr base_int_t sm_89_bit = 1 << 15;
constexpr base_int_t sm_90_bit = 1 << 16;
constexpr base_int_t sm_100_bit = 1 << 17;
constexpr base_int_t sm_101_bit = 1 << 18;
constexpr base_int_t sm_120_bit = 1 << 19;
constexpr base_int_t all_devices =
sm_35_bit | sm_37_bit | sm_50_bit | sm_52_bit | sm_53_bit | sm_60_bit | sm_61_bit | sm_62_bit | sm_70_bit | sm_72_bit
| sm_75_bit | sm_80_bit | sm_86_bit | sm_87_bit | sm_89_bit | sm_90_bit | sm_100_bit | sm_101_bit;
| sm_75_bit | sm_80_bit | sm_86_bit | sm_87_bit | sm_89_bit | sm_90_bit | sm_100_bit | sm_101_bit | sm_120_bit;

// Store a set of targets as a set of bits
struct _NV_BITSET_ATTRIBUTE target_description
Expand Down Expand Up @@ -103,6 +104,7 @@ enum class sm_selector : base_int_t
sm_90 = 90,
sm_100 = 100,
sm_101 = 101,
sm_120 = 120,
};

constexpr base_int_t toint(sm_selector a)
Expand Down Expand Up @@ -130,12 +132,14 @@ constexpr base_int_t bitexact(sm_selector a)
: toint(a) == 90 ? sm_90_bit
: toint(a) == 100 ? sm_100_bit
: toint(a) == 101 ? sm_101_bit
: toint(a) == 120 ? sm_120_bit
: 0;
}

constexpr base_int_t bitrounddown(sm_selector a)
{
return toint(a) >= 101 ? sm_101_bit
return toint(a) >= 120 ? sm_120_bit
: toint(a) >= 101 ? sm_101_bit
: toint(a) >= 100 ? sm_100_bit
: toint(a) >= 90 ? sm_90_bit
: toint(a) >= 89 ? sm_89_bit
Expand Down Expand Up @@ -214,6 +218,7 @@ constexpr sm_selector sm_89 = sm_selector::sm_89;
constexpr sm_selector sm_90 = sm_selector::sm_90;
constexpr sm_selector sm_100 = sm_selector::sm_100;
constexpr sm_selector sm_101 = sm_selector::sm_101;
constexpr sm_selector sm_120 = sm_selector::sm_120;

using detail::is_exactly;
using detail::provides;
Expand Down
Loading