Skip to content

Commit 43222bd

Browse files
[SPIR-V] Do not use OpenCL metadata for ptr element type resolution (#82678)
This pull request aims to remove any dependency on OpenCL/SPIR-V type information in LLVM IR metadata. While, using metadata might simplify and prettify the resulting SPIR-V output (and restore some of the information missed in the transformation to opaque pointers), the overall methodology for resolving kernel parameter types is highly inefficient. The high-level strategy is to assign kernel parameter types in this order: 1. Resolving the types using builtin function calls as mangled names must contain type information or by looking up builtin definition in SPIRVBuiltins.td. Then: - Assigning the type temporarily using an intrinsic and later setting the right SPIR-V type in SPIRVGlobalRegistry after IRTranslation - Inserting a bitcast 2. Defaulting to LLVM IR types (in case of pointers the generic i8* type or types from byval/byref attributes) In case of type incompatibility (e.g. parameter defined initially as sampler_t and later used as image_t) the error will be found early on before IRTranslation (in the SPIRVEmitIntrinsics pass).
1 parent 23bc5b6 commit 43222bd

37 files changed

+523
-286
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ static const Type *getMachineInstrType(MachineInstr *MI) {
17751775
return nullptr;
17761776
Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
17771777
assert(Ty && "Type is expected");
1778-
return getTypedPtrEltType(Ty);
1778+
return Ty;
17791779
}
17801780

17811781
static const Type *getBlockStructType(Register ParamReg,
@@ -1787,7 +1787,7 @@ static const Type *getBlockStructType(Register ParamReg,
17871787
// section 6.12.5 should guarantee that we can do this.
17881788
MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
17891789
if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
1790-
return getTypedPtrEltType(MI->getOperand(1).getGlobal()->getType());
1790+
return MI->getOperand(1).getGlobal()->getType();
17911791
assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
17921792
"Blocks in OpenCL C must be traceable to allocation site");
17931793
return getMachineInstrType(MI);
@@ -2043,7 +2043,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
20432043
.addImm(Builtin->Number);
20442044
for (auto Argument : Call->Arguments)
20452045
MIB.addUse(Argument);
2046-
MIB.addImm(Builtin->ElementCount);
2046+
if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
2047+
MIB.addImm(Builtin->ElementCount);
20472048

20482049
// Rounding mode should be passed as a last argument in the MI for builtins
20492050
// like "vstorea_halfn_r".
@@ -2179,6 +2180,61 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
21792180
return false;
21802181
}
21812182

2183+
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2184+
unsigned ArgIdx, LLVMContext &Ctx) {
2185+
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2186+
StringRef BuiltinArgs =
2187+
DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
2188+
BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
2189+
if (ArgIdx >= BuiltinArgsTypeStrs.size())
2190+
return nullptr;
2191+
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2192+
2193+
// Parse strings representing OpenCL builtin types.
2194+
if (hasBuiltinTypePrefix(TypeStr)) {
2195+
// OpenCL builtin types in demangled call strings have the following format:
2196+
// e.g. ocl_image2d_ro
2197+
bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
2198+
assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
2199+
2200+
// Check if this is pointer to a builtin type and not just pointer
2201+
// representing a builtin type. In case it is a pointer to builtin type,
2202+
// this will require additional handling in the method calling
2203+
// parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
2204+
// base types.
2205+
if (TypeStr.ends_with("*"))
2206+
TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
2207+
2208+
return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
2209+
Ctx);
2210+
}
2211+
2212+
// Parse type name in either "typeN" or "type vector[N]" format, where
2213+
// N is the number of elements of the vector.
2214+
Type *BaseType;
2215+
unsigned VecElts = 0;
2216+
2217+
BaseType = parseBasicTypeName(TypeStr, Ctx);
2218+
if (!BaseType)
2219+
// Unable to recognize SPIRV type name.
2220+
return nullptr;
2221+
2222+
if (BaseType->isVoidTy())
2223+
BaseType = Type::getInt8Ty(Ctx);
2224+
2225+
// Handle "typeN*" or "type vector[N]*".
2226+
TypeStr.consume_back("*");
2227+
2228+
if (TypeStr.consume_front(" vector["))
2229+
TypeStr = TypeStr.substr(0, TypeStr.find(']'));
2230+
2231+
TypeStr.getAsInteger(10, VecElts);
2232+
if (VecElts > 0)
2233+
BaseType = VectorType::get(BaseType, VecElts, false);
2234+
2235+
return BaseType;
2236+
}
2237+
21822238
struct BuiltinType {
21832239
StringRef Name;
21842240
uint32_t Opcode;
@@ -2277,9 +2333,8 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
22772333
}
22782334

22792335
namespace SPIRV {
2280-
const TargetExtType *
2281-
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
2282-
MachineIRBuilder &MIRBuilder) {
2336+
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
2337+
LLVMContext &Context) {
22832338
StringRef NameWithParameters = TypeName;
22842339

22852340
// Pointers-to-opaque-structs representing OpenCL types are first translated
@@ -2303,7 +2358,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
23032358
// Parameterized SPIR-V builtins names follow this format:
23042359
// e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
23052360
if (!NameWithParameters.contains('_'))
2306-
return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);
2361+
return TargetExtType::get(Context, NameWithParameters);
23072362

23082363
SmallVector<StringRef> Parameters;
23092364
unsigned BaseNameLength = NameWithParameters.find('_') - 1;
@@ -2312,8 +2367,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
23122367
SmallVector<Type *, 1> TypeParameters;
23132368
bool HasTypeParameter = !isDigit(Parameters[0][0]);
23142369
if (HasTypeParameter)
2315-
TypeParameters.push_back(parseTypeString(
2316-
Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
2370+
TypeParameters.push_back(parseTypeString(Parameters[0], Context));
23172371
SmallVector<unsigned> IntParameters;
23182372
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
23192373
unsigned IntParameter = 0;
@@ -2323,7 +2377,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
23232377
"Invalid format of SPIR-V builtin parameter literal!");
23242378
IntParameters.push_back(IntParameter);
23252379
}
2326-
return TargetExtType::get(MIRBuilder.getContext(),
2380+
return TargetExtType::get(Context,
23272381
NameWithParameters.substr(0, BaseNameLength),
23282382
TypeParameters, IntParameters);
23292383
}
@@ -2343,7 +2397,7 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
23432397
const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
23442398
if (!BuiltinType)
23452399
BuiltinType = parseBuiltinTypeNameToTargetExtType(
2346-
OpaqueType->getStructName().str(), MIRBuilder);
2400+
OpaqueType->getStructName().str(), MIRBuilder.getContext());
23472401

23482402
unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
23492403

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,26 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
3838
const SmallVectorImpl<Register> &Args,
3939
SPIRVGlobalRegistry *GR);
4040

41+
/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
42+
/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
43+
/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
44+
///
45+
/// \return LLVM Type or nullptr if unrecognized
46+
///
47+
/// \p DemangledCall is the skeleton of the lowered builtin function call.
48+
/// \p ArgIdx is the index of the argument to parse.
49+
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
50+
unsigned ArgIdx, LLVMContext &Ctx);
51+
4152
/// Translates a string representing a SPIR-V or OpenCL builtin type to a
4253
/// TargetExtType that can be further lowered with lowerBuiltinType().
4354
///
4455
/// \return A TargetExtType representing the builtin SPIR-V type.
4556
///
4657
/// \p TypeName is the full string representation of the SPIR-V or OpenCL
4758
/// builtin type.
48-
const TargetExtType *
49-
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
50-
MachineIRBuilder &MIRBuilder);
59+
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
60+
LLVMContext &Context);
5161

5262
/// Handles the translation of the provided special opaque/builtin type \p Type
5363
/// to SPIR-V type. Generates the corresponding machine instructions for the

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "SPIRVSubtarget.h"
2323
#include "SPIRVUtils.h"
2424
#include "llvm/CodeGen/FunctionLoweringInfo.h"
25+
#include "llvm/IR/IntrinsicInst.h"
26+
#include "llvm/IR/IntrinsicsSPIRV.h"
2527
#include "llvm/Support/ModRef.h"
2628

2729
using namespace llvm;
@@ -158,28 +160,54 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
158160

159161
Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
160162

161-
// In case of non-kernel SPIR-V function or already TargetExtType, use the
162-
// original IR type.
163-
if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
164-
isSpecialOpaqueType(OriginalArgType))
163+
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
164+
// be legally reassigned later).
165+
if (!OriginalArgType->isPointerTy())
165166
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
166167

167-
SPIRVType *ResArgType = nullptr;
168-
if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
169-
StringRef MDTypeStr = MDKernelArgType->getString();
170-
if (MDTypeStr.ends_with("*"))
171-
ResArgType = GR->getOrCreateSPIRVTypeByName(
172-
MDTypeStr, MIRBuilder,
173-
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
174-
ST));
175-
else if (MDTypeStr.ends_with("_t"))
176-
ResArgType = GR->getOrCreateSPIRVTypeByName(
177-
"opencl." + MDTypeStr.str(), MIRBuilder,
178-
SPIRV::StorageClass::Function, ArgAccessQual);
168+
// In case OriginalArgType is of pointer type, there are three possibilities:
169+
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
170+
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
171+
// intrinsic assigning a TargetExtType.
172+
// 3) This is a pointer, try to retrieve pointer element type from a
173+
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
174+
// type.
175+
Argument *Arg = F.getArg(ArgIdx);
176+
if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
177+
Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
178+
: Arg->getParamByRefType();
179+
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
180+
return GR->getOrCreateSPIRVPointerType(
181+
ElementType, MIRBuilder,
182+
addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
183+
ST));
179184
}
180-
return ResArgType ? ResArgType
181-
: GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
182-
ArgAccessQual);
185+
186+
for (auto User : Arg->users()) {
187+
auto *II = dyn_cast<IntrinsicInst>(User);
188+
// Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
189+
if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
190+
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
191+
Type *BuiltinType =
192+
cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
193+
assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
194+
return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
195+
}
196+
197+
// Check if this is spv_assign_ptr_type assigning pointer element type.
198+
if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
199+
continue;
200+
201+
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
202+
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
203+
cast<ConstantAsMetadata>(VMD->getMetadata())->getType(), MIRBuilder);
204+
return GR->getOrCreateSPIRVPointerType(
205+
ElementType, MIRBuilder,
206+
addressSpaceToStorageClass(
207+
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
208+
}
209+
210+
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
183211
}
184212

185213
static SPIRV::ExecutionModel::ExecutionModel

0 commit comments

Comments
 (0)