Skip to content

Commit

Permalink
Add cuda::is_floating_point supporting half and bfloat
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 21, 2025
1 parent 8da3ace commit a3d5e27
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 0 deletions.
55 changes: 55 additions & 0 deletions libcudacxx/include/cuda/__type_traits/is_floating_point.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H
#define __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/remove_cv.h>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template <class _Tp>
struct __cccl_is_floating_point : _CUDA_VSTD::is_floating_point<_Tp>
{};
#ifdef _LIBCUDACXX_HAS_NVFP16
template <>
struct __cccl_is_floating_point<__half> : _CUDA_VSTD::true_type
{};
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
template <>
struct __cccl_is_floating_point<__nv_bfloat16> : _CUDA_VSTD::true_type
{};
#endif // _LIBCUDACXX_HAS_NVBF16

template <class _Tp>
struct _CCCL_TYPE_VISIBILITY_DEFAULT is_floating_point : __cccl_is_floating_point<_CUDA_VSTD::remove_cv_t<_Tp>>
{};

#if !defined(_CCCL_NO_VARIABLE_TEMPLATES)
template <class _Tp>
_CCCL_INLINE_VAR constexpr bool is_floating_point_v = is_floating_point<_Tp>::value;
#endif // !_CCCL_NO_VARIABLE_TEMPLATES

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif // __CUDA__TYPE_TRAITS_IS_FLOATING_POINT_H
27 changes: 27 additions & 0 deletions libcudacxx/include/cuda/type_traits
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_TYPE_TRAITS_
#define _CUDA_TYPE_TRAITS_

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/__type_traits/is_floating_point.h>
#include <cuda/std/type_traits>

#endif // _CUDA_TYPE_TRAITS_
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// keep this test in sync with `is_floating_point.pass.cpp` for `cuda::std::is_floating_point`

#include <cuda/std/cstddef> // for cuda::std::nullptr_t
#include <cuda/type_traits>

#include "test_macros.h"

TEST_NV_DIAG_SUPPRESS(cuda_demote_unsupported_floating_point)

template <class T>
__host__ __device__ void test_is_floating_point()
{
static_assert(cuda::is_floating_point<T>::value, "");
static_assert(cuda::is_floating_point<const T>::value, "");
static_assert(cuda::is_floating_point<volatile T>::value, "");
static_assert(cuda::is_floating_point<const volatile T>::value, "");
#if TEST_STD_VER > 2011
static_assert(cuda::is_floating_point_v<T>, "");
static_assert(cuda::is_floating_point_v<const T>, "");
static_assert(cuda::is_floating_point_v<volatile T>, "");
static_assert(cuda::is_floating_point_v<const volatile T>, "");
#endif
}

template <class T>
__host__ __device__ void test_is_not_floating_point()
{
static_assert(!cuda::is_floating_point<T>::value, "");
static_assert(!cuda::is_floating_point<const T>::value, "");
static_assert(!cuda::is_floating_point<volatile T>::value, "");
static_assert(!cuda::is_floating_point<const volatile T>::value, "");
#if TEST_STD_VER > 2011
static_assert(!cuda::is_floating_point_v<T>, "");
static_assert(!cuda::is_floating_point_v<const T>, "");
static_assert(!cuda::is_floating_point_v<volatile T>, "");
static_assert(!cuda::is_floating_point_v<const volatile T>, "");
#endif
}

class Empty
{};

class NotEmpty
{
__host__ __device__ virtual ~NotEmpty();
};

union Union
{};

struct bit_zero
{
int : 0;
};

class Abstract
{
__host__ __device__ virtual ~Abstract() = 0;
};

enum Enum
{
zero,
one
};
struct incomplete_type;

typedef void (*FunctionPtr)();

int main(int, char**)
{
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16

test_is_not_floating_point<short>();
test_is_not_floating_point<unsigned short>();
test_is_not_floating_point<int>();
test_is_not_floating_point<unsigned int>();
test_is_not_floating_point<long>();
test_is_not_floating_point<unsigned long>();

test_is_not_floating_point<cuda::std::nullptr_t>();
test_is_not_floating_point<void>();
test_is_not_floating_point<int&>();
test_is_not_floating_point<int&&>();
test_is_not_floating_point<int*>();
test_is_not_floating_point<const int*>();
test_is_not_floating_point<char[3]>();
test_is_not_floating_point<char[]>();
test_is_not_floating_point<Union>();
test_is_not_floating_point<Empty>();
test_is_not_floating_point<bit_zero>();
test_is_not_floating_point<NotEmpty>();
test_is_not_floating_point<Abstract>();
test_is_not_floating_point<Enum>();
test_is_not_floating_point<FunctionPtr>();
test_is_not_floating_point<incomplete_type>();

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//

// keep this test in sync with `is_floating_point.pass.cpp` for `cuda::is_floating_point`

// type_traits

// is_floating_point
Expand Down

0 comments on commit a3d5e27

Please sign in to comment.