Skip to content

Commit a9d62df

Browse files
improve type inference for GEP
1 parent 9ba240b commit a9d62df

File tree

3 files changed

+112
-59
lines changed

3 files changed

+112
-59
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 101 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,11 @@ class SPIRVEmitIntrinsics
186186
CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
187187
Type *&KnownElemTy, bool IsPostprocessing);
188188

189-
CallInst *buildSpvPtrcast(Instruction *I, Type *ElemTy);
190-
void propagateElemTypeInUses(Instruction *I, Type *ElemTy);
189+
CallInst *buildSpvPtrcast(Value *Op, Type *ElemTy);
190+
void propagateElemType(Value *Op, Type *ElemTy);
191+
void propagateElemTypeRec(Value *Op, Type *PtrElemTy, CallInst *PtrCasted);
192+
void propagateElemTypeRec(Value *Op, Type *PtrElemTy, CallInst *PtrCasted,
193+
std::unordered_set<Value *> &Visited);
191194

192195
void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);
193196

@@ -245,10 +248,8 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
245248
}
246249

247250
bool allowEmitFakeUse(const Value *Arg) {
248-
if (const auto *II = dyn_cast<IntrinsicInst>(Arg))
249-
if (Function *F = II->getCalledFunction())
250-
if (F->getName().starts_with("llvm.spv."))
251-
return false;
251+
if (isSpvIntrinsic(Arg))
252+
return false;
252253
if (dyn_cast<AtomicCmpXchgInst>(Arg) || dyn_cast<InsertValueInst>(Arg) ||
253254
dyn_cast<UndefValue>(Arg))
254255
return false;
@@ -425,6 +426,79 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
425426
GR->addDeducedElementType(Arg, ElemTy);
426427
}
427428

429+
CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Value *Op, Type *ElemTy) {
430+
IRBuilder<> B(Op->getContext());
431+
if (auto *OpI = dyn_cast<Instruction>(Op)) {
432+
// spv_ptrcast's argument Op denotes an instruction that generates
433+
// a value, and we may use getInsertionPointAfterDef()
434+
setInsertPointAfterDef(B, OpI);
435+
} else if (auto *OpA = dyn_cast<Argument>(Op)) {
436+
B.SetInsertPointPastAllocas(OpA->getParent());
437+
B.SetCurrentDebugLocation(DebugLoc());
438+
} else {
439+
B.SetInsertPoint(CurrF->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
440+
}
441+
Type *OpTy = Op->getType();
442+
SmallVector<Type *, 2> Types = {OpTy, OpTy};
443+
SmallVector<Value *, 2> Args = {Op, buildMD(PoisonValue::get(ElemTy)),
444+
B.getInt32(getPointerAddressSpace(OpTy))};
445+
CallInst *PtrCasted =
446+
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
447+
buildAssignPtr(B, ElemTy, PtrCasted);
448+
return PtrCasted;
449+
}
450+
451+
void SPIRVEmitIntrinsics::propagateElemType(Value *Op, Type *ElemTy) {
452+
CallInst *PtrCasted = buildSpvPtrcast(Op, ElemTy);
453+
SmallVector<User *> Users(Op->users());
454+
for (auto *U : Users) {
455+
if (isa<BitCastInst>(U) || isa<GetElementPtrInst>(U) || isSpvIntrinsic(U))
456+
continue;
457+
U->replaceUsesOfWith(Op, PtrCasted);
458+
}
459+
}
460+
461+
void SPIRVEmitIntrinsics::propagateElemTypeRec(Value *Op, Type *PtrElemTy,
462+
CallInst *PtrCasted) {
463+
if (!isNestedPointer(PtrElemTy))
464+
return;
465+
std::unordered_set<Value *> Visited;
466+
propagateElemTypeRec(Op, PtrElemTy, PtrCasted, Visited);
467+
}
468+
469+
void SPIRVEmitIntrinsics::propagateElemTypeRec(
470+
Value *Op, Type *PtrElemTy, CallInst *PtrCasted,
471+
std::unordered_set<Value *> &Visited) {
472+
if (!Visited.insert(Op).second)
473+
return;
474+
SmallVector<User *> Users(Op->users());
475+
for (auto *U : Users) {
476+
if (isa<BitCastInst>(U) || isSpvIntrinsic(U))
477+
continue;
478+
if (auto *Ref = dyn_cast<GetElementPtrInst>(U)) {
479+
CallInst *AssignCI = GR->findAssignPtrTypeInstr(Ref);
480+
if (AssignCI && Ref->getPointerOperand() == Op) {
481+
Type *PrevElemTy = GR->findDeducedElementType(Ref);
482+
assert(PrevElemTy && "Expected valid element type");
483+
// evaluate a new GEP type
484+
Type *NewElemTy = PtrElemTy;
485+
for (Use &RefUse : drop_begin(Ref->indices()))
486+
NewElemTy =
487+
GetElementPtrInst::getTypeAtIndex(NewElemTy, RefUse.get());
488+
// record the new GEP type
489+
assert(NewElemTy && "Expected valid GEP indices");
490+
updateAssignType(AssignCI, Ref, PoisonValue::get(NewElemTy));
491+
// recursively propagate change
492+
if (isNestedPointer(NewElemTy))
493+
propagateElemTypeRec(Ref, NewElemTy, buildSpvPtrcast(Ref, PrevElemTy),
494+
Visited);
495+
}
496+
continue;
497+
}
498+
U->replaceUsesOfWith(Op, PtrCasted);
499+
}
500+
}
501+
428502
// Set element pointer type to the given value of ValueTy and tries to
429503
// specify this type further (recursively) by Operand value, if needed.
430504

@@ -530,7 +604,9 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
530604
maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8);
531605
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
532606
Ty = Ref->getResultElementType();
533-
if (isNestedPointer(Ty)) {
607+
if (isNestedPointer(Ref->getSourceElementType())) {
608+
Type *PtrElemTy = GR->findDeducedElementType(Ref->getPointerOperand());
609+
Ty = PtrElemTy ? PtrElemTy : Ref->getSourceElementType();
534610
for (Use &U : drop_begin(Ref->indices()))
535611
Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
536612
}
@@ -936,7 +1012,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
9361012
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
9371013
if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
9381014
updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
939-
propagateElemTypeInUses(CI, PrevElemTy);
1015+
propagateElemType(CI, PrevElemTy);
9401016
}
9411017
}
9421018
}
@@ -1012,25 +1088,13 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
10121088
GR->addAssignPtrTypeInstr(Op, CI);
10131089
} else {
10141090
updateAssignType(AssignCI, Op, OpTyVal);
1091+
propagateElemTypeRec(
1092+
Op, KnownElemTy,
1093+
buildSpvPtrcast(Op, GR->findDeducedElementType(Op)));
10151094
}
10161095
} else {
10171096
eraseTodoType(Op);
1018-
if (auto *OpI = dyn_cast<Instruction>(Op)) {
1019-
// spv_ptrcast's argument Op denotes an instruction that generates
1020-
// a value, and we may use getInsertionPointAfterDef()
1021-
setInsertPointAfterDef(B, OpI);
1022-
} else if (auto *OpA = dyn_cast<Argument>(Op)) {
1023-
B.SetInsertPointPastAllocas(OpA->getParent());
1024-
B.SetCurrentDebugLocation(DebugLoc());
1025-
} else {
1026-
B.SetInsertPoint(CurrF->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
1027-
}
1028-
SmallVector<Type *, 2> Types = {OpTy, OpTy};
1029-
SmallVector<Value *, 2> Args = {Op, buildMD(OpTyVal),
1030-
B.getInt32(getPointerAddressSpace(OpTy))};
1031-
CallInst *PtrCastI =
1032-
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
1033-
buildAssignPtr(B, KnownElemTy, PtrCastI);
1097+
CallInst *PtrCastI = buildSpvPtrcast(Op, KnownElemTy);
10341098
if (OpIt.second == std::numeric_limits<unsigned>::max())
10351099
dyn_cast<CallInst>(I)->setCalledOperand(PtrCastI);
10361100
else
@@ -2191,35 +2255,6 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
21912255
return true;
21922256
}
21932257

2194-
CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Instruction *I, Type *ElemTy) {
2195-
IRBuilder<> B(I->getContext());
2196-
B.SetInsertPoint(*I->getInsertionPointAfterDef());
2197-
B.SetCurrentDebugLocation(I->getDebugLoc());
2198-
Type *OpTy = I->getType();
2199-
SmallVector<Type *, 2> Types = {OpTy, OpTy};
2200-
SmallVector<Value *, 2> Args = {I, buildMD(PoisonValue::get(ElemTy)),
2201-
B.getInt32(getPointerAddressSpace(OpTy))};
2202-
CallInst *PtrCasted =
2203-
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
2204-
buildAssignPtr(B, ElemTy, PtrCasted);
2205-
return PtrCasted;
2206-
}
2207-
2208-
void SPIRVEmitIntrinsics::propagateElemTypeInUses(Instruction *I,
2209-
Type *ElemTy) {
2210-
CallInst *PtrCasted = buildSpvPtrcast(I, ElemTy);
2211-
SmallVector<User *> Users(I->users());
2212-
for (auto *U : Users) {
2213-
if (isa<BitCastInst>(U) || isa<GetElementPtrInst>(U))
2214-
continue;
2215-
if (const auto *II = dyn_cast<IntrinsicInst>(U))
2216-
if (Function *F = II->getCalledFunction())
2217-
if (F->getName().starts_with("llvm.spv."))
2218-
continue;
2219-
U->replaceUsesOfWith(I, PtrCasted);
2220-
}
2221-
}
2222-
22232258
// Try to deduce a better type for pointers to untyped ptr.
22242259
bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
22252260
if (!GR || TodoTypeSz == 0)
@@ -2228,7 +2263,7 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
22282263
unsigned SzTodo = TodoTypeSz;
22292264
DenseMap<Value *, SmallPtrSet<Value *, 4>> ToProcess;
22302265
for (auto [Op, Enabled] : TodoType) {
2231-
if (!Enabled)
2266+
if (!Enabled || isa<GetElementPtrInst>(Op))
22322267
continue;
22332268
CallInst *AssignCI = GR->findAssignPtrTypeInstr(Op);
22342269
Type *KnownTy = GR->findDeducedElementType(Op);
@@ -2241,14 +2276,22 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
22412276
std::unordered_set<Value *> Visited;
22422277
if (Type *ElemTy = deduceElementTypeHelper(Op, Visited, false, true)) {
22432278
if (ElemTy != KnownTy) {
2244-
if (mayUpdateOpType(CI)) {
2245-
updateAssignType(AssignCI, CI, PoisonValue::get(ElemTy));
2246-
propagateElemTypeInUses(CI, KnownTy);
2279+
if (isa<CallInst>(Op)) {
2280+
propagateElemType(CI, ElemTy);
22472281
} else {
2248-
propagateElemTypeInUses(CI, ElemTy);
2282+
updateAssignType(AssignCI, CI, PoisonValue::get(ElemTy));
2283+
propagateElemTypeRec(CI, ElemTy, buildSpvPtrcast(CI, KnownTy));
22492284
}
22502285
eraseTodoType(Op);
22512286
continue;
2287+
/*
2288+
if (mayUpdateOpType(CI)) {
2289+
updateAssignType(AssignCI, CI, PoisonValue::get(ElemTy));
2290+
propagateElemType(CI, KnownTy);
2291+
} else {
2292+
propagateElemType(CI, ElemTy);
2293+
}
2294+
*/
22522295
}
22532296
}
22542297
}

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/CodeGen/MachineInstr.h"
2323
#include "llvm/CodeGen/MachineInstrBuilder.h"
2424
#include "llvm/Demangle/Demangle.h"
25+
#include "llvm/IR/IntrinsicInst.h"
2526
#include "llvm/IR/IntrinsicsSPIRV.h"
2627
#include <queue>
2728
#include <vector>
@@ -747,4 +748,12 @@ bool isNestedPointer(const Type *Ty) {
747748
return false;
748749
}
749750

751+
bool isSpvIntrinsic(const Value *Arg) {
752+
if (const auto *II = dyn_cast<IntrinsicInst>(Arg))
753+
if (Function *F = II->getCalledFunction())
754+
if (F->getName().starts_with("llvm.spv."))
755+
return true;
756+
return false;
757+
}
758+
750759
} // namespace llvm

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI);
196196

197197
// Check if MI is a SPIR-V specific intrinsic call.
198198
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID);
199+
// Check if it's a SPIR-V specific intrinsic call.
200+
bool isSpvIntrinsic(const Value *Arg);
199201

200202
// Get type of i-th operand of the metadata node.
201203
Type *getMDOperandAsType(const MDNode *N, unsigned I);
@@ -301,7 +303,6 @@ inline Type *getPointeeType(const Type *Ty) {
301303
else if (auto *ExtTy = dyn_cast<TargetExtType>(Ty))
302304
if (isTypedPointerWrapper(ExtTy))
303305
return ExtTy->getTypeParameter(0);
304-
//return applyWrappers(ExtTy->getTypeParameter(0));
305306
}
306307
return nullptr;
307308
}

0 commit comments

Comments
 (0)