@@ -76,6 +76,10 @@ class SPIRVEmitIntrinsics
76
76
DenseSet<Instruction *> AggrStores;
77
77
SPIRV::InstructionSet::InstructionSet InstrSet;
78
78
79
+ // map of function declarations to <pointer arg index => element type>
80
+ DenseMap<const Function *, SmallVector<std::pair<unsigned , Type *>>>
81
+ FDeclPtrTys;
82
+
79
83
// a register of Instructions that don't have a complete type definition
80
84
bool CanTodoType = true ;
81
85
unsigned TodoTypeSz = 0 ;
@@ -184,6 +188,10 @@ class SPIRVEmitIntrinsics
184
188
void deduceOperandElementTypeFunctionPointer (
185
189
CallInst *CI, SmallVector<std::pair<Value *, unsigned >> &Ops,
186
190
Type *&KnownElemTy, bool IsPostprocessing);
191
+ bool deduceOperandElementTypeFunctionRet (
192
+ Instruction *I, SmallPtrSet<Instruction *, 4 > *UncompleteRets,
193
+ const SmallPtrSet<Value *, 4 > *AskOps, bool IsPostprocessing,
194
+ Type *&KnownElemTy, Value *Op, Function *F);
187
195
188
196
CallInst *buildSpvPtrcast (Function *F, Value *Op, Type *ElemTy);
189
197
void replaceUsesOfWithSpvPtrcast (Value *Op, Type *ElemTy, Instruction *I,
@@ -205,6 +213,7 @@ class SPIRVEmitIntrinsics
205
213
bool runOnFunction (Function &F);
206
214
bool postprocessTypes (Module &M);
207
215
bool processFunctionPointers (Module &M);
216
+ void parseFunDeclarations (Module &M);
208
217
209
218
public:
210
219
static char ID;
@@ -957,6 +966,47 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
957
966
IsNewFTy ? FunctionType::get (RetTy, ArgTys, FTy->isVarArg ()) : FTy;
958
967
}
959
968
969
+ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet (
970
+ Instruction *I, SmallPtrSet<Instruction *, 4 > *UncompleteRets,
971
+ const SmallPtrSet<Value *, 4 > *AskOps, bool IsPostprocessing,
972
+ Type *&KnownElemTy, Value *Op, Function *F) {
973
+ KnownElemTy = GR->findDeducedElementType (F);
974
+ if (KnownElemTy)
975
+ return false ;
976
+ if (Type *OpElemTy = GR->findDeducedElementType (Op)) {
977
+ GR->addDeducedElementType (F, OpElemTy);
978
+ GR->addReturnType (
979
+ F, TypedPointerType::get (OpElemTy,
980
+ getPointerAddressSpace (F->getReturnType ())));
981
+ // non-recursive update of types in function uses
982
+ DenseSet<std::pair<Value *, Value *>> VisitedSubst{std::make_pair (I, Op)};
983
+ for (User *U : F->users ()) {
984
+ CallInst *CI = dyn_cast<CallInst>(U);
985
+ if (!CI || CI->getCalledFunction () != F)
986
+ continue ;
987
+ if (CallInst *AssignCI = GR->findAssignPtrTypeInstr (CI)) {
988
+ if (Type *PrevElemTy = GR->findDeducedElementType (CI)) {
989
+ updateAssignType (AssignCI, CI, PoisonValue::get (OpElemTy));
990
+ propagateElemType (CI, PrevElemTy, VisitedSubst);
991
+ }
992
+ }
993
+ }
994
+ // Non-recursive update of types in the function uncomplete returns.
995
+ // This may happen just once per a function, the latch is a pair of
996
+ // findDeducedElementType(F) / addDeducedElementType(F, ...).
997
+ // With or without the latch it is a non-recursive call due to
998
+ // UncompleteRets set to nullptr in this call.
999
+ if (UncompleteRets)
1000
+ for (Instruction *UncompleteRetI : *UncompleteRets)
1001
+ deduceOperandElementType (UncompleteRetI, nullptr , AskOps,
1002
+ IsPostprocessing);
1003
+ } else if (UncompleteRets) {
1004
+ UncompleteRets->insert (I);
1005
+ }
1006
+ TypeValidated.insert (I);
1007
+ return true ;
1008
+ }
1009
+
960
1010
// If the Instruction has Pointer operands with unresolved types, this function
961
1011
// tries to deduce them. If the Instruction has Pointer operands with known
962
1012
// types which differ from expected, this function tries to insert a bitcast to
@@ -1039,46 +1089,15 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
1039
1089
Ops.push_back (std::make_pair (Op, i));
1040
1090
}
1041
1091
} else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
1042
- Type *RetTy = CurrF->getReturnType ();
1043
- if (!isPointerTy (RetTy))
1092
+ if (!isPointerTy (CurrF->getReturnType ()))
1044
1093
return ;
1045
1094
Value *Op = Ref->getReturnValue ();
1046
1095
if (!Op)
1047
1096
return ;
1048
- if (!(KnownElemTy = GR->findDeducedElementType (CurrF))) {
1049
- if (Type *OpElemTy = GR->findDeducedElementType (Op)) {
1050
- GR->addDeducedElementType (CurrF, OpElemTy);
1051
- GR->addReturnType (CurrF, TypedPointerType::get (
1052
- OpElemTy, getPointerAddressSpace (RetTy)));
1053
- // non-recursive update of types in function uses
1054
- DenseSet<std::pair<Value *, Value *>> VisitedSubst{
1055
- std::make_pair (I, Op)};
1056
- for (User *U : CurrF->users ()) {
1057
- CallInst *CI = dyn_cast<CallInst>(U);
1058
- if (!CI || CI->getCalledFunction () != CurrF)
1059
- continue ;
1060
- if (CallInst *AssignCI = GR->findAssignPtrTypeInstr (CI)) {
1061
- if (Type *PrevElemTy = GR->findDeducedElementType (CI)) {
1062
- updateAssignType (AssignCI, CI, PoisonValue::get (OpElemTy));
1063
- propagateElemType (CI, PrevElemTy, VisitedSubst);
1064
- }
1065
- }
1066
- }
1067
- // Non-recursive update of types in the function uncomplete returns.
1068
- // This may happen just once per a function, the latch is a pair of
1069
- // findDeducedElementType(F) / addDeducedElementType(F, ...).
1070
- // With or without the latch it is a non-recursive call due to
1071
- // UncompleteRets set to nullptr in this call.
1072
- if (UncompleteRets)
1073
- for (Instruction *UncompleteRetI : *UncompleteRets)
1074
- deduceOperandElementType (UncompleteRetI, nullptr , AskOps,
1075
- IsPostprocessing);
1076
- } else if (UncompleteRets) {
1077
- UncompleteRets->insert (I);
1078
- }
1079
- TypeValidated.insert (I);
1097
+ if (deduceOperandElementTypeFunctionRet (I, UncompleteRets, AskOps,
1098
+ IsPostprocessing, KnownElemTy, Op,
1099
+ CurrF))
1080
1100
return ;
1081
- }
1082
1101
Uncomplete = isTodoType (CurrF);
1083
1102
Ops.push_back (std::make_pair (Op, 0 ));
1084
1103
} else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
@@ -2157,6 +2176,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
2157
2176
AggrConstTypes.clear ();
2158
2177
AggrStores.clear ();
2159
2178
2179
+ DenseMap<Function *, DenseMap<unsigned , Type *>> FDeclPtrTys;
2180
+
2160
2181
processParamTypesByFunHeader (CurrF, B);
2161
2182
2162
2183
// StoreInst's operand type can be changed during the next transformations,
@@ -2180,6 +2201,31 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
2180
2201
for (auto &I : instructions (Func))
2181
2202
Worklist.push_back (&I);
2182
2203
2204
+ // Apply types parsed from demangled function declarations.
2205
+ for (auto &I : Worklist) {
2206
+ CallInst *CI = dyn_cast<CallInst>(I);
2207
+ if (!CI || !CI->getCalledFunction ())
2208
+ continue ;
2209
+ auto It = FDeclPtrTys.find (CI->getCalledFunction ());
2210
+ if (It == FDeclPtrTys.end ())
2211
+ continue ;
2212
+ unsigned Sz = CI->arg_size ();
2213
+ for (auto [Idx, ElemTy] : It->second )
2214
+ if (Idx < Sz) {
2215
+ Value *Arg = CI->getArgOperand (Idx);
2216
+ GR->addDeducedElementType (Arg, ElemTy);
2217
+ if (CallInst *Ref = dyn_cast<CallInst>(Arg))
2218
+ if (Function *RefF = Ref->getCalledFunction ();
2219
+ RefF && isPointerTy (RefF->getReturnType ()) &&
2220
+ !GR->findDeducedElementType (RefF)) {
2221
+ GR->addDeducedElementType (RefF, ElemTy);
2222
+ GR->addReturnType (RefF, TypedPointerType::get (
2223
+ ElemTy, getPointerAddressSpace (
2224
+ RefF->getReturnType ())));
2225
+ }
2226
+ }
2227
+ }
2228
+
2183
2229
// Pass forward: use operand to deduce instructions result.
2184
2230
for (auto &I : Worklist) {
2185
2231
// Don't emit intrinsincs for convergence intrinsics.
@@ -2287,9 +2333,44 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
2287
2333
return SzTodo > TodoTypeSz;
2288
2334
}
2289
2335
2336
+ // Parse and store argument types of function declarations where needed.
2337
+ void SPIRVEmitIntrinsics::parseFunDeclarations (Module &M) {
2338
+ for (auto &F : M) {
2339
+ if (!F.isDeclaration () || F.isIntrinsic ())
2340
+ continue ;
2341
+ // get the demangled name
2342
+ std::string DemangledName = getOclOrSpirvBuiltinDemangledName (F.getName ());
2343
+ if (DemangledName.empty ())
2344
+ continue ;
2345
+ // find pointer arguments
2346
+ SmallVector<unsigned > Idxs;
2347
+ for (unsigned OpIdx = 0 ; OpIdx < F.arg_size (); ++OpIdx)
2348
+ if (isPointerTy (F.getArg (OpIdx)->getType ()))
2349
+ Idxs.push_back (OpIdx);
2350
+ if (!Idxs.size ())
2351
+ continue ;
2352
+ // parse function arguments
2353
+ LLVMContext &Ctx = F.getContext ();
2354
+ SmallVector<StringRef, 10 > TypeStrs;
2355
+ SPIRV::parseBuiltinTypeStr (TypeStrs, DemangledName, Ctx);
2356
+ if (!TypeStrs.size ())
2357
+ continue ;
2358
+ // find type info for pointer arguments
2359
+ for (unsigned Idx : Idxs) {
2360
+ if (Idx >= TypeStrs.size ())
2361
+ continue ;
2362
+ if (Type *ElemTy =
2363
+ SPIRV::parseBuiltinCallArgumentType (TypeStrs[Idx].trim (), Ctx))
2364
+ FDeclPtrTys[&F].push_back (std::make_pair (Idx, ElemTy));
2365
+ }
2366
+ }
2367
+ }
2368
+
2290
2369
bool SPIRVEmitIntrinsics::runOnModule (Module &M) {
2291
2370
bool Changed = false ;
2292
2371
2372
+ parseFunDeclarations (M);
2373
+
2293
2374
TodoType.clear ();
2294
2375
for (auto &F : M)
2295
2376
Changed |= runOnFunction (F);
0 commit comments