-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIR-V] Improve type inference in SPIR-V Backend for opaque pointers #86283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesThis PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case. Full diff: https://github.com/llvm/llvm-project/pull/86283.diff 4 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 458af9229ed7b1..5828db6669ff18 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
+ Type *deduceFunParamType(Function *F, unsigned OpIdx);
+ Type *deduceFunParamType(Function *F, unsigned OpIdx,
+ std::unordered_set<Function *> &FVisited);
public:
static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
static Type *deduceElementTypeHelper(Value *I,
std::unordered_set<Value *> &Visited,
DenseMap<Value *, Type *> &DeducedElTys) {
+ // allow to pass nullptr as an argument
+ if (!I)
+ return nullptr;
+
// maybe already known
auto It = DeducedElTys.find(I);
if (It != DeducedElTys.end())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
// fallback value in case when we fail to deduce a type
Type *Ty = nullptr;
// look for known basic patterns of type inference
- if (auto *Ref = dyn_cast<AllocaInst>(I))
+ if (auto *Ref = dyn_cast<AllocaInst>(I)) {
Ty = Ref->getAllocatedType();
- else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
+ } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
Ty = Ref->getResultElementType();
- else if (auto *Ref = dyn_cast<GlobalValue>(I))
+ } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
Ty = Ref->getValueType();
- else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
+ } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
DeducedElTys);
+ } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
+ if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
+ isPointerTy(Src) && isPointerTy(Dest))
+ Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+ }
// remember the found relationship
if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
- DenseMap<unsigned, Argument *> Args;
- unsigned i = 0;
- for (Argument &Arg : F->args()) {
- if (isUntypedPointerTy(Arg.getType()) &&
- DeducedElTys.find(&Arg) == DeducedElTys.end() &&
- !HasPointeeTypeAttr(&Arg))
- Args[i] = &Arg;
- i++;
- }
- if (Args.size() == 0)
- return;
+Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+ std::unordered_set<Function *> FVisited;
+ return deduceFunParamType(F, OpIdx, FVisited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceFunParamType(
+ Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
+ // maybe a cycle
+ if (FVisited.find(F) != FVisited.end())
+ return nullptr;
+ FVisited.insert(F);
- // Args contains opaque pointers without element type definition
- B.SetInsertPointPastAllocas(F);
std::unordered_set<Value *> Visited;
+ SmallVector<std::pair<Function *, unsigned>> Lookup;
+ // search in function's call sites
for (User *U : F->users()) {
CallInst *CI = dyn_cast<CallInst>(U);
- if (!CI)
+ if (!CI || OpIdx >= CI->arg_size())
continue;
- for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
- OpIdx++) {
- auto It = Args.find(OpIdx);
- Argument *Arg = It == Args.end() ? nullptr : It->second;
- if (!Arg)
- continue;
- Value *OpArg = CI->getArgOperand(OpIdx);
- if (!isPointerTy(OpArg->getType()))
+ Value *OpArg = CI->getArgOperand(OpIdx);
+ if (!isPointerTy(OpArg->getType()))
+ continue;
+ // maybe we already know operand's element type
+ if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
+ return It->second;
+ // search in actual parameter's users
+ for (User *OpU : OpArg->users()) {
+ Instruction *Inst = dyn_cast<Instruction>(OpU);
+ if (!Inst || Inst == CI)
continue;
- // maybe we already know the operand's element type
- auto DeducedIt = DeducedElTys.find(OpArg);
- Type *ElemTy =
- DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
- if (!ElemTy) {
- for (User *OpU : OpArg->users()) {
- if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
- Visited.clear();
- ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
- if (ElemTy)
- break;
- }
- }
+ Visited.clear();
+ if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+ return Ty;
+ }
+ // check if it's a formal parameter of the outer function
+ if (!CI->getParent() || !CI->getParent()->getParent())
+ continue;
+ Function *OuterF = CI->getParent()->getParent();
+ if (FVisited.find(OuterF) != FVisited.end())
+ continue;
+ for (unsigned i = 0; i < OuterF->arg_size(); ++i) {
+ if (OuterF->getArg(i) == OpArg) {
+ Lookup.push_back(std::make_pair(OuterF, i));
+ break;
}
- if (ElemTy) {
- unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+ }
+ }
+
+ // search in function parameters
+ for (auto &Pair : Lookup) {
+ if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+ return Ty;
+ }
+
+ return nullptr;
+}
+
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+ B.SetInsertPointPastAllocas(F);
+ DenseMap<Argument *, Type *> Args;
+ for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+ Argument *Arg = F->getArg(OpIdx);
+ if (isUntypedPointerTy(Arg->getType()) &&
+ DeducedElTys.find(Arg) == DeducedElTys.end() &&
+ !HasPointeeTypeAttr(Arg)) {
+ if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
CallInst *AssignPtrTyCI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
- Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+ Constant::getNullValue(ElemTy), Arg,
+ {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
DeducedElTys[AssignPtrTyCI] = ElemTy;
DeducedElTys[Arg] = ElemTy;
- Args.erase(It);
}
}
- if (Args.size() == 0)
- break;
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..ee52163a5d127f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -479,6 +479,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
GVar = M->getGlobalVariable(Name);
if (GVar == nullptr) {
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
+ // Module takes ownership of the global var.
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
GlobalValue::ExternalLinkage, nullptr,
Twine(Name));
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5bb8f6084f9671..39228e2196b3af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -499,6 +499,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
Register GV = I.getOperand(1).getReg();
MachineRegisterInfo::def_instr_iterator II = MRI->def_instr_begin(GV);
+ (void)II;
assert(((*II).getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
(*II).getOpcode() == TargetOpcode::COPY ||
(*II).getOpcode() == SPIRV::OpVariable) &&
@@ -771,10 +772,13 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
- Type *LLVMArrTy = ArrayType::get(
- IntegerType::get(GR.CurMF->getFunction().getContext(), 8), Num);
- GlobalVariable *GV =
- new GlobalVariable(LLVMArrTy, true, GlobalValue::InternalLinkage);
+ Function &CurFunction = GR.CurMF->getFunction();
+ Type *LLVMArrTy =
+ ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
+ // Module takes ownership of the global var.
+ GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
+ true, GlobalValue::InternalLinkage,
+ Constant::getNullValue(LLVMArrTy));
Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
GR.add(GV, GR.CurMF, VarReg);
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
new file mode 100644
index 00000000000000..703f1e22a0321a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -0,0 +1,52 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[ArgCum:.*]] "_arg_cum"
+; CHECK-SPIRV-DAG: OpName %[[FunTest:.*]] "test"
+; CHECK-SPIRV-DAG: OpName %[[Addr:.*]] "addr"
+; CHECK-SPIRV-DAG: OpName %[[StubObj:.*]] "stub_object"
+; CHECK-SPIRV-DAG: OpName %[[MemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooStub:.*]] "foo_stub"
+; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
+; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
+; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
+; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
+; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
+; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
+
+define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
+entry:
+ %lptr = getelementptr inbounds i32, ptr addrspace(1) %_arg_cum, i64 1
+ %addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
+ %object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
+ call spir_func void @foo(ptr addrspace(4) %object, i32 3)
+ ret void
+}
+
+define void @foo_stub(ptr addrspace(4) noundef %stub_object, i32 noundef %mem_order) {
+entry:
+ %object.addr = alloca ptr addrspace(4)
+ %object.addr.ascast = addrspacecast ptr %object.addr to ptr addrspace(4)
+ store ptr addrspace(4) %stub_object, ptr addrspace(4) %object.addr.ascast
+ ret void
+}
+
+define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
+ tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
+ ret void
+}
+
|
This PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case.