Skip to content

[SPIRV] Improve builtins matching and type inference in SPIR-V Backend, fix target ext type constants #89948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ lookupBuiltin(StringRef DemangledCall,
std::string BuiltinName =
DemangledCall.substr(0, DemangledCall.find('(')).str();

// Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
BuiltinName = BuiltinName.substr(12);

// Check if the extracted name contains type information between angle
// brackets. If so, the builtin is an instantiated template - needs to have
// the information after angle brackets and return type removed.
Expand Down Expand Up @@ -2008,6 +2012,13 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode =
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;

bool IsSet = Opcode == SPIRV::OpGroupAsyncCopy;
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, Opcode, Call,
IsSet ? TypeReg : Register(0));

auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);

switch (Opcode) {
Expand Down Expand Up @@ -2306,7 +2317,7 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
// parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
// base types.
if (TypeStr.ends_with("*"))
TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));

return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
Ctx);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,9 @@ defm : DemangledNativeBuiltin<"__spirv_SpecConstantComposite", OpenCL_std, SpecC

// Async Copy and Prefetch builtin records:
defm : DemangledNativeBuiltin<"async_work_group_copy", OpenCL_std, AsyncCopy, 4, 4, OpGroupAsyncCopy>;
defm : DemangledNativeBuiltin<"__spirv_GroupAsyncCopy", OpenCL_std, AsyncCopy, 4, 4, OpGroupAsyncCopy>;
defm : DemangledNativeBuiltin<"__spirv_GroupAsyncCopy", OpenCL_std, AsyncCopy, 6, 6, OpGroupAsyncCopy>;
defm : DemangledNativeBuiltin<"wait_group_events", OpenCL_std, AsyncCopy, 2, 2, OpGroupWaitEvents>;
defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy, 2, 2, OpGroupWaitEvents>;
defm : DemangledNativeBuiltin<"__spirv_GroupWaitEvents", OpenCL_std, AsyncCopy, 3, 3, OpGroupWaitEvents>;

// Load and store builtin records:
defm : DemangledNativeBuiltin<"__spirv_Load", OpenCL_std, LoadStore, 1, 3, OpLoad>;
Expand Down
97 changes: 71 additions & 26 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class SPIRVEmitIntrinsics
return B.CreateIntrinsic(IntrID, {Types}, Args);
}

void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);

void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
Expand All @@ -111,6 +113,7 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
void processParamTypesByFunHeader(Function *F, IRBuilder<> &B);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);
Expand Down Expand Up @@ -194,6 +197,17 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
false);
}

void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
Value *Arg) {
CallInst *AssignPtrTyCI =
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Arg->getType()},
Constant::getNullValue(ElemTy), Arg,
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
GR->addDeducedElementType(Arg, ElemTy);
AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
}

// Set element pointer type to the given value of ValueTy and tries to
// specify this type further (recursively) by Operand value, if needed.
Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
Expand Down Expand Up @@ -232,6 +246,19 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
return nullptr;
}

// Implements what we know in advance about intrinsics and builtin calls
// TODO: consider feasibility of this particular case to be generalized by
// encoding knowledge about intrinsics and builtin calls by corresponding
// specification rules
static Type *getPointeeTypeByCallInst(StringRef DemangledName,
Function *CalledF, unsigned OpIdx) {
if ((DemangledName.starts_with("__spirv_ocl_printf(") ||
DemangledName.starts_with("printf(")) &&
OpIdx == 0)
return IntegerType::getInt8Ty(CalledF->getContext());
return nullptr;
}

// Deduce and return a successfully deduced Type of the Instruction,
// or nullptr otherwise.
Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
Expand Down Expand Up @@ -795,6 +822,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
return;

// collect information about formal parameter types
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
Function *CalledF = CI->getCalledFunction();
SmallVector<Type *, 4> CalledArgTys;
bool HaveTypes = false;
Expand All @@ -811,10 +840,15 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
if (!ElemTy && hasPointeeTypeAttr(CalledArg))
ElemTy = getPointeeTypeByAttr(CalledArg);
if (!ElemTy) {
for (User *U : CalledArg->users()) {
if (Instruction *Inst = dyn_cast<Instruction>(U)) {
if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
break;
ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
if (ElemTy) {
GR->addDeducedElementType(CalledArg, ElemTy);
} else {
for (User *U : CalledArg->users()) {
if (Instruction *Inst = dyn_cast<Instruction>(U)) {
if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
break;
}
}
}
}
Expand All @@ -823,8 +857,6 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
}
}

std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
if (DemangledName.empty() && !HaveTypes)
return;

Expand All @@ -835,8 +867,14 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
continue;

// Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
continue;
if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand)) {
// However, we may have assumptions about the formal argument's type and
// may have a need to insert a ptr cast for the actual parameter of this
// call.
Argument *CalledArg = CalledF->getArg(OpIdx);
if (!GR->findDeducedElementType(CalledArg))
continue;
}

Type *ExpectedType =
OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
Expand Down Expand Up @@ -1102,9 +1140,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
(II->paramHasAttr(OpNo, Attribute::ImmArg))))
continue;
B.SetInsertPoint(I);
auto *NewOp =
buildIntrWithMD(Intrinsic::spv_track_constant,
{Op->getType(), Op->getType()}, Op, Op, {}, B);
Value *OpTyVal = Op;
if (Op->getType()->isTargetExtTy())
OpTyVal = Constant::getNullValue(
IntegerType::get(I->getContext(), GR->getPointerSize()));
auto *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant,
{Op->getType(), OpTyVal->getType()}, Op,
OpTyVal, {}, B);
I->setOperand(OpNo, NewOp);
}
}
Expand Down Expand Up @@ -1179,28 +1221,29 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
return nullptr;
}

void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
if (!isUntypedPointerTy(Arg->getType()))
continue;
Type *ElemTy = GR->findDeducedElementType(Arg);
if (!ElemTy && hasPointeeTypeAttr(Arg) &&
(ElemTy = getPointeeTypeByAttr(Arg)) != nullptr)
buildAssignPtr(B, ElemTy, Arg);
}
}

void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
if (!isUntypedPointerTy(Arg->getType()))
continue;
Type *ElemTy = GR->findDeducedElementType(Arg);
if (!ElemTy) {
if (hasPointeeTypeAttr(Arg) &&
(ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
GR->addDeducedElementType(Arg, ElemTy);
} else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
CallInst *AssignPtrTyCI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
Constant::getNullValue(ElemTy), Arg,
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
GR->addDeducedElementType(Arg, ElemTy);
AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
}
}
if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr)
buildAssignPtr(B, ElemTy, Arg);
}
}

Expand All @@ -1217,6 +1260,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
AggrConstTypes.clear();
AggrStores.clear();

processParamTypesByFunHeader(F, B);

// StoreInst's operand type can be changed during the next transformations,
// so we need to store it in the set. Also store already transformed types.
for (auto &I : instructions(Func)) {
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,16 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
case SPIRV::OpConstantI: {
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
MI.getOperand(2).getImm() == 0) {
// Validate the null constant of a target extension type
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
MI.removeOperand(i);
}
} break;
}
}
}
Expand Down
40 changes: 28 additions & 12 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ class SPIRVPreLegalizer : public MachineFunctionPass {
};
} // namespace

static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
static void
addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
MachineRegisterInfo &MRI = MF.getRegInfo();
DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
Expand All @@ -47,21 +49,22 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
continue;
ToErase.push_back(&MI);
Register SrcReg = MI.getOperand(2).getReg();
auto *Const =
cast<Constant>(cast<ConstantAsMetadata>(
MI.getOperand(3).getMetadata()->getOperand(0))
->getValue());
if (auto *GV = dyn_cast<GlobalValue>(Const)) {
Register Reg = GR->find(GV, &MF);
if (!Reg.isValid())
GR->add(GV, &MF, MI.getOperand(2).getReg());
GR->add(GV, &MF, SrcReg);
else
RegsAlreadyAddedToDT[&MI] = Reg;
} else {
Register Reg = GR->find(Const, &MF);
if (!Reg.isValid()) {
if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
auto *BuildVec = MRI.getVRegDef(SrcReg);
assert(BuildVec &&
BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
Expand All @@ -75,7 +78,13 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
BuildVec->getOperand(1 + i).setReg(ElemReg);
}
}
GR->add(Const, &MF, MI.getOperand(2).getReg());
GR->add(Const, &MF, SrcReg);
if (Const->getType()->isTargetExtTy()) {
// remember association so that we can restore it when assign types
MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
if (SrcMI && SrcMI->getOpcode() == TargetOpcode::G_CONSTANT)
TargetExtConstTypes[SrcMI] = Const->getType();
}
} else {
RegsAlreadyAddedToDT[&MI] = Reg;
// This MI is unused and will be removed. If the MI uses
Expand Down Expand Up @@ -364,8 +373,10 @@ void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
}
} // namespace llvm

static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
static void
generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB,
DenseMap<MachineInstr *, Type *> &TargetExtConstTypes) {
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
Expand Down Expand Up @@ -422,11 +433,14 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
continue;
}
Type *Ty = nullptr;
if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
Ty = MI.getOperand(1).getCImm()->getType();
else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
if (MI.getOpcode() == TargetOpcode::G_CONSTANT) {
auto TargetExtIt = TargetExtConstTypes.find(&MI);
Ty = TargetExtIt == TargetExtConstTypes.end()
? MI.getOperand(1).getCImm()->getType()
: TargetExtIt->second;
} else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT) {
Ty = MI.getOperand(1).getFPImm()->getType();
else {
} else {
assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
Type *ElemTy = nullptr;
MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
Expand Down Expand Up @@ -616,10 +630,12 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
addConstantsToTrack(MF, GR);
// a registry of target extension constants
DenseMap<MachineInstr *, Type *> TargetExtConstTypes;
addConstantsToTrack(MF, GR, TargetExtConstTypes);
foldConstantsIntoIntrinsics(MF);
insertBitcasts(MF, GR, MIB);
generateAssignInstrs(MF, GR, MIB);
generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
processSwitches(MF, GR, MIB);
processInstrsWithTypeFolding(MF, GR, MIB);
removeImplicitFallthroughs(MF, MIB);
Expand Down
40 changes: 40 additions & 0 deletions llvm/test/CodeGen/SPIRV/printf.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK: %[[#ExtImport:]] = OpExtInstImport "OpenCL.std"
; CHECK: %[[#Char:]] = OpTypeInt 8 0
; CHECK: %[[#CharPtr:]] = OpTypePointer UniformConstant %[[#Char]]
; CHECK: %[[#GV:]] = OpVariable %[[#]] UniformConstant %[[#]]
; CHECK: OpFunction
; CHECK: %[[#Arg1:]] = OpFunctionParameter
; CHECK: %[[#Arg2:]] = OpFunctionParameter
; CHECK: %[[#CastedGV:]] = OpBitcast %[[#CharPtr]] %[[#GV]]
; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst:]]
; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedGV]] %[[#ArgConst]]
; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst:]]
; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#Arg1]] %[[#ArgConst]]
; CHECK-NEXT: %[[#CastedArg2:]] = OpBitcast %[[#CharPtr]] %[[#Arg2]]
; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
; CHECK-NEXT: OpExtInst %[[#]] %[[#ExtImport]] printf %[[#CastedArg2]] %[[#ArgConst]]
; CHECK: OpFunctionEnd

%struct = type { [6 x i8] }

@FmtStr = internal addrspace(2) constant [6 x i8] c"c=%c\0A\00", align 1

define spir_kernel void @foo(ptr addrspace(2) %_arg_fmt1, ptr addrspace(2) byval(%struct) %_arg_fmt2) {
entry:
%r1 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
%r2 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) @FmtStr, i8 signext 97)
%r3 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
%r4 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt1, i8 signext 97)
%r5 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z6printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
%r6 = tail call spir_func i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2) %_arg_fmt2, i8 signext 97)
ret void
}

declare dso_local spir_func i32 @_Z6printfPU3AS2Kcz(ptr addrspace(2), ...)
declare dso_local spir_func i32 @_Z18__spirv_ocl_printfPU3AS2Kcz(ptr addrspace(2), ...)
Loading