Skip to content

Commit

Permalink
Support float_eq on CTK < 12.2
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 28, 2025
1 parent fe27263 commit 4aaedf9
Showing 1 changed file with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1

#include <cuda/std/__bit/bit_cast.h>
#include <cuda/std/limits>

template <class T>
Expand All @@ -42,26 +43,42 @@ __host__ __device__ inline __nv_fp8_e5m2 make_fp8_e5m2(double x, __nv_saturation

__host__ __device__ inline bool float_eq(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E4M3)});
# else
return ::cuda::std::bit_cast<unsigned char>(x) == ::cuda::std::bit_cast<unsigned char>(y);
# endif
}

__host__ __device__ inline bool float_eq(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E5M2)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E5M2)});
# else
return ::cuda::std::bit_cast<unsigned char>(x) == ::cuda::std::bit_cast<unsigned char>(y);
# endif
}
#endif // _CCCL_HAS_NVFP8

#if defined(_CCCL_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return __heq(x, y);
# else
return __half2float(x) == __half2float(y);
# endif
}
#endif // _CCCL_HAS_NVFP16

#if defined(_CCCL_HAS_NVBF16)
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return __heq(x, y);
# else
return __bfloat162float(x) == __bfloat162float(y);
# endif
}
#endif // _CCCL_HAS_NVBF16

Expand Down

0 comments on commit 4aaedf9

Please sign in to comment.