Skip to content

Commit 1d250d9

Browse files
[SPIR-V] Improve type inference in SPIR-V Backend for opaque pointers (#86283)
This PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case.
1 parent 99c40f6 commit 1d250d9

File tree

4 files changed

+138
-50
lines changed

4 files changed

+138
-50
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 77 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
9292
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
9393
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
9494
void processParamTypes(Function *F, IRBuilder<> &B);
95+
Type *deduceFunParamType(Function *F, unsigned OpIdx);
96+
Type *deduceFunParamType(Function *F, unsigned OpIdx,
97+
std::unordered_set<Function *> &FVisited);
9598

9699
public:
97100
static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
169172
static Type *deduceElementTypeHelper(Value *I,
170173
std::unordered_set<Value *> &Visited,
171174
DenseMap<Value *, Type *> &DeducedElTys) {
175+
// allow to pass nullptr as an argument
176+
if (!I)
177+
return nullptr;
178+
172179
// maybe already known
173180
auto It = DeducedElTys.find(I);
174181
if (It != DeducedElTys.end())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
182189
// fallback value in case when we fail to deduce a type
183190
Type *Ty = nullptr;
184191
// look for known basic patterns of type inference
185-
if (auto *Ref = dyn_cast<AllocaInst>(I))
192+
if (auto *Ref = dyn_cast<AllocaInst>(I)) {
186193
Ty = Ref->getAllocatedType();
187-
else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
194+
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
188195
Ty = Ref->getResultElementType();
189-
else if (auto *Ref = dyn_cast<GlobalValue>(I))
196+
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
190197
Ty = Ref->getValueType();
191-
else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
198+
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
192199
Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
193200
DeducedElTys);
201+
} else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
202+
if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
203+
isPointerTy(Src) && isPointerTy(Dest))
204+
Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
205+
}
194206

195207
// remember the found relationship
196208
if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
795807
}
796808
}
797809

798-
void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
799-
DenseMap<unsigned, Argument *> Args;
800-
unsigned i = 0;
801-
for (Argument &Arg : F->args()) {
802-
if (isUntypedPointerTy(Arg.getType()) &&
803-
DeducedElTys.find(&Arg) == DeducedElTys.end() &&
804-
!HasPointeeTypeAttr(&Arg))
805-
Args[i] = &Arg;
806-
i++;
807-
}
808-
if (Args.size() == 0)
809-
return;
810+
Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
811+
std::unordered_set<Function *> FVisited;
812+
return deduceFunParamType(F, OpIdx, FVisited);
813+
}
814+
815+
Type *SPIRVEmitIntrinsics::deduceFunParamType(
816+
Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
817+
// maybe a cycle
818+
if (FVisited.find(F) != FVisited.end())
819+
return nullptr;
820+
FVisited.insert(F);
810821

811-
// Args contains opaque pointers without element type definition
812-
B.SetInsertPointPastAllocas(F);
813822
std::unordered_set<Value *> Visited;
823+
SmallVector<std::pair<Function *, unsigned>> Lookup;
824+
// search in function's call sites
814825
for (User *U : F->users()) {
815826
CallInst *CI = dyn_cast<CallInst>(U);
816-
if (!CI)
827+
if (!CI || OpIdx >= CI->arg_size())
817828
continue;
818-
for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
819-
OpIdx++) {
820-
auto It = Args.find(OpIdx);
821-
Argument *Arg = It == Args.end() ? nullptr : It->second;
822-
if (!Arg)
823-
continue;
824-
Value *OpArg = CI->getArgOperand(OpIdx);
825-
if (!isPointerTy(OpArg->getType()))
829+
Value *OpArg = CI->getArgOperand(OpIdx);
830+
if (!isPointerTy(OpArg->getType()))
831+
continue;
832+
// maybe we already know operand's element type
833+
if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
834+
return It->second;
835+
// search in actual parameter's users
836+
for (User *OpU : OpArg->users()) {
837+
Instruction *Inst = dyn_cast<Instruction>(OpU);
838+
if (!Inst || Inst == CI)
826839
continue;
827-
// maybe we already know the operand's element type
828-
auto DeducedIt = DeducedElTys.find(OpArg);
829-
Type *ElemTy =
830-
DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
831-
if (!ElemTy) {
832-
for (User *OpU : OpArg->users()) {
833-
if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
834-
Visited.clear();
835-
ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
836-
if (ElemTy)
837-
break;
838-
}
839-
}
840+
Visited.clear();
841+
if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
842+
return Ty;
843+
}
844+
// check if it's a formal parameter of the outer function
845+
if (!CI->getParent() || !CI->getParent()->getParent())
846+
continue;
847+
Function *OuterF = CI->getParent()->getParent();
848+
if (FVisited.find(OuterF) != FVisited.end())
849+
continue;
850+
for (unsigned i = 0; i < OuterF->arg_size(); ++i) {
851+
if (OuterF->getArg(i) == OpArg) {
852+
Lookup.push_back(std::make_pair(OuterF, i));
853+
break;
840854
}
841-
if (ElemTy) {
842-
unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
855+
}
856+
}
857+
858+
// search in function parameters
859+
for (auto &Pair : Lookup) {
860+
if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
861+
return Ty;
862+
}
863+
864+
return nullptr;
865+
}
866+
867+
void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
868+
B.SetInsertPointPastAllocas(F);
869+
DenseMap<Argument *, Type *> Args;
870+
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
871+
Argument *Arg = F->getArg(OpIdx);
872+
if (isUntypedPointerTy(Arg->getType()) &&
873+
DeducedElTys.find(Arg) == DeducedElTys.end() &&
874+
!HasPointeeTypeAttr(Arg)) {
875+
if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
843876
CallInst *AssignPtrTyCI = buildIntrWithMD(
844877
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
845-
Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
878+
Constant::getNullValue(ElemTy), Arg,
879+
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
846880
DeducedElTys[AssignPtrTyCI] = ElemTy;
847881
DeducedElTys[Arg] = ElemTy;
848-
Args.erase(It);
849882
}
850883
}
851-
if (Args.size() == 0)
852-
break;
853884
}
854885
}
855886

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
479479
GVar = M->getGlobalVariable(Name);
480480
if (GVar == nullptr) {
481481
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
482+
// Module takes ownership of the global var.
482483
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
483484
GlobalValue::ExternalLinkage, nullptr,
484485
Twine(Name));

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
499499
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
500500
Register GV = I.getOperand(1).getReg();
501501
MachineRegisterInfo::def_instr_iterator II = MRI->def_instr_begin(GV);
502+
(void)II;
502503
assert(((*II).getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
503504
(*II).getOpcode() == TargetOpcode::COPY ||
504505
(*II).getOpcode() == SPIRV::OpVariable) &&
@@ -771,10 +772,13 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
771772
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
772773
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
773774
// TODO: check if we have such GV, add init, use buildGlobalVariable.
774-
Type *LLVMArrTy = ArrayType::get(
775-
IntegerType::get(GR.CurMF->getFunction().getContext(), 8), Num);
776-
GlobalVariable *GV =
777-
new GlobalVariable(LLVMArrTy, true, GlobalValue::InternalLinkage);
775+
Function &CurFunction = GR.CurMF->getFunction();
776+
Type *LLVMArrTy =
777+
ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
778+
// Module takes ownership of the global var.
779+
GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
780+
true, GlobalValue::InternalLinkage,
781+
Constant::getNullValue(LLVMArrTy));
778782
Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
779783
GR.add(GV, GR.CurMF, VarReg);
780784

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-SPIRV-DAG: OpName %[[ArgCum:.*]] "_arg_cum"
5+
; CHECK-SPIRV-DAG: OpName %[[FunTest:.*]] "test"
6+
; CHECK-SPIRV-DAG: OpName %[[Addr:.*]] "addr"
7+
; CHECK-SPIRV-DAG: OpName %[[StubObj:.*]] "stub_object"
8+
; CHECK-SPIRV-DAG: OpName %[[MemOrder:.*]] "mem_order"
9+
; CHECK-SPIRV-DAG: OpName %[[FooStub:.*]] "foo_stub"
10+
; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
11+
; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
12+
; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
13+
; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
14+
; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
15+
; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
16+
; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
17+
; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
18+
; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
19+
; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
20+
; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
21+
; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
22+
; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
23+
; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
24+
; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
25+
; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
26+
; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
27+
; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
28+
; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
29+
; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
30+
31+
define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
32+
entry:
33+
%lptr = getelementptr inbounds i32, ptr addrspace(1) %_arg_cum, i64 1
34+
%addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
35+
%object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
36+
call spir_func void @foo(ptr addrspace(4) %object, i32 3)
37+
ret void
38+
}
39+
40+
define void @foo_stub(ptr addrspace(4) noundef %stub_object, i32 noundef %mem_order) {
41+
entry:
42+
%object.addr = alloca ptr addrspace(4)
43+
%object.addr.ascast = addrspacecast ptr %object.addr to ptr addrspace(4)
44+
store ptr addrspace(4) %stub_object, ptr addrspace(4) %object.addr.ascast
45+
ret void
46+
}
47+
48+
define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
49+
tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
50+
ret void
51+
}
52+

0 commit comments

Comments
 (0)