Skip to content

Commit

Permalink
[SYCL][Matrix] Switch to SPV_KHR_cooperative_matrix extension
Browse files Browse the repository at this point in the history
Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims committed Apr 8, 2024
1 parent 2011829 commit 5b8e9eb
Show file tree
Hide file tree
Showing 10 changed files with 502 additions and 3 deletions.
54 changes: 54 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,23 @@ llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

llvm::Type *
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
assert(TemplateArgs.size() == 5 &&
"Wrong CooperativeMatrixKHR template parameters number");
std::vector<unsigned> Params;
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong CooperativeMatrixKHR template parameter");
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
}

return llvm::TargetExtType::get(
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
Expand Down Expand Up @@ -363,6 +380,39 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
/// The expected representation is:
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%, %cols%,
/// %use%)
llvm::Type *CodeGenTypes::ConvertSPVCooperativeMatrixType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st CooperativeMatrixKHR template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.starts_with("class.sycl::") ||
LlvmTyName.starts_with("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getCooperativeMatrixKHRExtType(CompTy, TemplateArgs);
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -654,6 +704,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_CooperativeMatrixKHR") {
ResultType = ConvertSPVCooperativeMatrixType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_TaskSequenceINTEL") {
ResultType = llvm::TargetExtType::get(getLLVMContext(),
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ class CodeGenTypes {
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
/// The expected representation is:
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,
/// %cols%, %use%)
///
llvm::Type *ConvertSPVCooperativeMatrixType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Driver/ToolChains/Clang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10395,7 +10395,8 @@ void SPIRVTranslator::ConstructJob(Compilation &C, const JobAction &JA,
",+SPV_KHR_uniform_group_instructions"
",+SPV_INTEL_masked_gather_scatter"
",+SPV_INTEL_tensor_float32_conversion"
",+SPV_INTEL_optnone";
",+SPV_INTEL_optnone"
",+SPV_KHR_cooperative_matrix";
if (ShouldPreserveMetadata)
ExtArg += ",+SPV_KHR_non_semantic_info";
if (IsCPU)
Expand Down
41 changes: 41 additions & 0 deletions clang/test/CodeGenSYCL/cooperative_matrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
// Test that SPIR-V codegen generates the expected LLVM struct name for the
// CooperativeMatrixKHR type.
#include <stddef.h>
#include <stdint.h>

namespace __spv {
template <typename T, uint32_t S, size_t R, size_t C, uint32_t U>
struct __spirv_CooperativeMatrixKHR;
}

// CHECK: @_Z2f1{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 5, 10, 0)
void f1(__spv::__spirv_CooperativeMatrixKHR<float, 3, 5, 10, 0> *matrix) {}

// CHECK: @_Z2f2{{.*}}(target("spirv.CooperativeMatrixKHR", i64, 3, 10, 2, 1)
void f2(__spv::__spirv_CooperativeMatrixKHR<uint64_t, 3, 10, 2, 1> *matrix) {}

// CHECK: @_Z2f3{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 3, 10, 2, 2)
void f3(__spv::__spirv_CooperativeMatrixKHR<char, 3, 10, 2, 2> *matrix) {}

namespace sycl {
class half {};
class bfloat16 {};
class tf32 {};
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(target("spirv.CooperativeMatrixKHR", half, 3, 10, 2, 0)
void f4(__spv::__spirv_CooperativeMatrixKHR<my_half, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(target("spirv.CooperativeMatrixKHR", i16, 3, 10, 2, 0)
void f5(__spv::__spirv_CooperativeMatrixKHR<sycl::bfloat16, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(target("spirv.CooperativeMatrixKHR", i128, 3, 10, 2, 0)
void f6(__spv::__spirv_CooperativeMatrixKHR<_BitInt(128), 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f7{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 10, 2, 0)
void f7(__spv::__spirv_CooperativeMatrixKHR<sycl::tf32, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f8{{.*}}(target("spirv.CooperativeMatrixKHR", double, 3, 5, 10, 0)
void f8(__spv::__spirv_CooperativeMatrixKHR<double, 3, 5, 10, 0> *matrix) {}
File renamed without changes.
4 changes: 3 additions & 1 deletion clang/test/Driver/sycl-spirv-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
// CHECK-DEFAULT-SAME:,+SPV_KHR_uniform_group_instructions
// CHECK-DEFAULT-SAME:,+SPV_INTEL_masked_gather_scatter
// CHECK-DEFAULT-SAME:,+SPV_INTEL_tensor_float32_conversion
// CHECK-DEFAULT-SAME:,+SPV_INTEL_optnone"
// CHECK-DEFAULT-SAME:,+SPV_INTEL_optnone
// CHECK-DEFAULT-SAME:,+SPV_KHR_cooperative_matrix"
// CHECK-FPGA-HW: llvm-spirv{{.*}}"-spirv-ext=-all
// CHECK-FPGA-HW-SAME:,+SPV_EXT_shader_atomic_float_add
// CHECK-FPGA-HW-SAME:,+SPV_EXT_shader_atomic_float_min_max
Expand Down Expand Up @@ -119,5 +120,6 @@
// CHECK-CPU-SAME:,+SPV_INTEL_masked_gather_scatter
// CHECK-CPU-SAME:,+SPV_INTEL_tensor_float32_conversion
// CHECK-CPU-SAME:,+SPV_INTEL_optnone
// CHECK-CPU-SAME:,+SPV_KHR_cooperative_matrix
// CHECK-CPU-SAME:,+SPV_INTEL_fp_max_error"

100 changes: 100 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);

#ifndef USE_COOP_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -174,6 +175,105 @@ template <typename Ts, typename T, std::size_t R, std::size_t C,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
Ts val, size_t i);
#else // USE_COOP_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadKHR(T *Ptr, __spv::MatrixLayout Layout = L,
std::size_t Stride = 0,
int MemOperand = 0);
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreKHR(
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, std::size_t Stride = 0, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_CooperativeMatrixLengthKHR(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(const T Value, size_t Height,
size_t Stride, size_t Width,
size_t CoordX,
size_t CoordY);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
size_t Height, size_t Width,
size_t CoordX, size_t CoordY,
__spv::MatrixLayout Layout = L,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
__spirv_CooperativeMatrixMulAddKHR(
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *C,
size_t Operands = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_CooperativeMatrixGetElementCoordINTEL(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *, size_t i);

// AccessChain followed by load/store serves to extract/insert and element
// from/to the matrix
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL T *
__spirv_AccessChain(__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> **,
size_t i);

template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_Load(T *Ptr);

template <typename T>
extern __DPCPP_SYCL_EXTERNAL void __spirv_Store(T *Ptr, T Obj);
#endif // USE_COOP_MATRIX

template <typename T>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(
Expand Down
19 changes: 19 additions & 0 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,30 @@ enum class MatrixLayout : uint32_t {

enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

enum class MatrixOperands : uint32_t {
// SPV_KHR_cooperative_matrix operands
NoneKHR = 0,
MatrixASignedComponentsKHR = 0x1,
MatrixBSignedComponentsKHR = 0x2,
MatrixCSignedComponentsKHR = 0x4,
MatrixResultSignedComponentsKHR = 0x8,
SaturatingAccumulationKHR = 0x10,
// SPV_INTE_joint_matrix operands
MatrixAAndBTF32ComponentsINTEL = 0x20,
MatrixAAndBBFloat16ComponentsINTEL = 0x40,
MatrixCBFloat16ComponentsINTEL = 0x80,
MatrixResultBFloat16ComponentsINTEL = 0x100
};

template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
MatrixUse U = MatrixUse::MatrixA>
struct __spirv_JointMatrixINTEL;

template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
struct __spirv_CooperativeMatrixKHR;

struct __spirv_TaskSequenceINTEL;

} // namespace __spv
Expand Down
Loading

0 comments on commit 5b8e9eb

Please sign in to comment.