Skip to content

Commit 8175190

Browse files
[SPIR-V] Emit proper pointer type for OpenCL kernel arguments (#67726)
1 parent b858309 commit 8175190

10 files changed

+315
-86
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,60 +2010,6 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
20102010
llvm_unreachable("Unable to recognize type!");
20112011
}
20122012

2013-
static const TargetExtType *parseToTargetExtType(const Type *OpaqueType,
2014-
MachineIRBuilder &MIRBuilder) {
2015-
assert(isSpecialOpaqueType(OpaqueType) &&
2016-
"Not a SPIR-V/OpenCL special opaque type!");
2017-
assert(!OpaqueType->isTargetExtTy() &&
2018-
"This already is SPIR-V/OpenCL TargetExtType!");
2019-
2020-
StringRef NameWithParameters = OpaqueType->getStructName();
2021-
2022-
// Pointers-to-opaque-structs representing OpenCL types are first translated
2023-
// to equivalent SPIR-V types. OpenCL builtin type names should have the
2024-
// following format: e.g. %opencl.event_t
2025-
if (NameWithParameters.startswith("opencl.")) {
2026-
const SPIRV::OpenCLType *OCLTypeRecord =
2027-
SPIRV::lookupOpenCLType(NameWithParameters);
2028-
if (!OCLTypeRecord)
2029-
report_fatal_error("Missing TableGen record for OpenCL type: " +
2030-
NameWithParameters);
2031-
NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
2032-
// Continue with the SPIR-V builtin type...
2033-
}
2034-
2035-
// Names of the opaque structs representing a SPIR-V builtins without
2036-
// parameters should have the following format: e.g. %spirv.Event
2037-
assert(NameWithParameters.startswith("spirv.") &&
2038-
"Unknown builtin opaque type!");
2039-
2040-
// Parameterized SPIR-V builtins names follow this format:
2041-
// e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
2042-
if (NameWithParameters.find('_') == std::string::npos)
2043-
return TargetExtType::get(OpaqueType->getContext(), NameWithParameters);
2044-
2045-
SmallVector<StringRef> Parameters;
2046-
unsigned BaseNameLength = NameWithParameters.find('_') - 1;
2047-
SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");
2048-
2049-
SmallVector<Type *, 1> TypeParameters;
2050-
bool HasTypeParameter = !isDigit(Parameters[0][0]);
2051-
if (HasTypeParameter)
2052-
TypeParameters.push_back(parseTypeString(
2053-
Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
2054-
SmallVector<unsigned> IntParameters;
2055-
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
2056-
unsigned IntParameter = 0;
2057-
bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
2058-
assert(ValidLiteral &&
2059-
"Invalid format of SPIR-V builtin parameter literal!");
2060-
IntParameters.push_back(IntParameter);
2061-
}
2062-
return TargetExtType::get(OpaqueType->getContext(),
2063-
NameWithParameters.substr(0, BaseNameLength),
2064-
TypeParameters, IntParameters);
2065-
}
2066-
20672013
//===----------------------------------------------------------------------===//
20682014
// Implementation functions for builtin types.
20692015
//===----------------------------------------------------------------------===//
@@ -2127,6 +2073,56 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
21272073
}
21282074

21292075
namespace SPIRV {
2076+
const TargetExtType *
2077+
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
2078+
MachineIRBuilder &MIRBuilder) {
2079+
StringRef NameWithParameters = TypeName;
2080+
2081+
// Pointers-to-opaque-structs representing OpenCL types are first translated
2082+
// to equivalent SPIR-V types. OpenCL builtin type names should have the
2083+
// following format: e.g. %opencl.event_t
2084+
if (NameWithParameters.startswith("opencl.")) {
2085+
const SPIRV::OpenCLType *OCLTypeRecord =
2086+
SPIRV::lookupOpenCLType(NameWithParameters);
2087+
if (!OCLTypeRecord)
2088+
report_fatal_error("Missing TableGen record for OpenCL type: " +
2089+
NameWithParameters);
2090+
NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
2091+
// Continue with the SPIR-V builtin type...
2092+
}
2093+
2094+
// Names of the opaque structs representing a SPIR-V builtins without
2095+
// parameters should have the following format: e.g. %spirv.Event
2096+
assert(NameWithParameters.startswith("spirv.") &&
2097+
"Unknown builtin opaque type!");
2098+
2099+
// Parameterized SPIR-V builtins names follow this format:
2100+
// e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
2101+
if (NameWithParameters.find('_') == std::string::npos)
2102+
return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);
2103+
2104+
SmallVector<StringRef> Parameters;
2105+
unsigned BaseNameLength = NameWithParameters.find('_') - 1;
2106+
SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");
2107+
2108+
SmallVector<Type *, 1> TypeParameters;
2109+
bool HasTypeParameter = !isDigit(Parameters[0][0]);
2110+
if (HasTypeParameter)
2111+
TypeParameters.push_back(parseTypeString(
2112+
Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
2113+
SmallVector<unsigned> IntParameters;
2114+
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
2115+
unsigned IntParameter = 0;
2116+
bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
2117+
assert(ValidLiteral &&
2118+
"Invalid format of SPIR-V builtin parameter literal!");
2119+
IntParameters.push_back(IntParameter);
2120+
}
2121+
return TargetExtType::get(MIRBuilder.getContext(),
2122+
NameWithParameters.substr(0, BaseNameLength),
2123+
TypeParameters, IntParameters);
2124+
}
2125+
21302126
SPIRVType *lowerBuiltinType(const Type *OpaqueType,
21312127
SPIRV::AccessQualifier::AccessQualifier AccessQual,
21322128
MachineIRBuilder &MIRBuilder,
@@ -2141,7 +2137,8 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
21412137
// will be removed in the future release of LLVM.
21422138
const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
21432139
if (!BuiltinType)
2144-
BuiltinType = parseToTargetExtType(OpaqueType, MIRBuilder);
2140+
BuiltinType = parseBuiltinTypeNameToTargetExtType(
2141+
OpaqueType->getStructName().str(), MIRBuilder);
21452142

21462143
unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
21472144

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
3737
const Register OrigRet, const Type *OrigRetTy,
3838
const SmallVectorImpl<Register> &Args,
3939
SPIRVGlobalRegistry *GR);
40+
41+
/// Translates a string representing a SPIR-V or OpenCL builtin type to a
42+
/// TargetExtType that can be further lowered with lowerBuiltinType().
43+
///
44+
/// \return A TargetExtType representing the builtin SPIR-V type.
45+
///
46+
/// \p TypeName is the full string representation of the SPIR-V or OpenCL
47+
/// builtin type.
48+
const TargetExtType *
49+
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
50+
MachineIRBuilder &MIRBuilder);
51+
4052
/// Handles the translation of the provided special opaque/builtin type \p Type
4153
/// to SPIR-V type. Generates the corresponding machine instructions for the
4254
/// target type or gets the already existing OpType<...> register from the

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -194,23 +194,38 @@ getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
194194
return {};
195195
}
196196

197-
static Type *getArgType(const Function &F, unsigned ArgIdx) {
197+
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
198+
SPIRVGlobalRegistry *GR,
199+
MachineIRBuilder &MIRBuilder) {
200+
// Read argument's access qualifier from metadata or default.
201+
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
202+
getArgAccessQual(F, ArgIdx);
203+
198204
Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
205+
206+
// In case of non-kernel SPIR-V function or already TargetExtType, use the
207+
// original IR type.
199208
if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
200209
isSpecialOpaqueType(OriginalArgType))
201-
return OriginalArgType;
210+
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
202211

203212
MDString *MDKernelArgType =
204213
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
205-
if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
206-
return OriginalArgType;
207-
208-
std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
209-
Type *ExistingOpaqueType =
210-
StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
211-
return ExistingOpaqueType
212-
? ExistingOpaqueType
213-
: StructType::create(F.getContext(), KernelArgTypeStr);
214+
if (!MDKernelArgType || (MDKernelArgType->getString().ends_with("*") &&
215+
MDKernelArgType->getString().ends_with("_t")))
216+
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
217+
218+
if (MDKernelArgType->getString().ends_with("*"))
219+
return GR->getOrCreateSPIRVTypeByName(
220+
MDKernelArgType->getString(), MIRBuilder,
221+
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));
222+
223+
if (MDKernelArgType->getString().ends_with("_t"))
224+
return GR->getOrCreateSPIRVTypeByName(
225+
"opencl." + MDKernelArgType->getString().str(), MIRBuilder,
226+
SPIRV::StorageClass::Function, ArgAccessQual);
227+
228+
llvm_unreachable("Unable to recognize argument type name.");
214229
}
215230

216231
static bool isEntryPoint(const Function &F) {
@@ -262,10 +277,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
262277
// TODO: handle the case of multiple registers.
263278
if (VRegs[i].size() > 1)
264279
return false;
265-
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
266-
getArgAccessQual(F, i);
267-
auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
268-
MIRBuilder, ArgAccessQual);
280+
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
281+
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
269282
ArgTypeVRegs.push_back(SpirvTy);
270283

271284
if (Arg.hasName())

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -956,40 +956,82 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
956956
}
957957

958958
// TODO: maybe use tablegen to implement this.
959-
SPIRVType *
960-
SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
961-
MachineIRBuilder &MIRBuilder) {
959+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
960+
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
961+
SPIRV::StorageClass::StorageClass SC,
962+
SPIRV::AccessQualifier::AccessQualifier AQ) {
962963
unsigned VecElts = 0;
963964
auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
964965

966+
// Parse strings representing either a SPIR-V or OpenCL builtin type.
967+
if (hasBuiltinTypePrefix(TypeStr))
968+
return getOrCreateSPIRVType(
969+
SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder),
970+
MIRBuilder, AQ);
971+
965972
// Parse type name in either "typeN" or "type vector[N]" format, where
966973
// N is the number of elements of the vector.
967-
Type *Type;
974+
Type *Ty;
975+
976+
if (TypeStr.starts_with("atomic_"))
977+
TypeStr = TypeStr.substr(strlen("atomic_"));
978+
968979
if (TypeStr.startswith("void")) {
969-
Type = Type::getVoidTy(Ctx);
980+
Ty = Type::getVoidTy(Ctx);
970981
TypeStr = TypeStr.substr(strlen("void"));
982+
} else if (TypeStr.startswith("bool")) {
983+
Ty = Type::getIntNTy(Ctx, 1);
984+
TypeStr = TypeStr.substr(strlen("bool"));
985+
} else if (TypeStr.startswith("char") || TypeStr.startswith("uchar")) {
986+
Ty = Type::getInt8Ty(Ctx);
987+
TypeStr = TypeStr.startswith("char") ? TypeStr.substr(strlen("char"))
988+
: TypeStr.substr(strlen("uchar"));
989+
} else if (TypeStr.startswith("short") || TypeStr.startswith("ushort")) {
990+
Ty = Type::getInt16Ty(Ctx);
991+
TypeStr = TypeStr.startswith("short") ? TypeStr.substr(strlen("short"))
992+
: TypeStr.substr(strlen("ushort"));
971993
} else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) {
972-
Type = Type::getInt32Ty(Ctx);
994+
Ty = Type::getInt32Ty(Ctx);
973995
TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int"))
974996
: TypeStr.substr(strlen("uint"));
975-
} else if (TypeStr.startswith("float")) {
976-
Type = Type::getFloatTy(Ctx);
977-
TypeStr = TypeStr.substr(strlen("float"));
997+
} else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
998+
Ty = Type::getInt64Ty(Ctx);
999+
TypeStr = TypeStr.startswith("long") ? TypeStr.substr(strlen("long"))
1000+
: TypeStr.substr(strlen("ulong"));
9781001
} else if (TypeStr.startswith("half")) {
979-
Type = Type::getHalfTy(Ctx);
1002+
Ty = Type::getHalfTy(Ctx);
9801003
TypeStr = TypeStr.substr(strlen("half"));
981-
} else if (TypeStr.startswith("opencl.sampler_t")) {
982-
Type = StructType::create(Ctx, "opencl.sampler_t");
1004+
} else if (TypeStr.startswith("float")) {
1005+
Ty = Type::getFloatTy(Ctx);
1006+
TypeStr = TypeStr.substr(strlen("float"));
1007+
} else if (TypeStr.startswith("double")) {
1008+
Ty = Type::getDoubleTy(Ctx);
1009+
TypeStr = TypeStr.substr(strlen("double"));
9831010
} else
9841011
llvm_unreachable("Unable to recognize SPIRV type name.");
1012+
1013+
auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
1014+
1015+
// Handle "type*" or "type* vector[N]".
1016+
if (TypeStr.starts_with("*")) {
1017+
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1018+
TypeStr = TypeStr.substr(strlen("*"));
1019+
}
1020+
1021+
// Handle "typeN*" or "type vector[N]*".
1022+
bool IsPtrToVec = TypeStr.consume_back("*");
1023+
9851024
if (TypeStr.startswith(" vector[")) {
9861025
TypeStr = TypeStr.substr(strlen(" vector["));
9871026
TypeStr = TypeStr.substr(0, TypeStr.find(']'));
9881027
}
9891028
TypeStr.getAsInteger(10, VecElts);
990-
auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder);
9911029
if (VecElts > 0)
9921030
SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1031+
1032+
if (IsPtrToVec)
1033+
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1034+
9931035
return SpirvTy;
9941036
}
9951037

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,11 @@ class SPIRVGlobalRegistry {
138138

139139
// Either generate a new OpTypeXXX instruction or return an existing one
140140
// corresponding to the given string containing the name of the builtin type.
141-
SPIRVType *getOrCreateSPIRVTypeByName(StringRef TypeStr,
142-
MachineIRBuilder &MIRBuilder);
141+
SPIRVType *getOrCreateSPIRVTypeByName(
142+
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
143+
SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
144+
SPIRV::AccessQualifier::AccessQualifier AQ =
145+
SPIRV::AccessQualifier::ReadWrite);
143146

144147
// Return the SPIR-V type instruction corresponding to the given VReg, or
145148
// nullptr if no such type instruction exists.

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ const Type *getTypedPtrEltType(const Type *Ty) {
332332
return PType->getNonOpaquePointerElementType();
333333
}
334334

335-
static bool hasBuiltinTypePrefix(StringRef Name) {
335+
bool hasBuiltinTypePrefix(StringRef Name) {
336336
if (Name.starts_with("opencl.") || Name.starts_with("spirv."))
337337
return true;
338338
return false;

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ std::string getOclOrSpirvBuiltinDemangledName(StringRef Name);
9292
// element type, otherwise return Type.
9393
const Type *getTypedPtrEltType(const Type *Type);
9494

95+
// Check if a string contains a builtin prefix.
96+
bool hasBuiltinTypePrefix(StringRef Name);
97+
9598
// Check if given LLVM type is a special opaque builtin type.
9699
bool isSpecialOpaqueType(const Type *Ty);
97100
} // namespace llvm
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
3+
; CHECK: %[[#FLOAT32:]] = OpTypeFloat 32
4+
; CHECK: %[[#PTR:]] = OpTypePointer CrossWorkgroup %[[#FLOAT32]]
5+
; CHECK: %[[#ARG:]] = OpFunctionParameter %[[#PTR]]
6+
; CHECK: %[[#GEP:]] = OpInBoundsPtrAccessChain %[[#PTR]] %[[#ARG]] %[[#]]
7+
; CHECK: %[[#]] = OpLoad %[[#FLOAT32]] %[[#GEP]] Aligned 4
8+
9+
define spir_kernel void @test1(ptr addrspace(1) %arg1) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_type_qual !4 {
10+
%a = getelementptr inbounds float, ptr addrspace(1) %arg1, i64 1
11+
%b = load float, ptr addrspace(1) %a, align 4
12+
ret void
13+
}
14+
15+
!1 = !{i32 1}
16+
!2 = !{!"none"}
17+
!3 = !{!"float*"}
18+
!4 = !{!""}

0 commit comments

Comments
 (0)