Skip to content

Commit 5b8e9eb

Browse files
committed
[SYCL][Matrix] Switch to SPV_KHR_cooperative_matrix extension
Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 2011829 commit 5b8e9eb

File tree

10 files changed

+502
-3
lines changed

10 files changed

+502
-3
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,23 @@ llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
320320
"spirv.JointMatrixINTEL", {CompTy}, Params);
321321
}
322322

323+
llvm::Type *
324+
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
325+
ArrayRef<TemplateArgument> TemplateArgs,
326+
const unsigned Val = 0) {
327+
assert(TemplateArgs.size() == 5 &&
328+
"Wrong CooperativeMatrixKHR template parameters number");
329+
std::vector<unsigned> Params;
330+
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
331+
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
332+
"Wrong CooperativeMatrixKHR template parameter");
333+
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
334+
}
335+
336+
return llvm::TargetExtType::get(
337+
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
338+
}
339+
323340
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
324341
/// which is represented as a pointer to a structure to LLVM extension type
325342
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
@@ -363,6 +380,39 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
363380
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
364381
}
365382

383+
/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
384+
/// which is represented as a pointer to a structure to LLVM extension type
385+
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
386+
/// The expected representation is:
387+
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%, %cols%,
388+
/// %use%)
389+
llvm::Type *CodeGenTypes::ConvertSPVCooperativeMatrixType(RecordDecl *RD) {
390+
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
391+
ArrayRef<TemplateArgument> TemplateArgs =
392+
TemplateDecl->getTemplateArgs().asArray();
393+
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
394+
"1st CooperativeMatrixKHR template parameter must be type");
395+
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
396+
397+
if (CompTy->isStructTy()) {
398+
StringRef LlvmTyName = CompTy->getStructName();
399+
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
400+
if (LlvmTyName.starts_with("class.sycl::") ||
401+
LlvmTyName.starts_with("class.__sycl_internal::"))
402+
LlvmTyName = LlvmTyName.rsplit("::").second;
403+
if (LlvmTyName == "half") {
404+
CompTy = llvm::Type::getHalfTy(getLLVMContext());
405+
} else if (LlvmTyName == "tf32") {
406+
CompTy = llvm::Type::getFloatTy(getLLVMContext());
407+
} else if (LlvmTyName == "bfloat16") {
408+
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
409+
} else {
410+
llvm_unreachable("Wrong matrix base type!");
411+
}
412+
}
413+
return getCooperativeMatrixKHRExtType(CompTy, TemplateArgs);
414+
}
415+
366416
/// ConvertType - Convert the specified type to its LLVM form.
367417
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
368418
T = Context.getCanonicalType(T);
@@ -654,6 +704,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
654704
"__spv::__spirv_JointMatrixINTEL") {
655705
ResultType = ConvertSYCLJointMatrixINTELType(RD);
656706
break;
707+
} else if (RD && RD->getQualifiedNameAsString() ==
708+
"__spv::__spirv_CooperativeMatrixKHR") {
709+
ResultType = ConvertSPVCooperativeMatrixType(RD);
710+
break;
657711
} else if (RD && RD->getQualifiedNameAsString() ==
658712
"__spv::__spirv_TaskSequenceINTEL") {
659713
ResultType = llvm::TargetExtType::get(getLLVMContext(),

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ class CodeGenTypes {
136136
/// %use%, (optional) %element_type_interpretation%)
137137
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
138138

139+
/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
140+
/// which is represented as a pointer to a structure to LLVM extension type
141+
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
142+
/// The expected representation is:
143+
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,
144+
/// %cols%, %use%)
145+
///
146+
llvm::Type *ConvertSPVCooperativeMatrixType(RecordDecl *RD);
147+
139148
/// GetFunctionType - Get the LLVM function type for \arg Info.
140149
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);
141150

clang/lib/Driver/ToolChains/Clang.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10395,7 +10395,8 @@ void SPIRVTranslator::ConstructJob(Compilation &C, const JobAction &JA,
1039510395
",+SPV_KHR_uniform_group_instructions"
1039610396
",+SPV_INTEL_masked_gather_scatter"
1039710397
",+SPV_INTEL_tensor_float32_conversion"
10398-
",+SPV_INTEL_optnone";
10398+
",+SPV_INTEL_optnone"
10399+
",+SPV_KHR_cooperative_matrix";
1039910400
if (ShouldPreserveMetadata)
1040010401
ExtArg += ",+SPV_KHR_non_semantic_info";
1040110402
if (IsCPU)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
2+
// Test that SPIR-V codegen generates the expected LLVM struct name for the
3+
// CooperativeMatrixKHR type.
4+
#include <stddef.h>
5+
#include <stdint.h>
6+
7+
namespace __spv {
8+
template <typename T, uint32_t S, size_t R, size_t C, uint32_t U>
9+
struct __spirv_CooperativeMatrixKHR;
10+
}
11+
12+
// CHECK: @_Z2f1{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 5, 10, 0)
13+
void f1(__spv::__spirv_CooperativeMatrixKHR<float, 3, 5, 10, 0> *matrix) {}
14+
15+
// CHECK: @_Z2f2{{.*}}(target("spirv.CooperativeMatrixKHR", i64, 3, 10, 2, 1)
16+
void f2(__spv::__spirv_CooperativeMatrixKHR<uint64_t, 3, 10, 2, 1> *matrix) {}
17+
18+
// CHECK: @_Z2f3{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 3, 10, 2, 2)
19+
void f3(__spv::__spirv_CooperativeMatrixKHR<char, 3, 10, 2, 2> *matrix) {}
20+
21+
namespace sycl {
22+
class half {};
23+
class bfloat16 {};
24+
class tf32 {};
25+
}
26+
typedef sycl::half my_half;
27+
28+
// CHECK: @_Z2f4{{.*}}(target("spirv.CooperativeMatrixKHR", half, 3, 10, 2, 0)
29+
void f4(__spv::__spirv_CooperativeMatrixKHR<my_half, 3, 10, 2, 0> *matrix) {}
30+
31+
// CHECK: @_Z2f5{{.*}}(target("spirv.CooperativeMatrixKHR", i16, 3, 10, 2, 0)
32+
void f5(__spv::__spirv_CooperativeMatrixKHR<sycl::bfloat16, 3, 10, 2, 0> *matrix) {}
33+
34+
// CHECK: @_Z2f6{{.*}}(target("spirv.CooperativeMatrixKHR", i128, 3, 10, 2, 0)
35+
void f6(__spv::__spirv_CooperativeMatrixKHR<_BitInt(128), 3, 10, 2, 0> *matrix) {}
36+
37+
// CHECK: @_Z2f7{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 10, 2, 0)
38+
void f7(__spv::__spirv_CooperativeMatrixKHR<sycl::tf32, 3, 10, 2, 0> *matrix) {}
39+
40+
// CHECK: @_Z2f8{{.*}}(target("spirv.CooperativeMatrixKHR", double, 3, 5, 10, 0)
41+
void f8(__spv::__spirv_CooperativeMatrixKHR<double, 3, 5, 10, 0> *matrix) {}

clang/test/Driver/sycl-spirv-ext.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@
5555
// CHECK-DEFAULT-SAME:,+SPV_KHR_uniform_group_instructions
5656
// CHECK-DEFAULT-SAME:,+SPV_INTEL_masked_gather_scatter
5757
// CHECK-DEFAULT-SAME:,+SPV_INTEL_tensor_float32_conversion
58-
// CHECK-DEFAULT-SAME:,+SPV_INTEL_optnone"
58+
// CHECK-DEFAULT-SAME:,+SPV_INTEL_optnone
59+
// CHECK-DEFAULT-SAME:,+SPV_KHR_cooperative_matrix"
5960
// CHECK-FPGA-HW: llvm-spirv{{.*}}"-spirv-ext=-all
6061
// CHECK-FPGA-HW-SAME:,+SPV_EXT_shader_atomic_float_add
6162
// CHECK-FPGA-HW-SAME:,+SPV_EXT_shader_atomic_float_min_max
@@ -119,5 +120,6 @@
119120
// CHECK-CPU-SAME:,+SPV_INTEL_masked_gather_scatter
120121
// CHECK-CPU-SAME:,+SPV_INTEL_tensor_float32_conversion
121122
// CHECK-CPU-SAME:,+SPV_INTEL_optnone
123+
// CHECK-CPU-SAME:,+SPV_KHR_cooperative_matrix
122124
// CHECK-CPU-SAME:,+SPV_INTEL_fp_max_error"
123125

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);
2828

29+
#ifndef USE_COOP_MATRIX
2930
template <typename T, typename Tp, std::size_t R, std::size_t C,
3031
__spv::MatrixUse U,
3132
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
@@ -174,6 +175,105 @@ template <typename Ts, typename T, std::size_t R, std::size_t C,
174175
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
175176
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
176177
Ts val, size_t i);
178+
#else // USE_COOP_MATRIX
179+
template <typename T, typename Tp, std::size_t R, std::size_t C,
180+
__spv::MatrixUse U,
181+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
182+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
183+
extern __DPCPP_SYCL_EXTERNAL
184+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
185+
__spirv_CooperativeMatrixLoadKHR(T *Ptr, __spv::MatrixLayout Layout = L,
186+
std::size_t Stride = 0,
187+
int MemOperand = 0);
188+
template <typename T, typename Tp, std::size_t R, std::size_t C,
189+
__spv::MatrixUse U,
190+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
191+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
192+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreKHR(
193+
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
194+
__spv::MatrixLayout Layout = L, std::size_t Stride = 0, int MemOperand = 0);
195+
196+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
197+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
198+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
199+
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_CooperativeMatrixLengthKHR(
200+
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *);
201+
202+
template <typename T, typename Tp, std::size_t R, std::size_t C,
203+
__spv::MatrixUse U,
204+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
205+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
206+
extern __DPCPP_SYCL_EXTERNAL
207+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
208+
__spirv_CooperativeMatrixConstructCheckedINTEL(const T Value, size_t Height,
209+
size_t Stride, size_t Width,
210+
size_t CoordX,
211+
size_t CoordY);
212+
213+
template <typename T, typename Tp, std::size_t R, std::size_t C,
214+
__spv::MatrixUse U,
215+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
216+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
217+
extern __DPCPP_SYCL_EXTERNAL
218+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
219+
__spirv_CooperativeMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
220+
size_t Height, size_t Width,
221+
size_t CoordX, size_t CoordY,
222+
__spv::MatrixLayout Layout = L,
223+
int MemOperand = 0);
224+
225+
template <typename T, typename Tp, std::size_t R, std::size_t C,
226+
__spv::MatrixUse U,
227+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
228+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
229+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
230+
T *Ptr, __spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
231+
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
232+
size_t CoordY, __spv::MatrixLayout Layout = L, int MemOperand = 0);
233+
234+
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
235+
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
236+
__spv::MatrixUse UC,
237+
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
238+
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
239+
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
240+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
241+
extern __DPCPP_SYCL_EXTERNAL
242+
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *
243+
__spirv_CooperativeMatrixMulAddKHR(
244+
__spv::__spirv_CooperativeMatrixKHR<TA, S, M, K, UA> *A,
245+
__spv::__spirv_CooperativeMatrixKHR<TB, S, K, N, UB> *B,
246+
__spv::__spirv_CooperativeMatrixKHR<TC, S, M, N, UC> *C,
247+
size_t Operands = 0);
248+
249+
template <typename T, typename Tp, std::size_t R, std::size_t C,
250+
__spv::MatrixUse U,
251+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
252+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
253+
extern __DPCPP_SYCL_EXTERNAL
254+
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *
255+
__spirv_CompositeConstruct(const T v);
256+
257+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
258+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
259+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
260+
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
261+
__spirv_CooperativeMatrixGetElementCoordINTEL(
262+
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *, size_t i);
263+
264+
// AccessChain followed by load/store serves to extract/insert and element
265+
// from/to the matrix
266+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
267+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
268+
extern __DPCPP_SYCL_EXTERNAL T *
269+
__spirv_AccessChain(__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> **,
270+
size_t i);
271+
272+
template <typename T> extern __DPCPP_SYCL_EXTERNAL T __spirv_Load(T *Ptr);
273+
274+
template <typename T>
275+
extern __DPCPP_SYCL_EXTERNAL void __spirv_Store(T *Ptr, T Obj);
276+
#endif // USE_COOP_MATRIX
177277

178278
template <typename T>
179279
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,30 @@ enum class MatrixLayout : uint32_t {
119119

120120
enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
121121

122+
enum class MatrixOperands : uint32_t {
123+
// SPV_KHR_cooperative_matrix operands
124+
NoneKHR = 0,
125+
MatrixASignedComponentsKHR = 0x1,
126+
MatrixBSignedComponentsKHR = 0x2,
127+
MatrixCSignedComponentsKHR = 0x4,
128+
MatrixResultSignedComponentsKHR = 0x8,
129+
SaturatingAccumulationKHR = 0x10,
130+
// SPV_INTE_joint_matrix operands
131+
MatrixAAndBTF32ComponentsINTEL = 0x20,
132+
MatrixAAndBBFloat16ComponentsINTEL = 0x40,
133+
MatrixCBFloat16ComponentsINTEL = 0x80,
134+
MatrixResultBFloat16ComponentsINTEL = 0x100
135+
};
136+
122137
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
123138
Scope::Flag S = Scope::Flag::Subgroup,
124139
MatrixUse U = MatrixUse::MatrixA>
125140
struct __spirv_JointMatrixINTEL;
126141

142+
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
143+
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
144+
struct __spirv_CooperativeMatrixKHR;
145+
127146
struct __spirv_TaskSequenceINTEL;
128147

129148
} // namespace __spv

0 commit comments

Comments
 (0)