Skip to content

Commit

Permalink
Sync ptx_dot_variants.h with libcuda-ptx (#3564)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Jan 28, 2025
1 parent 4567491 commit e08bda4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
48 changes: 48 additions & 0 deletions libcudacxx/include/cuda/__ptx/ptx_dot_variants.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
//
//===----------------------------------------------------------------------===//

// WARNING: The source of truth for this file is libcuda-ptx. Do not modify without syncing with libcuda-ptx.

#ifndef _CUDA_PTX_DOT_VARIANTS_H_
#define _CUDA_PTX_DOT_VARIANTS_H_

Expand Down Expand Up @@ -111,6 +113,23 @@ enum class dot_op
exch
};

enum class dot_cta_group
{
cta_group_1,
cta_group_2
};

enum class dot_kind
{
f16,
f8f6f4,
i8,
mxf4,
mxf4nvf4,
mxf8f6f4,
tf32
};

template <dot_sem __sem>
using sem_t = _CUDA_VSTD::integral_constant<dot_sem, __sem>;
using sem_acq_rel_t = sem_t<dot_sem::acq_rel>;
Expand Down Expand Up @@ -173,6 +192,35 @@ static constexpr op_xor_op_t op_xor_op{};
static constexpr op_cas_t op_cas{};
static constexpr op_exch_t op_exch{};

template <dot_cta_group __cta_group>
using cta_group_t = _CUDA_VSTD::integral_constant<dot_cta_group, __cta_group>;
using cta_group_1_t = cta_group_t<dot_cta_group::cta_group_1>;
using cta_group_2_t = cta_group_t<dot_cta_group::cta_group_1>;

static constexpr cta_group_1_t cta_group_1{};
static constexpr cta_group_2_t cta_group_2{};

template <dot_kind __kind>
using kind_t = _CUDA_VSTD::integral_constant<dot_kind, __kind>;
using kind_f16_t = kind_t<dot_kind::f16>;
using kind_f8f6f4_t = kind_t<dot_kind::f8f6f4>;
using kind_i8_t = kind_t<dot_kind::i8>;
using kind_mxf4_t = kind_t<dot_kind::mxf4>;
using kind_mxf4nvf4_t = kind_t<dot_kind::mxf4nvf4>;
using kind_mxf8f6f4_t = kind_t<dot_kind::mxf8f6f4>;
using kind_tf32_t = kind_t<dot_kind::tf32>;

static constexpr kind_f16_t kind_f16{};
static constexpr kind_f8f6f4_t kind_f8f6f4{};
static constexpr kind_i8_t kind_i8{};
static constexpr kind_mxf4_t kind_mxf4{};
static constexpr kind_mxf4nvf4_t kind_mxf4nvf4{};
static constexpr kind_mxf8f6f4_t kind_mxf8f6f4{};
static constexpr kind_tf32_t kind_tf32{};

template <int n>
using n32_t = _CUDA_VSTD::integral_constant<int, n>;

_LIBCUDACXX_END_NAMESPACE_CUDA_PTX

#endif // _CUDA_PTX_DOT_VARIANTS_H_
4 changes: 1 addition & 3 deletions libcudacxx/include/cuda/__ptx/ptx_helper_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/cstddef>
#include <cuda/std/cstdint>
Expand All @@ -30,9 +31,6 @@

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA_PTX

template <int __n>
using n32_t = _CUDA_VSTD::integral_constant<int, __n>;

/*************************************************************
*
* Conversion from generic pointer -> state space "pointer"
Expand Down

0 comments on commit e08bda4

Please sign in to comment.