Skip to content

Commit 687e037

Browse files
authored
[SYCL][SPIR-V] Don't use 'print' to get JointMatrix typename (#7054)
1 parent e6c4c15 commit 687e037

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,30 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
8080
OS << "i" << TTy->getIntegerBitWidth();
8181
break;
8282
}
83-
} else if (TTy->isBFloatTy())
83+
} else if (TTy->isHalfTy()) {
84+
OS << "half";
85+
} else if (TTy->isFloatTy()) {
86+
OS << "float";
87+
} else if (TTy->isDoubleTy()) {
88+
OS << "double";
89+
} else if (TTy->isBFloatTy()) {
8490
OS << "bfloat16";
85-
else if (TTy->isStructTy()) {
91+
} else if (TTy->isStructTy()) {
8692
StringRef LlvmTyName = TTy->getStructName();
87-
// Emit half/bfloat16 for sycl[::*]::{half,bfloat16}
93+
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
8894
if (LlvmTyName.startswith("class.sycl::") ||
8995
LlvmTyName.startswith("class.__sycl_internal::"))
9096
LlvmTyName = LlvmTyName.rsplit("::").second;
97+
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
98+
LlvmTyName != "tf32")
99+
llvm_unreachable("Wrong matrix base type!");
91100
OS << LlvmTyName;
92-
} else
93-
TTy->print(OS, false, true);
94-
} else if (TemplateArg.getKind() == TemplateArgument::Integral)
101+
} else {
102+
llvm_unreachable("Wrong matrix base type!");
103+
}
104+
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
95105
OS << TemplateArg.getAsIntegral();
106+
}
96107
}
97108
Ty->setName(OS.str());
98109
return;

clang/test/CodeGenSYCL/matrix.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}
2121
namespace sycl {
2222
class half {};
2323
class bfloat16 {};
24+
class tf32 {};
2425
}
2526
typedef sycl::half my_half;
2627

@@ -32,3 +33,9 @@ void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}
3233

3334
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
3435
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}
36+
37+
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
38+
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}
39+
40+
// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
41+
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}

0 commit comments

Comments
 (0)