Skip to content

Commit

Permalink
Add cuda::is_floating_point supporting half and bfloat (NVIDIA#3379)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
2 people authored and davebayer committed Jan 22, 2025
1 parent b329447 commit f05b524
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 0 deletions.
45 changes: 45 additions & 0 deletions libcudacxx/include/cuda/__type_traits/is_floating_point.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H
#define __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/remove_cv.h>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template <class _Tp>
struct _CCCL_TYPE_VISIBILITY_DEFAULT is_floating_point
: _CUDA_VSTD::bool_constant<_CUDA_VSTD::is_floating_point<_CUDA_VSTD::remove_cv_t<_Tp>>::value
|| _CUDA_VSTD::__is_extended_floating_point<_CUDA_VSTD::remove_cv_t<_Tp>>::value>
{};

#if !defined(_CCCL_NO_VARIABLE_TEMPLATES)
template <class _Tp>
_CCCL_INLINE_VAR constexpr bool is_floating_point_v =
_CUDA_VSTD::is_floating_point_v<_CUDA_VSTD::remove_cv_t<_Tp>>
|| _CUDA_VSTD::__is_extended_floating_point_v<_CUDA_VSTD::remove_cv_t<_Tp>>;
#endif // !_CCCL_NO_VARIABLE_TEMPLATES

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif // __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H
27 changes: 27 additions & 0 deletions libcudacxx/include/cuda/type_traits
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_TYPE_TRAITS_
#define _CUDA_TYPE_TRAITS_

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/__type_traits/is_floating_point.h>
#include <cuda/std/type_traits>

#endif // _CUDA_TYPE_TRAITS_
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// keep this test in sync with `is_floating_point.pass.cpp` for `cuda::std::is_floating_point`

#include <cuda/std/cstddef> // for cuda::std::nullptr_t
#include <cuda/type_traits>

#include "test_macros.h"

TEST_NV_DIAG_SUPPRESS(cuda_demote_unsupported_floating_point)

template <class T>
__host__ __device__ void test_is_floating_point()
{
static_assert(cuda::is_floating_point<T>::value, "");
static_assert(cuda::is_floating_point<const T>::value, "");
static_assert(cuda::is_floating_point<volatile T>::value, "");
static_assert(cuda::is_floating_point<const volatile T>::value, "");
#if TEST_STD_VER > 2011
static_assert(cuda::is_floating_point_v<T>, "");
static_assert(cuda::is_floating_point_v<const T>, "");
static_assert(cuda::is_floating_point_v<volatile T>, "");
static_assert(cuda::is_floating_point_v<const volatile T>, "");
#endif
}

template <class T>
__host__ __device__ void test_is_not_floating_point()
{
static_assert(!cuda::is_floating_point<T>::value, "");
static_assert(!cuda::is_floating_point<const T>::value, "");
static_assert(!cuda::is_floating_point<volatile T>::value, "");
static_assert(!cuda::is_floating_point<const volatile T>::value, "");
#if TEST_STD_VER > 2011
static_assert(!cuda::is_floating_point_v<T>, "");
static_assert(!cuda::is_floating_point_v<const T>, "");
static_assert(!cuda::is_floating_point_v<volatile T>, "");
static_assert(!cuda::is_floating_point_v<const volatile T>, "");
#endif
}

class Empty
{};

class NotEmpty
{
__host__ __device__ virtual ~NotEmpty();
};

union Union
{};

struct bit_zero
{
int : 0;
};

class Abstract
{
__host__ __device__ virtual ~Abstract() = 0;
};

enum Enum
{
zero,
one
};
struct incomplete_type;

typedef void (*FunctionPtr)();

int main(int, char**)
{
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16

test_is_not_floating_point<short>();
test_is_not_floating_point<unsigned short>();
test_is_not_floating_point<int>();
test_is_not_floating_point<unsigned int>();
test_is_not_floating_point<long>();
test_is_not_floating_point<unsigned long>();

test_is_not_floating_point<cuda::std::nullptr_t>();
test_is_not_floating_point<void>();
test_is_not_floating_point<int&>();
test_is_not_floating_point<int&&>();
test_is_not_floating_point<int*>();
test_is_not_floating_point<const int*>();
test_is_not_floating_point<char[3]>();
test_is_not_floating_point<char[]>();
test_is_not_floating_point<Union>();
test_is_not_floating_point<Empty>();
test_is_not_floating_point<bit_zero>();
test_is_not_floating_point<NotEmpty>();
test_is_not_floating_point<Abstract>();
test_is_not_floating_point<Enum>();
test_is_not_floating_point<FunctionPtr>();
test_is_not_floating_point<incomplete_type>();

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//

// keep this test in sync with `is_floating_point.pass.cpp` for `cuda::is_floating_point`

// type_traits

// is_floating_point
Expand Down

0 comments on commit f05b524

Please sign in to comment.