Skip to content

Commit 6d0c311

Browse files
improve GEP support
1 parent 7c53f10 commit 6d0c311

File tree

3 files changed

+133
-45
lines changed

3 files changed

+133
-45
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 121 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "SPIRVSubtarget.h"
1818
#include "SPIRVTargetMachine.h"
1919
#include "SPIRVUtils.h"
20+
#include "llvm/ADT/DenseSet.h"
2021
#include "llvm/IR/IRBuilder.h"
2122
#include "llvm/IR/InstIterator.h"
2223
#include "llvm/IR/InstVisitor.h"
@@ -80,7 +81,8 @@ class SPIRVEmitIntrinsics
8081
unsigned TodoTypeSz = 0;
8182
DenseMap<Value *, bool> TodoType;
8283
void insertTodoType(Value *Op) {
83-
if (CanTodoType) {
84+
// TODO: add isa<CallInst>(Op) to no-insert
85+
if (CanTodoType && !isa<GetElementPtrInst>(Op)) {
8486
auto It = TodoType.try_emplace(Op, true);
8587
if (It.second)
8688
++TodoTypeSz;
@@ -94,9 +96,14 @@ class SPIRVEmitIntrinsics
9496
}
9597
}
9698
bool isTodoType(Value *Op) {
99+
if (isa<GetElementPtrInst>(Op))
100+
return false;
97101
auto It = TodoType.find(Op);
98102
return It != TodoType.end() && It->second;
99103
}
104+
// a register of Instructions that were visited by deduceOperandElementType()
105+
// to validate operand types with an instruction
106+
std::unordered_set<Instruction *> TypeValidated;
100107

101108
// well known result types of builtins
102109
enum WellKnownTypes { Event };
@@ -178,10 +185,17 @@ class SPIRVEmitIntrinsics
178185
Type *&KnownElemTy, bool IsPostprocessing);
179186

180187
CallInst *buildSpvPtrcast(Function *F, Value *Op, Type *ElemTy);
181-
void propagateElemType(Value *Op, Type *ElemTy);
182-
void propagateElemTypeRec(Value *Op, Type *PtrElemTy, Type *CastElemTy);
188+
void replaceUsesOfWithSpvPtrcast(Value *Op, Type *ElemTy, Instruction *I,
189+
DenseMap<Function *, CallInst *> Ptrcasts);
190+
void propagateElemType(Value *Op, Type *ElemTy,
191+
DenseSet<std::pair<Value *, Value *>> &VisitedSubst);
192+
void
193+
propagateElemTypeRec(Value *Op, Type *PtrElemTy, Type *CastElemTy,
194+
DenseSet<std::pair<Value *, Value *>> &VisitedSubst);
183195
void propagateElemTypeRec(Value *Op, Type *PtrElemTy, Type *CastElemTy,
184-
std::unordered_set<Value *> &Visited);
196+
DenseSet<std::pair<Value *, Value *>> &VisitedSubst,
197+
std::unordered_set<Value *> &Visited,
198+
DenseMap<Function *, CallInst *> Ptrcasts);
185199

186200
void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);
187201

@@ -439,38 +453,63 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
439453
return PtrCasted;
440454
}
441455

442-
void SPIRVEmitIntrinsics::propagateElemType(Value *Op, Type *ElemTy) {
456+
void SPIRVEmitIntrinsics::replaceUsesOfWithSpvPtrcast(
457+
Value *Op, Type *ElemTy, Instruction *I,
458+
DenseMap<Function *, CallInst *> Ptrcasts) {
459+
Function *F = I->getParent()->getParent();
460+
CallInst *PtrCastedI = nullptr;
461+
auto It = Ptrcasts.find(F);
462+
if (It == Ptrcasts.end()) {
463+
PtrCastedI = buildSpvPtrcast(F, Op, ElemTy);
464+
Ptrcasts[F] = PtrCastedI;
465+
} else {
466+
PtrCastedI = It->second;
467+
}
468+
I->replaceUsesOfWith(Op, PtrCastedI);
469+
}
470+
471+
void SPIRVEmitIntrinsics::propagateElemType(
472+
Value *Op, Type *ElemTy,
473+
DenseSet<std::pair<Value *, Value *>> &VisitedSubst) {
474+
DenseMap<Function *, CallInst *> Ptrcasts;
443475
SmallVector<User *> Users(Op->users());
444476
for (auto *U : Users) {
445-
if (!isa<Instruction>(U))
477+
if (!isa<Instruction>(U) || isa<BitCastInst>(U) || isSpvIntrinsic(U))
446478
continue;
447-
if (isa<BitCastInst>(U) || isa<GetElementPtrInst>(U) || isSpvIntrinsic(U))
479+
if (!VisitedSubst.insert(std::make_pair(U, Op)).second)
448480
continue;
449-
U->replaceUsesOfWith(
450-
Op, buildSpvPtrcast(dyn_cast<Instruction>(U)->getParent()->getParent(),
451-
Op, ElemTy));
481+
Instruction *UI = dyn_cast<Instruction>(U);
482+
// If the instruction was validated already, we need to keep it valid by
483+
// keeping current Op type.
484+
if (isa<GetElementPtrInst>(UI) ||
485+
TypeValidated.find(UI) != TypeValidated.end())
486+
replaceUsesOfWithSpvPtrcast(Op, ElemTy, UI, Ptrcasts);
452487
}
453488
}
454489

455-
void SPIRVEmitIntrinsics::propagateElemTypeRec(Value *Op, Type *PtrElemTy,
456-
Type *CastElemTy) {
457-
if (!isNestedPointer(PtrElemTy))
458-
return;
490+
void SPIRVEmitIntrinsics::propagateElemTypeRec(
491+
Value *Op, Type *PtrElemTy, Type *CastElemTy,
492+
DenseSet<std::pair<Value *, Value *>> &VisitedSubst) {
459493
std::unordered_set<Value *> Visited;
460-
propagateElemTypeRec(Op, PtrElemTy, CastElemTy, Visited);
494+
DenseMap<Function *, CallInst *> Ptrcasts;
495+
propagateElemTypeRec(Op, PtrElemTy, CastElemTy, VisitedSubst, Visited,
496+
Ptrcasts);
461497
}
462498

463499
void SPIRVEmitIntrinsics::propagateElemTypeRec(
464500
Value *Op, Type *PtrElemTy, Type *CastElemTy,
465-
std::unordered_set<Value *> &Visited) {
501+
DenseSet<std::pair<Value *, Value *>> &VisitedSubst,
502+
std::unordered_set<Value *> &Visited,
503+
DenseMap<Function *, CallInst *> Ptrcasts) {
466504
if (!Visited.insert(Op).second)
467505
return;
468506
SmallVector<User *> Users(Op->users());
469507
for (auto *U : Users) {
470-
if (!isa<Instruction>(U))
508+
if (!isa<Instruction>(U) || isa<BitCastInst>(U) || isSpvIntrinsic(U))
471509
continue;
472-
if (isa<BitCastInst>(U) || isSpvIntrinsic(U))
510+
if (!VisitedSubst.insert(std::make_pair(U, Op)).second)
473511
continue;
512+
/*
474513
if (auto *Ref = dyn_cast<GetElementPtrInst>(U)) {
475514
CallInst *AssignCI = GR->findAssignPtrTypeInstr(Ref);
476515
if (AssignCI && Ref->getPointerOperand() == Op) {
@@ -485,14 +524,18 @@ void SPIRVEmitIntrinsics::propagateElemTypeRec(
485524
assert(NewElemTy && "Expected valid GEP indices");
486525
updateAssignType(AssignCI, Ref, PoisonValue::get(NewElemTy));
487526
// recursively propagate change
488-
if (isNestedPointer(NewElemTy))
489-
propagateElemTypeRec(Ref, NewElemTy, PrevElemTy, Visited);
527+
propagateElemTypeRec(Ref, NewElemTy, PrevElemTy, VisitedSubst, Visited,
528+
Ptrcasts);
490529
}
491530
continue;
492531
}
493-
U->replaceUsesOfWith(
494-
Op, buildSpvPtrcast(dyn_cast<Instruction>(U)->getParent()->getParent(),
495-
Op, CastElemTy));
532+
*/
533+
Instruction *UI = dyn_cast<Instruction>(U);
534+
// If the instruction was validated already, we need to keep it valid by
535+
// keeping current Op type.
536+
if (isa<GetElementPtrInst>(UI) ||
537+
TypeValidated.find(UI) != TypeValidated.end())
538+
replaceUsesOfWithSpvPtrcast(Op, CastElemTy, UI, Ptrcasts);
496539
}
497540
}
498541

@@ -600,13 +643,34 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
600643
if (auto *Ref = dyn_cast<AllocaInst>(I)) {
601644
maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8);
602645
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
603-
Ty = Ref->getResultElementType();
646+
// TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
647+
// useful here
648+
if (isNestedPointer(Ref->getSourceElementType())) {
649+
Ty = Ref->getSourceElementType();
650+
for (Use &U : drop_begin(Ref->indices()))
651+
Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
652+
} else {
653+
Ty = Ref->getResultElementType();
654+
}
655+
/*
656+
if (Type *PtrElemTy = GR->findDeducedElementType(Ref->getPointerOperand()))
657+
{ Ty = PtrElemTy; for (Use &U : drop_begin(Ref->indices())) Ty =
658+
GetElementPtrInst::getTypeAtIndex(Ty, U.get()); if
659+
(isTodoType(Ref->getPointerOperand())) insertTodoType(Ref); } else if
660+
(isNestedPointer(Ref->getSourceElementType())) { Ty =
661+
Ref->getSourceElementType(); for (Use &U : drop_begin(Ref->indices())) Ty =
662+
GetElementPtrInst::getTypeAtIndex(Ty, U.get()); } else { Ty =
663+
Ref->getResultElementType();
664+
}
665+
*/
666+
/*
604667
if (isNestedPointer(Ref->getSourceElementType())) {
605668
Type *PtrElemTy = GR->findDeducedElementType(Ref->getPointerOperand());
606669
Ty = PtrElemTy ? PtrElemTy : Ref->getSourceElementType();
607670
for (Use &U : drop_begin(Ref->indices()))
608671
Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
609672
}
673+
*/
610674
} else if (auto *Ref = dyn_cast<LoadInst>(I)) {
611675
Value *Op = Ref->getPointerOperand();
612676
Type *KnownTy = GR->findDeducedElementType(Op);
@@ -934,7 +998,6 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
934998
Uncomplete = isTodoType(I);
935999
Ops.push_back(std::make_pair(Ref->getPointerOperand(), 0));
9361000
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
937-
// TODO: ensure that getPointerOperand() and GEP result type are consistent
9381001
if (GR->findDeducedElementType(Ref->getPointerOperand()))
9391002
return;
9401003
KnownElemTy = Ref->getSourceElementType();
@@ -992,17 +1055,20 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
9921055
GR->addDeducedElementType(CurrF, OpElemTy);
9931056
GR->addReturnType(CurrF, TypedPointerType::get(
9941057
OpElemTy, getPointerAddressSpace(RetTy)));
1058+
DenseSet<std::pair<Value *, Value *>> VisitedSubst{
1059+
std::make_pair(I, Op)};
9951060
for (User *U : CurrF->users()) {
9961061
CallInst *CI = dyn_cast<CallInst>(U);
9971062
if (!CI || CI->getCalledFunction() != CurrF)
9981063
continue;
9991064
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
10001065
if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
10011066
updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
1002-
propagateElemType(CI, PrevElemTy);
1067+
propagateElemType(CI, PrevElemTy, VisitedSubst);
10031068
}
10041069
}
10051070
}
1071+
TypeValidated.insert(I);
10061072
}
10071073
return;
10081074
}
@@ -1075,7 +1141,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
10751141
} else {
10761142
Type *PrevElemTy = GR->findDeducedElementType(Op);
10771143
updateAssignType(AssignCI, Op, OpTyVal);
1078-
propagateElemTypeRec(Op, KnownElemTy, PrevElemTy);
1144+
DenseSet<std::pair<Value *, Value *>> VisitedSubst{
1145+
std::make_pair(I, Op)};
1146+
propagateElemTypeRec(Op, KnownElemTy, PrevElemTy, VisitedSubst);
10791147
}
10801148
} else {
10811149
eraseTodoType(Op);
@@ -1087,6 +1155,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
10871155
I->setOperand(OpIt.second, PtrCastI);
10881156
}
10891157
}
1158+
TypeValidated.insert(I);
10901159
}
10911160

10921161
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
@@ -1293,10 +1362,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
12931362
Type *VTy = V->getType();
12941363

12951364
// A couple of sanity checks.
1296-
assert((isPointerTy(VTy) ||
1297-
(isa<TargetExtType>(VTy) &&
1298-
isTypedPointerWrapper(dyn_cast<TargetExtType>(VTy)))) &&
1299-
"Expect a pointer type!");
1365+
assert((isPointerTy(VTy)) && "Expect a pointer type!");
13001366
if (Type *ElemTy = getPointeeType(VTy))
13011367
if (ElemTy != AssignedType)
13021368
report_fatal_error("Unexpected pointer element type!");
@@ -1329,6 +1395,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
13291395
void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
13301396
Instruction *I, Value *Pointer, Type *ExpectedElementType,
13311397
unsigned OperandToReplace, IRBuilder<> &B) {
1398+
TypeValidated.insert(I);
13321399
// If Pointer is the result of nop BitCastInst (ptr -> ptr), use the source
13331400
// pointer instead. The BitCastInst should be later removed when visited.
13341401
while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
@@ -1392,8 +1459,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
13921459
// uncomplete, update spv_assign_ptr_type arguments.
13931460
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Pointer)) {
13941461
Type *PrevElemTy = GR->findDeducedElementType(Pointer);
1462+
assert(PrevElemTy);
1463+
DenseSet<std::pair<Value *, Value *>> VisitedSubst{
1464+
std::make_pair(I, Pointer)};
13951465
updateAssignType(AssignCI, Pointer, ExpectedElementVal);
1396-
propagateElemTypeRec(Pointer, ExpectedElementType, PrevElemTy);
1466+
propagateElemType(Pointer, PrevElemTy, VisitedSubst);
13971467
} else {
13981468
buildAssignPtr(B, ExpectedElementType, Pointer);
13991469
}
@@ -1422,15 +1492,20 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
14221492
}
14231493
if (SI) {
14241494
Value *Op = SI->getValueOperand();
1495+
Value *Pointer = SI->getPointerOperand();
1496+
// if (!GR->findDeducedElementType(Pointer) || isTodoType(Pointer)) {
14251497
Type *OpTy = Op->getType();
14261498
if (auto *OpI = dyn_cast<Instruction>(Op))
14271499
OpTy = restoreMutatedType(GR, OpI, OpTy);
14281500
if (OpTy == Op->getType())
14291501
OpTy = deduceElementTypeByValueDeep(OpTy, Op, false);
1430-
return replacePointerOperandWithPtrCast(I, SI->getPointerOperand(), OpTy, 1,
1431-
B);
1432-
} else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
1502+
replacePointerOperandWithPtrCast(I, Pointer, OpTy, 1, B);
1503+
//}
1504+
return;
1505+
}
1506+
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
14331507
Value *Pointer = LI->getPointerOperand();
1508+
// if (!GR->findDeducedElementType(Pointer) || isTodoType(Pointer)) {
14341509
Type *OpTy = LI->getType();
14351510
if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
14361511
// TODO: isNestedPointer() instead of dyn_cast<PointerType>
@@ -1443,8 +1518,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
14431518
insertTodoType(Pointer);
14441519
}
14451520
}
1446-
return replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
1447-
} else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
1521+
replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
1522+
//}
1523+
return;
1524+
}
1525+
if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
14481526
Value *Pointer = GEPI->getPointerOperand();
14491527
Type *OpTy = GEPI->getSourceElementType();
14501528
replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
@@ -1522,7 +1600,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
15221600
if (!ExpectedType || ExpectedType->isVoidTy())
15231601
continue;
15241602

1525-
if (ExpectedType->isTargetExtTy())
1603+
if (ExpectedType->isTargetExtTy() &&
1604+
!isTypedPointerWrapper(cast<TargetExtType>(ExpectedType)))
15261605
insertAssignPtrTypeTargetExt(cast<TargetExtType>(ExpectedType),
15271606
ArgOperand, B);
15281607
else
@@ -2155,6 +2234,7 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
21552234
unsigned SzTodo = TodoTypeSz;
21562235
DenseMap<Value *, SmallPtrSet<Value *, 4>> ToProcess;
21572236
for (auto [Op, Enabled] : TodoType) {
2237+
// TODO: add isa<CallInst>(Op) to continue
21582238
if (!Enabled || isa<GetElementPtrInst>(Op))
21592239
continue;
21602240
CallInst *AssignCI = GR->findAssignPtrTypeInstr(Op);
@@ -2168,11 +2248,12 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
21682248
std::unordered_set<Value *> Visited;
21692249
if (Type *ElemTy = deduceElementTypeHelper(Op, Visited, false, true)) {
21702250
if (ElemTy != KnownTy) {
2251+
DenseSet<std::pair<Value *, Value *>> VisitedSubst;
21712252
if (isa<CallInst>(Op)) {
2172-
propagateElemType(CI, ElemTy);
2253+
propagateElemType(CI, ElemTy, VisitedSubst);
21732254
} else {
21742255
updateAssignType(AssignCI, CI, PoisonValue::get(ElemTy));
2175-
propagateElemTypeRec(CI, ElemTy, KnownTy);
2256+
propagateElemTypeRec(CI, ElemTy, KnownTy, VisitedSubst);
21762257
}
21772258
eraseTodoType(Op);
21782259
continue;

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,11 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
394394
case SPIRV::OpGenericCastToPtr:
395395
validateAccessChain(STI, MRI, GR, MI);
396396
break;
397-
case SPIRV::OpPtrAccessChain:
398-
case SPIRV::OpInBoundsPtrAccessChain:
399-
if (MI.getNumOperands() == 4)
400-
validateAccessChain(STI, MRI, GR, MI);
401-
break;
397+
// case SPIRV::OpPtrAccessChain:
398+
// case SPIRV::OpInBoundsPtrAccessChain:
399+
// if (MI.getNumOperands() == 4)
400+
// validateAccessChain(STI, MRI, GR, MI);
401+
// break;
402402

403403
case SPIRV::OpFunctionCall:
404404
// ensure there is no mismatch between actual and expected arg types:

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,13 @@ inline bool isTypedPointerWrapper(const TargetExtType *ExtTy) {
282282
ExtTy->getNumTypeParameters() == 1;
283283
}
284284

285+
// True if this is an instance of PointerType or TypedPointerType.
286+
inline bool isPointerTyOrWrapper(const Type *Ty) {
287+
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty))
288+
return isTypedPointerWrapper(ExtTy);
289+
return isPointerTy(Ty);
290+
}
291+
285292
inline Type *applyWrappers(Type *Ty) {
286293
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty)) {
287294
if (isTypedPointerWrapper(ExtTy))

0 commit comments

Comments
 (0)