Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport to 2.8: Sync ptx_dot_variants.h with libcuda-ptx (#3564) #3577

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions libcudacxx/include/cuda/__ptx/ptx_dot_variants.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
//
//===----------------------------------------------------------------------===//

// WARNING: The source of truth for this file is libcuda-ptx. Do not modify without syncing with libcuda-ptx.

#ifndef _CUDA_PTX_DOT_VARIANTS_H_
#define _CUDA_PTX_DOT_VARIANTS_H_

Expand Down Expand Up @@ -111,6 +113,23 @@ enum class dot_op
exch
};

enum class dot_cta_group
{
cta_group_1,
cta_group_2
};

enum class dot_kind
{
f16,
f8f6f4,
i8,
mxf4,
mxf4nvf4,
mxf8f6f4,
tf32
};

template <dot_sem __sem>
using sem_t = _CUDA_VSTD::integral_constant<dot_sem, __sem>;
using sem_acq_rel_t = sem_t<dot_sem::acq_rel>;
Expand Down Expand Up @@ -173,6 +192,35 @@ static constexpr op_xor_op_t op_xor_op{};
static constexpr op_cas_t op_cas{};
static constexpr op_exch_t op_exch{};

template <dot_cta_group __cta_group>
using cta_group_t = _CUDA_VSTD::integral_constant<dot_cta_group, __cta_group>;
using cta_group_1_t = cta_group_t<dot_cta_group::cta_group_1>;
using cta_group_2_t = cta_group_t<dot_cta_group::cta_group_1>;

static constexpr cta_group_1_t cta_group_1{};
static constexpr cta_group_2_t cta_group_2{};

template <dot_kind __kind>
using kind_t = _CUDA_VSTD::integral_constant<dot_kind, __kind>;
using kind_f16_t = kind_t<dot_kind::f16>;
using kind_f8f6f4_t = kind_t<dot_kind::f8f6f4>;
using kind_i8_t = kind_t<dot_kind::i8>;
using kind_mxf4_t = kind_t<dot_kind::mxf4>;
using kind_mxf4nvf4_t = kind_t<dot_kind::mxf4nvf4>;
using kind_mxf8f6f4_t = kind_t<dot_kind::mxf8f6f4>;
using kind_tf32_t = kind_t<dot_kind::tf32>;

static constexpr kind_f16_t kind_f16{};
static constexpr kind_f8f6f4_t kind_f8f6f4{};
static constexpr kind_i8_t kind_i8{};
static constexpr kind_mxf4_t kind_mxf4{};
static constexpr kind_mxf4nvf4_t kind_mxf4nvf4{};
static constexpr kind_mxf8f6f4_t kind_mxf8f6f4{};
static constexpr kind_tf32_t kind_tf32{};

template <int n>
using n32_t = _CUDA_VSTD::integral_constant<int, n>;

_LIBCUDACXX_END_NAMESPACE_CUDA_PTX

#endif // _CUDA_PTX_DOT_VARIANTS_H_
4 changes: 1 addition & 3 deletions libcudacxx/include/cuda/__ptx/ptx_helper_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/cstddef>
#include <cuda/std/cstdint>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA_PTX

template <int __n>
using n32_t = _CUDA_VSTD::integral_constant<int, __n>;

/*************************************************************
*
* Conversion from generic pointer -> state space "pointer"
Expand Down
Loading