@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
92
92
void insertPtrCastOrAssignTypeInstr (Instruction *I, IRBuilder<> &B);
93
93
void processGlobalValue (GlobalVariable &GV, IRBuilder<> &B);
94
94
void processParamTypes (Function *F, IRBuilder<> &B);
95
+ Type *deduceFunParamType (Function *F, unsigned OpIdx);
96
+ Type *deduceFunParamType (Function *F, unsigned OpIdx,
97
+ std::unordered_set<Function *> &FVisited);
95
98
96
99
public:
97
100
static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
169
172
static Type *deduceElementTypeHelper (Value *I,
170
173
std::unordered_set<Value *> &Visited,
171
174
DenseMap<Value *, Type *> &DeducedElTys) {
175
+ // allow to pass nullptr as an argument
176
+ if (!I)
177
+ return nullptr ;
178
+
172
179
// maybe already known
173
180
auto It = DeducedElTys.find (I);
174
181
if (It != DeducedElTys.end ())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
182
189
// fallback value in case when we fail to deduce a type
183
190
Type *Ty = nullptr ;
184
191
// look for known basic patterns of type inference
185
- if (auto *Ref = dyn_cast<AllocaInst>(I))
192
+ if (auto *Ref = dyn_cast<AllocaInst>(I)) {
186
193
Ty = Ref->getAllocatedType ();
187
- else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
194
+ } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
188
195
Ty = Ref->getResultElementType ();
189
- else if (auto *Ref = dyn_cast<GlobalValue>(I))
196
+ } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
190
197
Ty = Ref->getValueType ();
191
- else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
198
+ } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
192
199
Ty = deduceElementTypeHelper (Ref->getPointerOperand (), Visited,
193
200
DeducedElTys);
201
+ } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
202
+ if (Type *Src = Ref->getSrcTy (), *Dest = Ref->getDestTy ();
203
+ isPointerTy (Src) && isPointerTy (Dest))
204
+ Ty = deduceElementTypeHelper (Ref->getOperand (0 ), Visited, DeducedElTys);
205
+ }
194
206
195
207
// remember the found relationship
196
208
if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
795
807
}
796
808
}
797
809
798
- void SPIRVEmitIntrinsics::processParamTypes (Function *F, IRBuilder<> &B) {
799
- DenseMap<unsigned , Argument *> Args;
800
- unsigned i = 0 ;
801
- for (Argument &Arg : F->args ()) {
802
- if (isUntypedPointerTy (Arg.getType ()) &&
803
- DeducedElTys.find (&Arg) == DeducedElTys.end () &&
804
- !HasPointeeTypeAttr (&Arg))
805
- Args[i] = &Arg;
806
- i++;
807
- }
808
- if (Args.size () == 0 )
809
- return ;
810
+ Type *SPIRVEmitIntrinsics::deduceFunParamType (Function *F, unsigned OpIdx) {
811
+ std::unordered_set<Function *> FVisited;
812
+ return deduceFunParamType (F, OpIdx, FVisited);
813
+ }
814
+
815
+ Type *SPIRVEmitIntrinsics::deduceFunParamType (
816
+ Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
817
+ // maybe a cycle
818
+ if (FVisited.find (F) != FVisited.end ())
819
+ return nullptr ;
820
+ FVisited.insert (F);
810
821
811
- // Args contains opaque pointers without element type definition
812
- B.SetInsertPointPastAllocas (F);
813
822
std::unordered_set<Value *> Visited;
823
+ SmallVector<std::pair<Function *, unsigned >> Lookup;
824
+ // search in function's call sites
814
825
for (User *U : F->users ()) {
815
826
CallInst *CI = dyn_cast<CallInst>(U);
816
- if (!CI)
827
+ if (!CI || OpIdx >= CI-> arg_size () )
817
828
continue ;
818
- for (unsigned OpIdx = 0 ; OpIdx < CI->arg_size () && Args.size () > 0 ;
819
- OpIdx++) {
820
- auto It = Args.find (OpIdx);
821
- Argument *Arg = It == Args.end () ? nullptr : It->second ;
822
- if (!Arg)
823
- continue ;
824
- Value *OpArg = CI->getArgOperand (OpIdx);
825
- if (!isPointerTy (OpArg->getType ()))
829
+ Value *OpArg = CI->getArgOperand (OpIdx);
830
+ if (!isPointerTy (OpArg->getType ()))
831
+ continue ;
832
+ // maybe we already know operand's element type
833
+ if (auto It = DeducedElTys.find (OpArg); It != DeducedElTys.end ())
834
+ return It->second ;
835
+ // search in actual parameter's users
836
+ for (User *OpU : OpArg->users ()) {
837
+ Instruction *Inst = dyn_cast<Instruction>(OpU);
838
+ if (!Inst || Inst == CI)
826
839
continue ;
827
- // maybe we already know the operand's element type
828
- auto DeducedIt = DeducedElTys.find (OpArg);
829
- Type *ElemTy =
830
- DeducedIt == DeducedElTys.end () ? nullptr : DeducedIt->second ;
831
- if (!ElemTy) {
832
- for (User *OpU : OpArg->users ()) {
833
- if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
834
- Visited.clear ();
835
- ElemTy = deduceElementTypeHelper (Inst, Visited, DeducedElTys);
836
- if (ElemTy)
837
- break ;
838
- }
839
- }
840
+ Visited.clear ();
841
+ if (Type *Ty = deduceElementTypeHelper (Inst, Visited, DeducedElTys))
842
+ return Ty;
843
+ }
844
+ // check if it's a formal parameter of the outer function
845
+ if (!CI->getParent () || !CI->getParent ()->getParent ())
846
+ continue ;
847
+ Function *OuterF = CI->getParent ()->getParent ();
848
+ if (FVisited.find (OuterF) != FVisited.end ())
849
+ continue ;
850
+ for (unsigned i = 0 ; i < OuterF->arg_size (); ++i) {
851
+ if (OuterF->getArg (i) == OpArg) {
852
+ Lookup.push_back (std::make_pair (OuterF, i));
853
+ break ;
840
854
}
841
- if (ElemTy) {
842
- unsigned AddressSpace = getPointerAddressSpace (Arg->getType ());
855
+ }
856
+ }
857
+
858
+ // search in function parameters
859
+ for (auto &Pair : Lookup) {
860
+ if (Type *Ty = deduceFunParamType (Pair.first , Pair.second , FVisited))
861
+ return Ty;
862
+ }
863
+
864
+ return nullptr ;
865
+ }
866
+
867
+ void SPIRVEmitIntrinsics::processParamTypes (Function *F, IRBuilder<> &B) {
868
+ B.SetInsertPointPastAllocas (F);
869
+ DenseMap<Argument *, Type *> Args;
870
+ for (unsigned OpIdx = 0 ; OpIdx < F->arg_size (); ++OpIdx) {
871
+ Argument *Arg = F->getArg (OpIdx);
872
+ if (isUntypedPointerTy (Arg->getType ()) &&
873
+ DeducedElTys.find (Arg) == DeducedElTys.end () &&
874
+ !HasPointeeTypeAttr (Arg)) {
875
+ if (Type *ElemTy = deduceFunParamType (F, OpIdx)) {
843
876
CallInst *AssignPtrTyCI = buildIntrWithMD (
844
877
Intrinsic::spv_assign_ptr_type, {Arg->getType ()},
845
- Constant::getNullValue (ElemTy), Arg, {B.getInt32 (AddressSpace)}, B);
878
+ Constant::getNullValue (ElemTy), Arg,
879
+ {B.getInt32 (getPointerAddressSpace (Arg->getType ()))}, B);
846
880
DeducedElTys[AssignPtrTyCI] = ElemTy;
847
881
DeducedElTys[Arg] = ElemTy;
848
- Args.erase (It);
849
882
}
850
883
}
851
- if (Args.size () == 0 )
852
- break ;
853
884
}
854
885
}
855
886
0 commit comments