Skip to content

Commit d2a5e8d

Browse files
authored
[SYCL][Matrix] Add generation of spirv.CooperativeMatrixKHR type (#13645)
Represented via target extension type. Joint matrix type will be removed soon. --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 6934bcf commit d2a5e8d

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,22 @@ llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
320320
"spirv.JointMatrixINTEL", {CompTy}, Params);
321321
}
322322

323+
llvm::Type *
324+
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
325+
ArrayRef<TemplateArgument> TemplateArgs) {
326+
assert(TemplateArgs.size() == 5 &&
327+
"Wrong CooperativeMatrixKHR template parameters number");
328+
std::vector<unsigned> Params;
329+
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
330+
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
331+
"Wrong CooperativeMatrixKHR template parameter");
332+
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
333+
}
334+
335+
return llvm::TargetExtType::get(
336+
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
337+
}
338+
323339
/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
324340
/// which is represented as a pointer to a structure to LLVM extension type
325341
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
@@ -363,6 +379,39 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
363379
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
364380
}
365381

382+
/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
383+
/// which is represented as a pointer to a structure to LLVM extension type
384+
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
385+
/// The expected representation is:
386+
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%, %cols%,
387+
/// %use%)
388+
llvm::Type *CodeGenTypes::ConvertSPVCooperativeMatrixType(RecordDecl *RD) {
389+
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
390+
ArrayRef<TemplateArgument> TemplateArgs =
391+
TemplateDecl->getTemplateArgs().asArray();
392+
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
393+
"1st CooperativeMatrixKHR template parameter must be type");
394+
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());
395+
396+
if (CompTy->isStructTy()) {
397+
StringRef LlvmTyName = CompTy->getStructName();
398+
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
399+
if (LlvmTyName.starts_with("class.sycl::") ||
400+
LlvmTyName.starts_with("class.__sycl_internal::"))
401+
LlvmTyName = LlvmTyName.rsplit("::").second;
402+
if (LlvmTyName == "half") {
403+
CompTy = llvm::Type::getHalfTy(getLLVMContext());
404+
} else if (LlvmTyName == "tf32") {
405+
CompTy = llvm::Type::getFloatTy(getLLVMContext());
406+
} else if (LlvmTyName == "bfloat16") {
407+
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
408+
} else {
409+
llvm_unreachable("Wrong matrix base type!");
410+
}
411+
}
412+
return getCooperativeMatrixKHRExtType(CompTy, TemplateArgs);
413+
}
414+
366415
/// ConvertType - Convert the specified type to its LLVM form.
367416
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
368417
T = Context.getCanonicalType(T);
@@ -654,6 +703,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
654703
"__spv::__spirv_JointMatrixINTEL") {
655704
ResultType = ConvertSYCLJointMatrixINTELType(RD);
656705
break;
706+
} else if (RD && RD->getQualifiedNameAsString() ==
707+
"__spv::__spirv_CooperativeMatrixKHR") {
708+
ResultType = ConvertSPVCooperativeMatrixType(RD);
709+
break;
657710
} else if (RD && RD->getQualifiedNameAsString() ==
658711
"__spv::__spirv_TaskSequenceINTEL") {
659712
ResultType = llvm::TargetExtType::get(getLLVMContext(),

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ class CodeGenTypes {
136136
/// %use%, (optional) %element_type_interpretation%)
137137
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);
138138

139+
/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
140+
/// which is represented as a pointer to a structure to LLVM extension type
141+
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
142+
/// The expected representation is:
143+
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,
144+
/// %cols%, %use%)
145+
///
146+
llvm::Type *ConvertSPVCooperativeMatrixType(RecordDecl *RD);
147+
139148
/// GetFunctionType - Get the LLVM function type for \arg Info.
140149
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);
141150

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
2+
// Test that SPIR-V codegen generates the expected LLVM struct name for the
3+
// CooperativeMatrixKHR type.
4+
#include <stddef.h>
5+
#include <stdint.h>
6+
7+
namespace __spv {
8+
template <typename T, uint32_t S, size_t R, size_t C, uint32_t U>
9+
struct __spirv_CooperativeMatrixKHR;
10+
}
11+
12+
// CHECK: @_Z2f1{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 5, 10, 0)
13+
void f1(__spv::__spirv_CooperativeMatrixKHR<float, 3, 5, 10, 0> *matrix) {}
14+
15+
// CHECK: @_Z2f2{{.*}}(target("spirv.CooperativeMatrixKHR", i64, 3, 10, 2, 1)
16+
void f2(__spv::__spirv_CooperativeMatrixKHR<uint64_t, 3, 10, 2, 1> *matrix) {}
17+
18+
// CHECK: @_Z2f3{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 3, 10, 2, 2)
19+
void f3(__spv::__spirv_CooperativeMatrixKHR<char, 3, 10, 2, 2> *matrix) {}
20+
21+
namespace sycl {
22+
class half {};
23+
class bfloat16 {};
24+
class tf32 {};
25+
}
26+
typedef sycl::half my_half;
27+
28+
// CHECK: @_Z2f4{{.*}}(target("spirv.CooperativeMatrixKHR", half, 3, 10, 2, 0)
29+
void f4(__spv::__spirv_CooperativeMatrixKHR<my_half, 3, 10, 2, 0> *matrix) {}
30+
31+
// CHECK: @_Z2f5{{.*}}(target("spirv.CooperativeMatrixKHR", i16, 3, 10, 2, 0)
32+
void f5(__spv::__spirv_CooperativeMatrixKHR<sycl::bfloat16, 3, 10, 2, 0> *matrix) {}
33+
34+
// CHECK: @_Z2f6{{.*}}(target("spirv.CooperativeMatrixKHR", i128, 3, 10, 2, 0)
35+
void f6(__spv::__spirv_CooperativeMatrixKHR<_BitInt(128), 3, 10, 2, 0> *matrix) {}
36+
37+
// CHECK: @_Z2f7{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 10, 2, 0)
38+
void f7(__spv::__spirv_CooperativeMatrixKHR<sycl::tf32, 3, 10, 2, 0> *matrix) {}
39+
40+
// CHECK: @_Z2f8{{.*}}(target("spirv.CooperativeMatrixKHR", double, 3, 5, 10, 0)
41+
void f8(__spv::__spirv_CooperativeMatrixKHR<double, 3, 5, 10, 0> *matrix) {}

0 commit comments

Comments
 (0)