From 94f63bbb6cfd9f94ec04d0a434eaa222e880b937 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Tue, 28 Jan 2025 19:38:20 +0100 Subject: [PATCH] Sync ptx_dot_variants.h with libcuda-ptx (#3564) --- .../include/cuda/__ptx/ptx_dot_variants.h | 48 +++++++++++++++++++ .../include/cuda/__ptx/ptx_helper_functions.h | 4 +- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/libcudacxx/include/cuda/__ptx/ptx_dot_variants.h b/libcudacxx/include/cuda/__ptx/ptx_dot_variants.h index 889839c69fb..574e6d21806 100644 --- a/libcudacxx/include/cuda/__ptx/ptx_dot_variants.h +++ b/libcudacxx/include/cuda/__ptx/ptx_dot_variants.h @@ -9,6 +9,8 @@ // //===----------------------------------------------------------------------===// +// WARNING: The source of truth for this file is libcuda-ptx. Do not modify without syncing with libcuda-ptx. + #ifndef _CUDA_PTX_DOT_VARIANTS_H_ #define _CUDA_PTX_DOT_VARIANTS_H_ @@ -111,6 +113,23 @@ enum class dot_op exch }; +enum class dot_cta_group +{ + cta_group_1, + cta_group_2 +}; + +enum class dot_kind +{ + f16, + f8f6f4, + i8, + mxf4, + mxf4nvf4, + mxf8f6f4, + tf32 +}; + template using sem_t = _CUDA_VSTD::integral_constant; using sem_acq_rel_t = sem_t; @@ -173,6 +192,35 @@ static constexpr op_xor_op_t op_xor_op{}; static constexpr op_cas_t op_cas{}; static constexpr op_exch_t op_exch{}; +template +using cta_group_t = _CUDA_VSTD::integral_constant; +using cta_group_1_t = cta_group_t; +using cta_group_2_t = cta_group_t; + +static constexpr cta_group_1_t cta_group_1{}; +static constexpr cta_group_2_t cta_group_2{}; + +template +using kind_t = _CUDA_VSTD::integral_constant; +using kind_f16_t = kind_t; +using kind_f8f6f4_t = kind_t; +using kind_i8_t = kind_t; +using kind_mxf4_t = kind_t; +using kind_mxf4nvf4_t = kind_t; +using kind_mxf8f6f4_t = kind_t; +using kind_tf32_t = kind_t; + +static constexpr kind_f16_t kind_f16{}; +static constexpr kind_f8f6f4_t kind_f8f6f4{}; +static constexpr kind_i8_t kind_i8{}; +static constexpr kind_mxf4_t kind_mxf4{}; +static constexpr kind_mxf4nvf4_t kind_mxf4nvf4{}; +static constexpr kind_mxf8f6f4_t kind_mxf8f6f4{}; +static constexpr kind_tf32_t kind_tf32{}; + +template +using n32_t = _CUDA_VSTD::integral_constant; + _LIBCUDACXX_END_NAMESPACE_CUDA_PTX #endif // _CUDA_PTX_DOT_VARIANTS_H_ diff --git a/libcudacxx/include/cuda/__ptx/ptx_helper_functions.h b/libcudacxx/include/cuda/__ptx/ptx_helper_functions.h index 9ce2b455d59..b536a87fb63 100644 --- a/libcudacxx/include/cuda/__ptx/ptx_helper_functions.h +++ b/libcudacxx/include/cuda/__ptx/ptx_helper_functions.h @@ -22,15 +22,13 @@ # pragma system_header #endif // no system header +#include #include #include #include _LIBCUDACXX_BEGIN_NAMESPACE_CUDA_PTX -template -using n32_t = _CUDA_VSTD::integral_constant; - /************************************************************* * * Conversion from generic pointer -> state space "pointer"