Skip to content

Commit 443971c

Browse files
[SYCL][SPIR-V] Change the LLVM type name of SPIR-V matrix types. (#6535)
This eliminates the need for the SPIR-V translator to query the pointer element type of the members in the struct to figure out what matrix type it really is.
1 parent 2eb9d4b commit 443971c

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,54 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
5151
StringRef suffix) {
5252
SmallString<256> TypeName;
5353
llvm::raw_svector_ostream OS(TypeName);
54+
// If RD is spirv_JointMatrixINTEL type, mangle differently.
55+
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
56+
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
57+
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
58+
ArrayRef<TemplateArgument> TemplateArgs =
59+
TemplateDecl->getTemplateArgs().asArray();
60+
OS << "spirv.JointMatrixINTEL.";
61+
for (auto &TemplateArg : TemplateArgs) {
62+
OS << "_";
63+
if (TemplateArg.getKind() == TemplateArgument::Type) {
64+
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
65+
if (TTy->isIntegerTy()) {
66+
switch (TTy->getIntegerBitWidth()) {
67+
case 8:
68+
OS << "char";
69+
break;
70+
case 16:
71+
OS << "short";
72+
break;
73+
case 32:
74+
OS << "int";
75+
break;
76+
case 64:
77+
OS << "long";
78+
break;
79+
default:
80+
OS << "i" << TTy->getIntegerBitWidth();
81+
break;
82+
}
83+
} else if (TTy->isBFloatTy())
84+
OS << "bfloat16";
85+
else if (TTy->isStructTy()) {
86+
StringRef LlvmTyName = TTy->getStructName();
87+
// Emit half/bfloat16 for sycl[::*]::{half,bfloat16}
88+
if (LlvmTyName.startswith("class.sycl::") ||
89+
LlvmTyName.startswith("class.__sycl_internal::"))
90+
LlvmTyName = LlvmTyName.rsplit("::").second;
91+
OS << LlvmTyName;
92+
} else
93+
TTy->print(OS, false, true);
94+
} else if (TemplateArg.getKind() == TemplateArgument::Integral)
95+
OS << TemplateArg.getAsIntegral();
96+
}
97+
Ty->setName(OS.str());
98+
return;
99+
}
100+
}
101+
}
54102
OS << RD->getKindName() << '.';
55103

56104
// FIXME: We probably want to make more tweaks to the printing policy. For

clang/test/CodeGenSYCL/matrix.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - -no-opaque-pointers | FileCheck %s
2+
// Test that SPIR-V codegen generates the expected LLVM struct name for the
3+
// JointMatrixINTEL type.
4+
#include <stddef.h>
5+
#include <stdint.h>
6+
7+
namespace __spv {
8+
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
9+
struct __spirv_JointMatrixINTEL;
10+
}
11+
12+
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
13+
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}
14+
15+
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
16+
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}
17+
18+
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
19+
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
20+
21+
namespace sycl {
22+
class half {};
23+
class bfloat16 {};
24+
}
25+
typedef sycl::half my_half;
26+
27+
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
28+
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}
29+
30+
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
31+
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
32+
33+
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
34+
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}

sycl/test/matrix/matrix-int8-test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s
22

3-
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* }
4-
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* }
5-
// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* }
3+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* }
4+
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* }
5+
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* }
66

77
#include <sycl/sycl.hpp>
88
#if (SYCL_EXT_ONEAPI_MATRIX == 2)

0 commit comments

Comments
 (0)