Skip to content

Commit 5da7179

Browse files
committed
[AMDGPU] Reland: Add IR LiveReg type-based optimization
1 parent 3386d24 commit 5da7179

11 files changed

+2606
-2036
lines changed

llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp

Lines changed: 295 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class AMDGPULateCodeGenPrepare
5050
AssumptionCache *AC = nullptr;
5151
UniformityInfo *UA = nullptr;
5252

53+
SmallVector<WeakTrackingVH, 8> DeadInsts;
54+
5355
public:
5456
static char ID;
5557

@@ -81,6 +83,69 @@ class AMDGPULateCodeGenPrepare
8183
bool visitLoadInst(LoadInst &LI);
8284
};
8385

86+
using ValueToValueMap = DenseMap<const Value *, Value *>;
87+
88+
class LiveRegOptimizer {
89+
private:
90+
Module *Mod = nullptr;
91+
const DataLayout *DL = nullptr;
92+
const GCNSubtarget *ST;
93+
/// The scalar type to convert to
94+
Type *ConvertToScalar;
95+
/// The set of visited Instructions
96+
SmallPtrSet<Instruction *, 4> Visited;
97+
/// Map of Value -> Converted Value
98+
ValueToValueMap ValMap;
99+
/// Map of containing conversions from Optimal Type -> Original Type per BB.
100+
DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101+
102+
public:
103+
/// Calculate the and \p return the type to convert to given a problematic \p
104+
/// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105+
Type *calculateConvertType(Type *OriginalType);
106+
/// Convert the virtual register defined by \p V to the compatible vector of
107+
/// legal type
108+
Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109+
/// Convert the virtual register defined by \p V back to the original type \p
110+
/// ConvertType, stripping away the MSBs in cases where there was an imperfect
111+
/// fit (e.g. v2i32 -> v7i8)
112+
Value *convertFromOptType(Type *ConvertType, Instruction *V,
113+
BasicBlock::iterator &InstPt,
114+
BasicBlock *InsertBlock);
115+
/// Check for problematic PHI nodes or cross-bb values based on the value
116+
/// defined by \p I, and coerce to legal types if necessary. For problematic
117+
/// PHI node, we coerce all incoming values in a single invocation.
118+
bool optimizeLiveType(Instruction *I,
119+
SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120+
121+
// Whether or not the type should be replaced to avoid inefficient
122+
// legalization code
123+
bool shouldReplace(Type *ITy) {
124+
FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
125+
if (!VTy)
126+
return false;
127+
128+
auto TLI = ST->getTargetLowering();
129+
130+
Type *EltTy = VTy->getElementType();
131+
// If the element size is not less than the convert to scalar size, then we
132+
// can't do any bit packing
133+
if (!EltTy->isIntegerTy() ||
134+
EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
135+
return false;
136+
137+
// Only coerce illegal types
138+
TargetLoweringBase::LegalizeKind LK =
139+
TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
140+
return LK.first != TargetLoweringBase::TypeLegal;
141+
}
142+
143+
LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144+
DL = &Mod->getDataLayout();
145+
ConvertToScalar = Type::getInt32Ty(Mod->getContext());
146+
}
147+
};
148+
84149
} // end anonymous namespace
85150

86151
bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
@@ -96,20 +161,243 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
96161
const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
97162
const TargetMachine &TM = TPC.getTM<TargetMachine>();
98163
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
99-
if (ST.hasScalarSubwordLoads())
100-
return false;
101164

102165
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
103166
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
104167

168+
// "Optimize" the virtual regs that cross basic block boundaries. When
169+
// building the SelectionDAG, vectors of illegal types that cross basic blocks
170+
// will be scalarized and widened, with each scalar living in its
171+
// own register. To work around this, this optimization converts the
172+
// vectors to equivalent vectors of legal type (which are converted back
173+
// before uses in subsequent blocks), to pack the bits into fewer physical
174+
// registers (used in CopyToReg/CopyFromReg pairs).
175+
LiveRegOptimizer LRO(Mod, &ST);
176+
105177
bool Changed = false;
106-
for (auto &BB : F)
107-
for (Instruction &I : llvm::make_early_inc_range(BB))
108-
Changed |= visit(I);
109178

179+
bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
180+
181+
for (auto &BB : reverse(F))
182+
for (Instruction &I : make_early_inc_range(reverse(BB))) {
183+
Changed |= !HasScalarSubwordLoads && visit(I);
184+
Changed |= LRO.optimizeLiveType(&I, DeadInsts);
185+
}
186+
187+
RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
110188
return Changed;
111189
}
112190

191+
Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
192+
assert(OriginalType->getScalarSizeInBits() <=
193+
ConvertToScalar->getScalarSizeInBits());
194+
195+
FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
196+
197+
TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
198+
TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
199+
unsigned ConvertEltCount =
200+
(OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
201+
202+
if (OriginalSize <= ConvertScalarSize)
203+
return IntegerType::get(Mod->getContext(), ConvertScalarSize);
204+
205+
return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
206+
ConvertEltCount, false);
207+
}
208+
209+
Value *LiveRegOptimizer::convertToOptType(Instruction *V,
210+
BasicBlock::iterator &InsertPt) {
211+
FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
212+
Type *NewTy = calculateConvertType(V->getType());
213+
214+
TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
215+
TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
216+
217+
IRBuilder<> Builder(V->getParent(), InsertPt);
218+
// If there is a bitsize match, we can fit the old vector into a new vector of
219+
// desired type.
220+
if (OriginalSize == NewSize)
221+
return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
222+
223+
// If there is a bitsize mismatch, we must use a wider vector.
224+
assert(NewSize > OriginalSize);
225+
uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
226+
227+
SmallVector<int, 8> ShuffleMask;
228+
uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
229+
for (unsigned I = 0; I < OriginalElementCount; I++)
230+
ShuffleMask.push_back(I);
231+
232+
for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233+
ShuffleMask.push_back(OriginalElementCount);
234+
235+
Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
236+
return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
237+
}
238+
239+
Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
240+
BasicBlock::iterator &InsertPt,
241+
BasicBlock *InsertBB) {
242+
FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
243+
244+
TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
245+
TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
246+
247+
IRBuilder<> Builder(InsertBB, InsertPt);
248+
// If there is a bitsize match, we simply convert back to the original type.
249+
if (OriginalSize == NewSize)
250+
return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
251+
252+
// If there is a bitsize mismatch, then we must have used a wider value to
253+
// hold the bits.
254+
assert(OriginalSize > NewSize);
255+
// For wide scalars, we can just truncate the value.
256+
if (!V->getType()->isVectorTy()) {
257+
Instruction *Trunc = cast<Instruction>(
258+
Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
259+
return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
260+
}
261+
262+
// For wider vectors, we must strip the MSBs to convert back to the original
263+
// type.
264+
VectorType *ExpandedVT = VectorType::get(
265+
Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
266+
(OriginalSize / NewVTy->getScalarSizeInBits()), false);
267+
Instruction *Converted =
268+
cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
269+
270+
unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
271+
SmallVector<int, 8> ShuffleMask(NarrowElementCount);
272+
std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
273+
274+
return Builder.CreateShuffleVector(Converted, ShuffleMask);
275+
}
276+
277+
bool LiveRegOptimizer::optimizeLiveType(
278+
Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279+
SmallVector<Instruction *, 4> Worklist;
280+
SmallPtrSet<PHINode *, 4> PhiNodes;
281+
SmallPtrSet<Instruction *, 4> Defs;
282+
SmallPtrSet<Instruction *, 4> Uses;
283+
284+
Worklist.push_back(cast<Instruction>(I));
285+
while (!Worklist.empty()) {
286+
Instruction *II = Worklist.pop_back_val();
287+
288+
if (!Visited.insert(II).second)
289+
continue;
290+
291+
if (!shouldReplace(II->getType()))
292+
continue;
293+
294+
if (PHINode *Phi = dyn_cast<PHINode>(II)) {
295+
PhiNodes.insert(Phi);
296+
// Collect all the incoming values of problematic PHI nodes.
297+
for (Value *V : Phi->incoming_values()) {
298+
// Repeat the collection process for newly found PHI nodes.
299+
if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
300+
if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
301+
Worklist.push_back(OpPhi);
302+
continue;
303+
}
304+
305+
Instruction *IncInst = dyn_cast<Instruction>(V);
306+
// Other incoming value types (e.g. vector literals) are unhandled
307+
if (!IncInst && !isa<ConstantAggregateZero>(V))
308+
return false;
309+
310+
// Collect all other incoming values for coercion.
311+
if (IncInst)
312+
Defs.insert(IncInst);
313+
}
314+
}
315+
316+
// Collect all relevant uses.
317+
for (User *V : II->users()) {
318+
// Repeat the collection process for problematic PHI nodes.
319+
if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
320+
if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
321+
Worklist.push_back(OpPhi);
322+
continue;
323+
}
324+
325+
Instruction *UseInst = cast<Instruction>(V);
326+
// Collect all uses of PHINodes and any use the crosses BB boundaries.
327+
if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
328+
Uses.insert(UseInst);
329+
if (!Defs.count(II) && !isa<PHINode>(II)) {
330+
Defs.insert(II);
331+
}
332+
}
333+
}
334+
}
335+
336+
// Coerce and track the defs.
337+
for (Instruction *D : Defs) {
338+
if (!ValMap.contains(D)) {
339+
BasicBlock::iterator InsertPt = std::next(D->getIterator());
340+
Value *ConvertVal = convertToOptType(D, InsertPt);
341+
assert(ConvertVal);
342+
ValMap[D] = ConvertVal;
343+
}
344+
}
345+
346+
// Construct new-typed PHI nodes.
347+
for (PHINode *Phi : PhiNodes) {
348+
ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
349+
Phi->getNumIncomingValues(),
350+
Phi->getName() + ".tc", Phi->getIterator());
351+
}
352+
353+
// Connect all the PHI nodes with their new incoming values.
354+
for (PHINode *Phi : PhiNodes) {
355+
PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
356+
bool MissingIncVal = false;
357+
for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
358+
Value *IncVal = Phi->getIncomingValue(I);
359+
if (isa<ConstantAggregateZero>(IncVal)) {
360+
Type *NewType = calculateConvertType(Phi->getType());
361+
NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
362+
Phi->getIncomingBlock(I));
363+
} else if (ValMap.contains(IncVal))
364+
NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
365+
else
366+
MissingIncVal = true;
367+
}
368+
Instruction *DeadInst = Phi;
369+
if (MissingIncVal) {
370+
DeadInst = cast<Instruction>(ValMap[Phi]);
371+
// Do not use the dead phi
372+
ValMap[Phi] = Phi;
373+
}
374+
DeadInsts.emplace_back(DeadInst);
375+
}
376+
// Coerce back to the original type and replace the uses.
377+
for (Instruction *U : Uses) {
378+
// Replace all converted operands for a use.
379+
for (auto [OpIdx, Op] : enumerate(U->operands())) {
380+
if (ValMap.contains(Op)) {
381+
Value *NewVal = nullptr;
382+
if (BBUseValMap.contains(U->getParent()) &&
383+
BBUseValMap[U->getParent()].contains(ValMap[Op]))
384+
NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
385+
else {
386+
BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
387+
NewVal =
388+
convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
389+
InsertPt, U->getParent());
390+
BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
391+
}
392+
assert(NewVal);
393+
U->setOperand(OpIdx, NewVal);
394+
}
395+
}
396+
}
397+
398+
return true;
399+
}
400+
113401
bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
114402
unsigned AS = LI.getPointerAddressSpace();
115403
// Skip non-constant address space.
@@ -119,7 +407,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
119407
// Skip non-simple loads.
120408
if (!LI.isSimple())
121409
return false;
122-
auto *Ty = LI.getType();
410+
Type *Ty = LI.getType();
123411
// Skip aggregate types.
124412
if (Ty->isAggregateType())
125413
return false;
@@ -181,7 +469,7 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
181469
auto *NewVal = IRB.CreateBitCast(
182470
IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType());
183471
LI.replaceAllUsesWith(NewVal);
184-
RecursivelyDeleteTriviallyDeadInstructions(&LI);
472+
DeadInsts.emplace_back(&LI);
185473

186474
return true;
187475
}

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,10 +1197,10 @@ bool GCNPassConfig::addPreISel() {
11971197
AMDGPUPassConfig::addPreISel();
11981198

11991199
if (TM->getOptLevel() > CodeGenOptLevel::None)
1200-
addPass(createAMDGPULateCodeGenPreparePass());
1200+
addPass(createSinkingPass());
12011201

12021202
if (TM->getOptLevel() > CodeGenOptLevel::None)
1203-
addPass(createSinkingPass());
1203+
addPass(createAMDGPULateCodeGenPreparePass());
12041204

12051205
// Merge divergent exit nodes. StructurizeCFG won't recognize the multi-exit
12061206
// regions formed by them.

0 commit comments

Comments
 (0)