Skip to content

[SPIR-V] Do not use OpenCL metadata for ptr element type resolution #82678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1775,7 +1775,7 @@ static const Type *getMachineInstrType(MachineInstr *MI) {
return nullptr;
Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
assert(Ty && "Type is expected");
return getTypedPtrEltType(Ty);
return Ty;
}

static const Type *getBlockStructType(Register ParamReg,
Expand All @@ -1787,7 +1787,7 @@ static const Type *getBlockStructType(Register ParamReg,
// section 6.12.5 should guarantee that we can do this.
MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
return getTypedPtrEltType(MI->getOperand(1).getGlobal()->getType());
return MI->getOperand(1).getGlobal()->getType();
assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
"Blocks in OpenCL C must be traceable to allocation site");
return getMachineInstrType(MI);
Expand Down Expand Up @@ -2043,7 +2043,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
.addImm(Builtin->Number);
for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
MIB.addImm(Builtin->ElementCount);
if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
MIB.addImm(Builtin->ElementCount);

// Rounding mode should be passed as a last argument in the MI for builtins
// like "vstorea_halfn_r".
Expand Down Expand Up @@ -2179,6 +2180,61 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return false;
}

Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
unsigned ArgIdx, LLVMContext &Ctx) {
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
StringRef BuiltinArgs =
DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
if (ArgIdx >= BuiltinArgsTypeStrs.size())
return nullptr;
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();

// Parse strings representing OpenCL builtin types.
if (hasBuiltinTypePrefix(TypeStr)) {
// OpenCL builtin types in demangled call strings have the following format:
// e.g. ocl_image2d_ro
bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");

// Check if this is pointer to a builtin type and not just pointer
// representing a builtin type. In case it is a pointer to builtin type,
// this will require additional handling in the method calling
// parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
// base types.
if (TypeStr.ends_with("*"))
TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));

return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
Ctx);
}

// Parse type name in either "typeN" or "type vector[N]" format, where
// N is the number of elements of the vector.
Type *BaseType;
unsigned VecElts = 0;

BaseType = parseBasicTypeName(TypeStr, Ctx);
if (!BaseType)
// Unable to recognize SPIRV type name.
return nullptr;

if (BaseType->isVoidTy())
BaseType = Type::getInt8Ty(Ctx);

// Handle "typeN*" or "type vector[N]*".
TypeStr.consume_back("*");

if (TypeStr.consume_front(" vector["))
TypeStr = TypeStr.substr(0, TypeStr.find(']'));

TypeStr.getAsInteger(10, VecElts);
if (VecElts > 0)
BaseType = VectorType::get(BaseType, VecElts, false);

return BaseType;
}

struct BuiltinType {
StringRef Name;
uint32_t Opcode;
Expand Down Expand Up @@ -2277,9 +2333,8 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
}

namespace SPIRV {
const TargetExtType *
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
MachineIRBuilder &MIRBuilder) {
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
LLVMContext &Context) {
StringRef NameWithParameters = TypeName;

// Pointers-to-opaque-structs representing OpenCL types are first translated
Expand All @@ -2303,7 +2358,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
// Parameterized SPIR-V builtins names follow this format:
// e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
if (!NameWithParameters.contains('_'))
return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);
return TargetExtType::get(Context, NameWithParameters);

SmallVector<StringRef> Parameters;
unsigned BaseNameLength = NameWithParameters.find('_') - 1;
Expand All @@ -2312,8 +2367,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
SmallVector<Type *, 1> TypeParameters;
bool HasTypeParameter = !isDigit(Parameters[0][0]);
if (HasTypeParameter)
TypeParameters.push_back(parseTypeString(
Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
TypeParameters.push_back(parseTypeString(Parameters[0], Context));
SmallVector<unsigned> IntParameters;
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
unsigned IntParameter = 0;
Expand All @@ -2323,7 +2377,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
"Invalid format of SPIR-V builtin parameter literal!");
IntParameters.push_back(IntParameter);
}
return TargetExtType::get(MIRBuilder.getContext(),
return TargetExtType::get(Context,
NameWithParameters.substr(0, BaseNameLength),
TypeParameters, IntParameters);
}
Expand All @@ -2343,7 +2397,7 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
if (!BuiltinType)
BuiltinType = parseBuiltinTypeNameToTargetExtType(
OpaqueType->getStructName().str(), MIRBuilder);
OpaqueType->getStructName().str(), MIRBuilder.getContext());

unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();

Expand Down
16 changes: 13 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,26 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);

/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
///
/// \return LLVM Type or nullptr if unrecognized
///
/// \p DemangledCall is the skeleton of the lowered builtin function call.
/// \p ArgIdx is the index of the argument to parse.
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
unsigned ArgIdx, LLVMContext &Ctx);

/// Translates a string representing a SPIR-V or OpenCL builtin type to a
/// TargetExtType that can be further lowered with lowerBuiltinType().
///
/// \return A TargetExtType representing the builtin SPIR-V type.
///
/// \p TypeName is the full string representation of the SPIR-V or OpenCL
/// builtin type.
const TargetExtType *
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
MachineIRBuilder &MIRBuilder);
TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
LLVMContext &Context);

/// Handles the translation of the provided special opaque/builtin type \p Type
/// to SPIR-V type. Generates the corresponding machine instructions for the
Expand Down
66 changes: 47 additions & 19 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/FunctionLoweringInfo.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/ModRef.h"

using namespace llvm;
Expand Down Expand Up @@ -158,28 +160,54 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,

Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);

// In case of non-kernel SPIR-V function or already TargetExtType, use the
// original IR type.
if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
isSpecialOpaqueType(OriginalArgType))
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
// be legally reassigned later).
if (!OriginalArgType->isPointerTy())
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

SPIRVType *ResArgType = nullptr;
if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
StringRef MDTypeStr = MDKernelArgType->getString();
if (MDTypeStr.ends_with("*"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
MDTypeStr, MIRBuilder,
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
ST));
else if (MDTypeStr.ends_with("_t"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
"opencl." + MDTypeStr.str(), MIRBuilder,
SPIRV::StorageClass::Function, ArgAccessQual);
// In case OriginalArgType is of pointer type, there are three possibilities:
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
// intrinsic assigning a TargetExtType.
// 3) This is a pointer, try to retrieve pointer element type from a
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
Argument *Arg = F.getArg(ArgIdx);
if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
: Arg->getParamByRefType();
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
ST));
}
return ResArgType ? ResArgType
: GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
ArgAccessQual);

for (auto User : Arg->users()) {
auto *II = dyn_cast<IntrinsicInst>(User);
// Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
Type *BuiltinType =
cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
}

// Check if this is spv_assign_ptr_type assigning pointer element type.
if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
continue;

MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
cast<ConstantAsMetadata>(VMD->getMetadata())->getType(), MIRBuilder);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
}

return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
}

static SPIRV::ExecutionModel::ExecutionModel
Expand Down
Loading