Skip to content

Commit

Permalink
Backport to 2.8: some FP8 support (#3479)
Browse files Browse the repository at this point in the history
* add `_CCCL_HAS_NVFP8` macro (#3429)

* Add cuda::is_floating_point supporting half and bfloat (#3379)

Co-authored-by: Michael Schellenberger Costa <[email protected]>

* Specialize __is_extended_floating_point for FP8 types (#3470)

Also ensure that we actually can enable FP8 due to FP16 and BF16 requirements

Co-authored-by: Michael Schellenberger Costa <[email protected]>

---------

Co-authored-by: Federico Busato <[email protected]>
Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2025
1 parent d5ca93c commit 51b08b0
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 0 deletions.
45 changes: 45 additions & 0 deletions libcudacxx/include/cuda/__type_traits/is_floating_point.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H
#define __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/remove_cv.h>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template <class _Tp>
struct _CCCL_TYPE_VISIBILITY_DEFAULT is_floating_point
: _CUDA_VSTD::bool_constant<_CUDA_VSTD::is_floating_point<_CUDA_VSTD::remove_cv_t<_Tp>>::value
|| _CUDA_VSTD::__is_extended_floating_point<_CUDA_VSTD::remove_cv_t<_Tp>>::value>
{};

#if !defined(_CCCL_NO_VARIABLE_TEMPLATES)
template <class _Tp>
_CCCL_INLINE_VAR constexpr bool is_floating_point_v =
_CUDA_VSTD::is_floating_point_v<_CUDA_VSTD::remove_cv_t<_Tp>>
|| _CUDA_VSTD::__is_extended_floating_point_v<_CUDA_VSTD::remove_cv_t<_Tp>>;
#endif // !_CCCL_NO_VARIABLE_TEMPLATES

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif // __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H
10 changes: 10 additions & 0 deletions libcudacxx/include/cuda/std/__cccl/extended_floating_point.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,14 @@
# endif
#endif // !_CCCL_HAS_NVBF16

#if !defined(_CCCL_DISABLE_NVFP8_SUPPORT)
# if _CCCL_HAS_INCLUDE(<cuda_fp8.h>) && defined(_CCCL_HAS_NVFP16) && defined(_CCCL_HAS_NVBF16)
# define _CCCL_HAS_NVFP8() 1
# else
# define _CCCL_HAS_NVFP8() 0
# endif // _CCCL_HAS_INCLUDE(<cuda_fp8.h>)
#else
# define _CCCL_HAS_NVFP8() 0
#endif // !defined(_CCCL_DISABLE_NVFP8_SUPPORT)

#endif // __CCCL_EXTENDED_FLOATING_POINT_H
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
#endif // _CCCL_HAS_NVFP8()

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
Expand Down Expand Up @@ -71,6 +75,22 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> =
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
template <>
struct __is_extended_floating_point<__nv_fp8_e4m3> : true_type
{};
template <>
struct __is_extended_floating_point<__nv_fp8_e5m2> : true_type
{};

# ifndef _CCCL_NO_INLINE_VARIABLES
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e4m3> = true;
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e5m2> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _CCCL_HAS_NVFP8()

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H
27 changes: 27 additions & 0 deletions libcudacxx/include/cuda/type_traits
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_TYPE_TRAITS_
#define _CUDA_TYPE_TRAITS_

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/__type_traits/is_floating_point.h>
#include <cuda/std/type_traits>

#endif // _CUDA_TYPE_TRAITS_
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#include <cuda/std/__cccl/extended_floating_point.h>

#include "test_macros.h"

#if !_CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
#endif
#if !defined(_CCCL_HAS_NVFP16)
# include <cuda_fp16.h>
#endif
#if !defined(_CCCL_HAS_NVBF16)
# include <cuda_bf16.h>
#endif

int main(int, char**)
{
#if !_CCCL_HAS_NVFP8()
auto x = __nv_fp8_e4m3(1.0f);
unused(x);
#else
static_assert(false);
#endif
#if !defined(_CCCL_HAS_NVFP16)
auto y = __half(1.0f);
unused(y);
#else
static_assert(false);
#endif
#if !defined(_CCCL_HAS_NVBF16)
auto z = __nv_bfloat16(1.0f);
unused(z);
#else
static_assert(false);
#endif
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#include <cuda/std/__cccl/extended_floating_point.h>

#include "test_macros.h"

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
#endif
#if defined(_CCCL_HAS_NVFP16)
# include <cuda_fp16.h>
#endif
#if defined(_CCCL_HAS_NVBF16)
# include <cuda_bf16.h>
#endif

int main(int, char**)
{
#if _CCCL_HAS_NVFP8()
auto x = __nv_fp8_e4m3(1.0f);
unused(x);
#endif
#if defined(_CCCL_HAS_NVFP16)
auto y = __half(1.0f);
unused(y);
#endif
#if defined(_CCCL_HAS_NVBF16)
auto z = __nv_bfloat16(1.0f);
unused(z);
#endif
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// keep this test in sync with `is_floating_point.pass.cpp` for `cuda::std::is_floating_point`

#include <cuda/std/cstddef> // for cuda::std::nullptr_t
#include <cuda/type_traits>

#include "test_macros.h"

TEST_NV_DIAG_SUPPRESS(cuda_demote_unsupported_floating_point)

template <class T>
__host__ __device__ void test_is_floating_point()
{
static_assert(cuda::is_floating_point<T>::value, "");
static_assert(cuda::is_floating_point<const T>::value, "");
static_assert(cuda::is_floating_point<volatile T>::value, "");
static_assert(cuda::is_floating_point<const volatile T>::value, "");
#if TEST_STD_VER > 2011
static_assert(cuda::is_floating_point_v<T>, "");
static_assert(cuda::is_floating_point_v<const T>, "");
static_assert(cuda::is_floating_point_v<volatile T>, "");
static_assert(cuda::is_floating_point_v<const volatile T>, "");
#endif
}

template <class T>
__host__ __device__ void test_is_not_floating_point()
{
static_assert(!cuda::is_floating_point<T>::value, "");
static_assert(!cuda::is_floating_point<const T>::value, "");
static_assert(!cuda::is_floating_point<volatile T>::value, "");
static_assert(!cuda::is_floating_point<const volatile T>::value, "");
#if TEST_STD_VER > 2011
static_assert(!cuda::is_floating_point_v<T>, "");
static_assert(!cuda::is_floating_point_v<const T>, "");
static_assert(!cuda::is_floating_point_v<volatile T>, "");
static_assert(!cuda::is_floating_point_v<const volatile T>, "");
#endif
}

class Empty
{};

class NotEmpty
{
__host__ __device__ virtual ~NotEmpty();
};

union Union
{};

struct bit_zero
{
int : 0;
};

class Abstract
{
__host__ __device__ virtual ~Abstract() = 0;
};

enum Enum
{
zero,
one
};
struct incomplete_type;

typedef void (*FunctionPtr)();

int main(int, char**)
{
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_is_floating_point<__nv_fp8_e4m3>();
test_is_floating_point<__nv_fp8_e5m2>();
#endif // ())

test_is_not_floating_point<short>();
test_is_not_floating_point<unsigned short>();
test_is_not_floating_point<int>();
test_is_not_floating_point<unsigned int>();
test_is_not_floating_point<long>();
test_is_not_floating_point<unsigned long>();

test_is_not_floating_point<cuda::std::nullptr_t>();
test_is_not_floating_point<void>();
test_is_not_floating_point<int&>();
test_is_not_floating_point<int&&>();
test_is_not_floating_point<int*>();
test_is_not_floating_point<const int*>();
test_is_not_floating_point<char[3]>();
test_is_not_floating_point<char[]>();
test_is_not_floating_point<Union>();
test_is_not_floating_point<Empty>();
test_is_not_floating_point<bit_zero>();
test_is_not_floating_point<NotEmpty>();
test_is_not_floating_point<Abstract>();
test_is_not_floating_point<Enum>();
test_is_not_floating_point<FunctionPtr>();
test_is_not_floating_point<incomplete_type>();

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//

// keep this test in sync with `is_floating_point.pass.cpp` for `cuda::is_floating_point`

// type_traits

// is_floating_point
Expand Down

0 comments on commit 51b08b0

Please sign in to comment.