Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport to 2.8: Specialize cuda::std::numeric_limits for FP8 types (#3478) #3492

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@
#include <cuda/std/climits>
#include <cuda/std/version>

#if defined(_LIBCUDACXX_HAS_NVFP16)
# include <cuda_fp16.h>
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
#endif // _CCCL_HAS_NVFP8()

_CCCL_PUSH_MACROS

_LIBCUDACXX_BEGIN_NAMESPACE_STD
Expand Down Expand Up @@ -744,6 +759,183 @@ public:
};
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# if defined(_CCCL_BUILTIN_BIT_CAST) || _CCCL_STD_VER >= 2014
# define _LIBCUDACXX_CONSTEXPR_FP8_LIMITS constexpr
# else // ^^^ _CCCL_BUILTIN_BIT_CAST || _CCCL_STD_VER >= 2014 ^^^ // vvv !_CCCL_BUILTIN_BIT_CAST && _CCCL_STD_VER <
// 2014 vvv
# define _LIBCUDACXX_CONSTEXPR_FP8_LIMITS
# endif // ^^^ !_CCCL_BUILTIN_BIT_CAST && _CCCL_STD_VER < 2014 ^^^

template <>
class __numeric_limits_impl<__nv_fp8_e4m3, __numeric_limits_type::__floating_point>
{
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS __nv_fp8_e4m3 __make_value(__nv_fp8_storage_t __val)
{
# if defined(_CCCL_BUILTIN_BIT_CAST)
return _CUDA_VSTD::bit_cast<__nv_fp8_e4m3>(__val);
# else // ^^^ _CCCL_BUILTIN_BIT_CAST ^^^ // vvv !_CCCL_BUILTIN_BIT_CAST vvv
__nv_fp8_e4m3 __ret{};
__ret.__x = __val;
return __ret;
# endif // ^^^ !_CCCL_BUILTIN_BIT_CAST ^^^
}

public:
using type = __nv_fp8_e4m3;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x08u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type max() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7eu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type lowest() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0xfeu));
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type epsilon() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x20u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type round_error() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x30u));
}

static constexpr int min_exponent = -6;
static constexpr int min_exponent10 = -2;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;

static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type infinity() noexcept
{
return type{};
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type quiet_NaN() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7fu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type signaling_NaN() noexcept
{
return type{};
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type denorm_min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x01u));
}

static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};

template <>
class __numeric_limits_impl<__nv_fp8_e5m2, __numeric_limits_type::__floating_point>
{
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS __nv_fp8_e5m2 __make_value(__nv_fp8_storage_t __val)
{
# if defined(_CCCL_BUILTIN_BIT_CAST)
return _CUDA_VSTD::bit_cast<__nv_fp8_e5m2>(__val);
# else // ^^^ _CCCL_BUILTIN_BIT_CAST ^^^ // vvv !_CCCL_BUILTIN_BIT_CAST vvv
__nv_fp8_e5m2 __ret{};
__ret.__x = __val;
return __ret;
# endif // ^^^ !_CCCL_BUILTIN_BIT_CAST ^^^
}

public:
using type = __nv_fp8_e5m2;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 2;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x04u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type max() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7bu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type lowest() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0xfbu));
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type epsilon() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x34u));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type round_error() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x38u));
}

static constexpr int min_exponent = -15;
static constexpr int min_exponent10 = -5;
static constexpr int max_exponent = 15;
static constexpr int max_exponent10 = 4;

static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type infinity() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7cu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type quiet_NaN() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7eu));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type signaling_NaN() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x7du));
}
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_CONSTEXPR_FP8_LIMITS type denorm_min() noexcept
{
return __make_value(static_cast<__nv_fp8_storage_t>(0x01u));
}

static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _CCCL_HAS_NVFP8()

template <class _Tp>
class numeric_limits : public __numeric_limits_impl<_Tp>
{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define NUMERIC_LIMITS_MEMBERS_COMMON_H

// Disable all the extended floating point operations and conversions
#define __CUDA_NO_FP8_CONVERSIONS__ 1
#define __CUDA_NO_HALF_CONVERSIONS__ 1
#define __CUDA_NO_HALF_OPERATORS__ 1
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
Expand All @@ -24,6 +25,32 @@ __host__ __device__ bool float_eq(T x, T y)
return x == y;
}

#if _CCCL_HAS_NVFP8()
__host__ __device__ inline __nv_fp8_e4m3 make_fp8_e4m3(double x, __nv_saturation_t sat = __NV_NOSAT)
{
__nv_fp8_e4m3 res;
res.__x = __nv_cvt_double_to_fp8(x, sat, __NV_E4M3);
return res;
}

__host__ __device__ inline __nv_fp8_e5m2 make_fp8_e5m2(double x, __nv_saturation_t sat = __NV_NOSAT)
{
__nv_fp8_e5m2 res;
res.__x = __nv_cvt_double_to_fp8(x, sat, __NV_E5M2);
return res;
}

__host__ __device__ inline bool float_eq(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y)
{
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E4M3)});
}

__host__ __device__ inline bool float_eq(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y)
{
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E5M2)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E5M2)});
}
#endif // _CCCL_HAS_NVFP8

#if defined(_LIBCUDACXX_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test_type<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_type<__nv_fp8_e4m3>();
test_type<__nv_fp8_e5m2>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(9.18354961579912115600575419705e-41));
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.001953125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.0000152587890625));
#endif // _CCCL_HAS_NVFP8()
#if !defined(__FLT_DENORM_MIN__) && !defined(FLT_TRUE_MIN)
# error Test has no expected values for floating point types
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16, 8>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, 3>();
test<__nv_fp8_e5m2, 2>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,25 @@

#include <cuda/std/cfloat>
#include <cuda/std/limits>
#include <cuda/std/type_traits>

#include "test_macros.h"

template <class T, int expected>
template <class T, cuda::std::enable_if_t<cuda::std::is_integral<T>::value, int> = 0>
__host__ __device__ constexpr int make_expected_digits10()
{
// digits * log10(2)
return static_cast<int>((cuda::std::numeric_limits<T>::digits * 30103l) / 100000l);
}

template <class T, cuda::std::enable_if_t<!cuda::std::is_integral<T>::value, int> = 0>
__host__ __device__ constexpr int make_expected_digits10()
{
// (digits - 1) * log10(2)
return static_cast<int>(((cuda::std::numeric_limits<T>::digits - 1) * 30103l) / 100000l);
}

template <class T, int expected = make_expected_digits10<T>()>
__host__ __device__ void test()
{
static_assert(cuda::std::numeric_limits<T>::digits10 == expected, "digits10 test 1");
Expand All @@ -30,41 +45,45 @@ __host__ __device__ void test()

int main(int, char**)
{
test<bool, 0>();
test<char, 2>();
test<signed char, 2>();
test<unsigned char, 2>();
test<wchar_t, 5 * sizeof(wchar_t) / 2 - 1>(); // 4 -> 9 and 2 -> 4
test<bool>();
test<char>();
test<signed char>();
test<unsigned char>();
test<wchar_t>();
#if TEST_STD_VER > 2017 && defined(__cpp_char8_t)
test<char8_t, 2>();
test<char8_t>();
#endif
#ifndef _LIBCUDACXX_HAS_NO_UNICODE_CHARS
test<char16_t, 4>();
test<char32_t, 9>();
test<char16_t>();
test<char32_t>();
#endif // _LIBCUDACXX_HAS_NO_UNICODE_CHARS
test<short, 4>();
test<unsigned short, 4>();
test<int, 9>();
test<unsigned int, 9>();
test<long, sizeof(long) == 4 ? 9 : 18>();
test<unsigned long, sizeof(long) == 4 ? 9 : 19>();
test<long long, 18>();
test<unsigned long long, 19>();
test<short>();
test<unsigned short>();
test<int>();
test<unsigned int>();
test<long>();
test<unsigned long>();
test<long long>();
test<unsigned long long>();
#ifndef _LIBCUDACXX_HAS_NO_INT128
test<__int128_t, 38>();
test<__uint128_t, 38>();
test<__int128_t>();
test<__uint128_t>();
#endif
test<float, FLT_DIG>();
test<double, DBL_DIG>();
test<float>();
test<double>();
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, LDBL_DIG>();
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
test<__half, 3>();
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16, 2>();
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>();
test<__nv_fp8_e5m2>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(0.0078125));
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.25));
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ int main(int, char**)
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16, cuda::std::denorm_present>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, cuda::std::denorm_present>();
test<__nv_fp8_e5m2, cuda::std::denorm_present>();
#endif // _CCCL_HAS_NVFP8()

return 0;
}
Loading
Loading