Skip to content

Commit 0e49a40

Browse files
committed
[ARM] Cleanup for the MVETailPrediction pass
This strips out a lot of the code that should no longer be needed from the MVETailPredictionPass, leaving the important part - find active lane mask instructions and convert them to VCTP operations. Differential Revision: https://reviews.llvm.org/D91866
1 parent 8057ebf commit 0e49a40

File tree

3 files changed

+201
-255
lines changed

3 files changed

+201
-255
lines changed

llvm/lib/Target/ARM/MVETailPredication.cpp

Lines changed: 44 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,13 @@
2222
/// The HardwareLoops pass inserts intrinsics identifying loops that the
2323
/// backend will attempt to convert into a low-overhead loop. The vectorizer is
2424
/// responsible for generating a vectorized loop in which the lanes are
25-
/// predicated upon the iteration counter. This pass looks at these predicated
26-
/// vector loops, that are targets for low-overhead loops, and prepares it for
27-
/// code generation. Once the vectorizer has produced a masked loop, there's a
28-
/// couple of final forms:
29-
/// - A tail-predicated loop, with implicit predication.
30-
/// - A loop containing multiple VCPT instructions, predicating multiple VPT
31-
/// blocks of instructions operating on different vector types.
32-
///
33-
/// This pass:
34-
/// 1) Checks if the predicates of the masked load/store instructions are
35-
/// generated by intrinsic @llvm.get.active.lanes(). This intrinsic consumes
36-
/// the the scalar loop tripcount as its second argument, which we extract
37-
/// to set up the number of elements processed by the loop.
38-
/// 2) Intrinsic @llvm.get.active.lanes() is then replaced by the MVE target
39-
/// specific VCTP intrinsic to represent the effect of tail predication.
40-
/// This will be picked up by the ARM Low-overhead loop pass, which performs
41-
/// the final transformation to a DLSTP or WLSTP tail-predicated loop.
25+
/// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
26+
/// get.active.lane.mask intrinsic and attempts to convert them to VCTP
27+
/// instructions. This will be picked up by the ARM Low-overhead loop pass later
28+
/// in the backend, which performs the final transformation to a DLSTP or WLSTP
29+
/// tail-predicated loop.
30+
//
31+
//===----------------------------------------------------------------------===//
4232

4333
#include "ARM.h"
4434
#include "ARMSubtarget.h"
@@ -57,6 +47,7 @@
5747
#include "llvm/InitializePasses.h"
5848
#include "llvm/Support/Debug.h"
5949
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
50+
#include "llvm/Transforms/Utils/Local.h"
6051
#include "llvm/Transforms/Utils/LoopUtils.h"
6152
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
6253

@@ -112,23 +103,18 @@ class MVETailPredication : public LoopPass {
112103
bool runOnLoop(Loop *L, LPPassManager&) override;
113104

114105
private:
115-
/// Perform the relevant checks on the loop and convert if possible.
116-
bool TryConvert(Value *TripCount);
117-
118-
/// Return whether this is a vectorized loop, that contains masked
119-
/// load/stores.
120-
bool IsPredicatedVectorLoop();
106+
/// Perform the relevant checks on the loop and convert active lane masks if
107+
/// possible.
108+
bool TryConvertActiveLaneMask(Value *TripCount);
121109

122110
/// Perform several checks on the arguments of @llvm.get.active.lane.mask
123111
/// intrinsic. E.g., check that the loop induction variable and the element
124112
/// count are of the form we expect, and also perform overflow checks for
125113
/// the new expressions that are created.
126-
bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount,
127-
FixedVectorType *VecTy);
114+
bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
128115

129116
/// Insert the intrinsic to represent the effect of tail predication.
130-
void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *TripCount,
131-
FixedVectorType *VecTy);
117+
void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *TripCount);
132118

133119
/// Rematerialize the iteration count in exit blocks, which enables
134120
/// ARMLowOverheadLoops to better optimise away loop update statements inside
@@ -138,25 +124,6 @@ class MVETailPredication : public LoopPass {
138124

139125
} // end namespace
140126

141-
static bool IsDecrement(Instruction &I) {
142-
auto *Call = dyn_cast<IntrinsicInst>(&I);
143-
if (!Call)
144-
return false;
145-
146-
Intrinsic::ID ID = Call->getIntrinsicID();
147-
return ID == Intrinsic::loop_decrement_reg;
148-
}
149-
150-
static bool IsMasked(Instruction *I) {
151-
auto *Call = dyn_cast<IntrinsicInst>(I);
152-
if (!Call)
153-
return false;
154-
155-
Intrinsic::ID ID = Call->getIntrinsicID();
156-
return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load ||
157-
isGatherScatter(Call);
158-
}
159-
160127
bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
161128
if (skipLoop(L) || !EnableTailPredication)
162129
return false;
@@ -207,147 +174,11 @@ bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
207174
return false;
208175
}
209176

210-
// Search for the hardware loop intrinic that decrements the loop counter.
211-
IntrinsicInst *Decrement = nullptr;
212-
for (auto *BB : L->getBlocks()) {
213-
for (auto &I : *BB) {
214-
if (IsDecrement(I)) {
215-
Decrement = cast<IntrinsicInst>(&I);
216-
break;
217-
}
218-
}
219-
}
220-
221-
if (!Decrement)
222-
return false;
223-
224-
LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"
225-
<< *Decrement << "\n");
226-
227-
if (!TryConvert(Setup->getArgOperand(0))) {
228-
LLVM_DEBUG(dbgs() << "ARM TP: Can't tail-predicate this loop.\n");
229-
return false;
230-
}
231-
232-
return true;
233-
}
234-
235-
static FixedVectorType *getVectorType(IntrinsicInst *I) {
236-
unsigned ID = I->getIntrinsicID();
237-
FixedVectorType *VecTy;
238-
if (ID == Intrinsic::masked_load || isGather(I)) {
239-
if (ID == Intrinsic::arm_mve_vldr_gather_base_wb ||
240-
ID == Intrinsic::arm_mve_vldr_gather_base_wb_predicated)
241-
// then the type is a StructType
242-
VecTy = dyn_cast<FixedVectorType>(I->getType()->getContainedType(0));
243-
else
244-
VecTy = dyn_cast<FixedVectorType>(I->getType());
245-
} else if (ID == Intrinsic::masked_store) {
246-
VecTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
247-
} else {
248-
VecTy = dyn_cast<FixedVectorType>(I->getOperand(2)->getType());
249-
}
250-
assert(VecTy && "No scalable vectors expected here");
251-
return VecTy;
252-
}
253-
254-
bool MVETailPredication::IsPredicatedVectorLoop() {
255-
// Check that the loop contains at least one masked load/store intrinsic.
256-
// We only support 'normal' vector instructions - other than masked
257-
// load/stores.
258-
bool ActiveLaneMask = false;
259-
for (auto *BB : L->getBlocks()) {
260-
for (auto &I : *BB) {
261-
auto *Int = dyn_cast<IntrinsicInst>(&I);
262-
if (!Int)
263-
continue;
264-
265-
switch (Int->getIntrinsicID()) {
266-
case Intrinsic::get_active_lane_mask:
267-
ActiveLaneMask = true;
268-
continue;
269-
case Intrinsic::sadd_sat:
270-
case Intrinsic::uadd_sat:
271-
case Intrinsic::ssub_sat:
272-
case Intrinsic::usub_sat:
273-
case Intrinsic::vector_reduce_add:
274-
continue;
275-
case Intrinsic::fma:
276-
case Intrinsic::trunc:
277-
case Intrinsic::rint:
278-
case Intrinsic::round:
279-
case Intrinsic::floor:
280-
case Intrinsic::ceil:
281-
case Intrinsic::fabs:
282-
if (ST->hasMVEFloatOps())
283-
continue;
284-
break;
285-
default:
286-
break;
287-
}
288-
if (IsMasked(&I)) {
289-
auto *VecTy = getVectorType(Int);
290-
unsigned Lanes = VecTy->getNumElements();
291-
unsigned ElementWidth = VecTy->getScalarSizeInBits();
292-
// MVE vectors are 128-bit, but don't support 128 x i1.
293-
// TODO: Can we support vectors larger than 128-bits?
294-
unsigned MaxWidth = TTI->getRegisterBitWidth(true);
295-
if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth)
296-
return false;
297-
MaskedInsts.push_back(cast<IntrinsicInst>(&I));
298-
continue;
299-
}
300-
301-
for (const Use &U : Int->args()) {
302-
if (isa<VectorType>(U->getType()))
303-
return false;
304-
}
305-
}
306-
}
307-
308-
if (!ActiveLaneMask) {
309-
LLVM_DEBUG(dbgs() << "ARM TP: No get.active.lane.mask intrinsic found.\n");
310-
return false;
311-
}
312-
return !MaskedInsts.empty();
313-
}
314-
315-
// Look through the exit block to see whether there's a duplicate predicate
316-
// instruction. This can happen when we need to perform a select on values
317-
// from the last and previous iteration. Instead of doing a straight
318-
// replacement of that predicate with the vctp, clone the vctp and place it
319-
// in the block. This means that the VPR doesn't have to be live into the
320-
// exit block which should make it easier to convert this loop into a proper
321-
// tail predicated loop.
322-
static void Cleanup(SetVector<Instruction*> &MaybeDead, Loop *L) {
323-
BasicBlock *Exit = L->getUniqueExitBlock();
324-
if (!Exit) {
325-
LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n");
326-
return;
327-
}
328-
329-
// Drop references and add operands to check for dead.
330-
SmallPtrSet<Instruction*, 4> Dead;
331-
while (!MaybeDead.empty()) {
332-
auto *I = MaybeDead.front();
333-
MaybeDead.remove(I);
334-
if (I->hasNUsesOrMore(1))
335-
continue;
177+
LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
336178

337-
for (auto &U : I->operands())
338-
if (auto *OpI = dyn_cast<Instruction>(U))
339-
MaybeDead.insert(OpI);
179+
bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
340180

341-
Dead.insert(I);
342-
}
343-
344-
for (auto *I : Dead) {
345-
LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump());
346-
I->eraseFromParent();
347-
}
348-
349-
for (auto I : L->blocks())
350-
DeleteDeadPHIs(I);
181+
return Changed;
351182
}
352183

353184
// The active lane intrinsic has this form:
@@ -368,15 +199,18 @@ static void Cleanup(SetVector<Instruction*> &MaybeDead, Loop *L) {
368199
// 3) The IV must be an induction phi with an increment equal to the
369200
// vector width.
370201
bool MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
371-
Value *TripCount, FixedVectorType *VecTy) {
202+
Value *TripCount) {
372203
bool ForceTailPredication =
373204
EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
374205
EnableTailPredication == TailPredication::ForceEnabled;
375206

376207
Value *ElemCount = ActiveLaneMask->getOperand(1);
377208
auto *EC= SE->getSCEV(ElemCount);
378209
auto *TC = SE->getSCEV(TripCount);
379-
int VectorWidth = VecTy->getNumElements();
210+
int VectorWidth =
211+
cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
212+
if (VectorWidth != 4 && VectorWidth != 8 && VectorWidth != 16)
213+
return false;
380214
ConstantInt *ConstElemCount = nullptr;
381215

382216
// 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
@@ -503,21 +337,22 @@ bool MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
503337
if (VectorWidth == StepValue)
504338
return true;
505339

506-
LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue << " doesn't match "
507-
"vector width " << VectorWidth << "\n");
340+
LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
341+
<< " doesn't match vector width " << VectorWidth << "\n");
508342

509343
return false;
510344
}
511345

512346
void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
513-
Value *TripCount, FixedVectorType *VecTy) {
347+
Value *TripCount) {
514348
IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
515349
Module *M = L->getHeader()->getModule();
516350
Type *Ty = IntegerType::get(M->getContext(), 32);
517-
unsigned VectorWidth = VecTy->getNumElements();
351+
unsigned VectorWidth =
352+
cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
518353

519354
// Insert a phi to count the number of elements processed by the loop.
520-
Builder.SetInsertPoint(L->getHeader()->getFirstNonPHI() );
355+
Builder.SetInsertPoint(L->getHeader()->getFirstNonPHI());
521356
PHINode *Processed = Builder.CreatePHI(Ty, 2);
522357
Processed->addIncoming(ActiveLaneMask->getOperand(1), L->getLoopPreheader());
523358

@@ -553,50 +388,36 @@ void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
553388
<< "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
554389
}
555390

556-
bool MVETailPredication::TryConvert(Value *TripCount) {
557-
if (!IsPredicatedVectorLoop()) {
558-
LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n");
391+
bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
392+
SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
393+
for (auto *BB : L->getBlocks())
394+
for (auto &I : *BB)
395+
if (auto *Int = dyn_cast<IntrinsicInst>(&I))
396+
if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
397+
ActiveLaneMasks.push_back(Int);
398+
399+
if (ActiveLaneMasks.empty())
559400
return false;
560-
}
561401

562402
LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
563-
SetVector<Instruction*> Predicates;
564-
565-
auto getPredicateOp = [](IntrinsicInst *I) {
566-
unsigned IntrinsicID = I->getIntrinsicID();
567-
if (IntrinsicID == Intrinsic::arm_mve_vldr_gather_offset_predicated ||
568-
IntrinsicID == Intrinsic::arm_mve_vstr_scatter_offset_predicated)
569-
return 5;
570-
return (IntrinsicID == Intrinsic::masked_load || isGather(I)) ? 2 : 3;
571-
};
572-
573-
// Walk through the masked intrinsics and try to find whether the predicate
574-
// operand is generated by intrinsic @llvm.get.active.lane.mask().
575-
for (auto *I : MaskedInsts) {
576-
Value *PredOp = I->getArgOperand(getPredicateOp(I));
577-
auto *Predicate = dyn_cast<Instruction>(PredOp);
578-
if (!Predicate || Predicates.count(Predicate))
579-
continue;
580403

581-
auto *ActiveLaneMask = dyn_cast<IntrinsicInst>(Predicate);
582-
if (!ActiveLaneMask ||
583-
ActiveLaneMask->getIntrinsicID() != Intrinsic::get_active_lane_mask)
584-
continue;
585-
586-
Predicates.insert(Predicate);
404+
for (auto *ActiveLaneMask : ActiveLaneMasks) {
587405
LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
588406
<< *ActiveLaneMask << "\n");
589407

590-
auto *VecTy = getVectorType(I);
591-
if (!IsSafeActiveMask(ActiveLaneMask, TripCount, VecTy)) {
408+
if (!IsSafeActiveMask(ActiveLaneMask, TripCount)) {
592409
LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
593410
return false;
594411
}
595412
LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP.\n");
596-
InsertVCTPIntrinsic(ActiveLaneMask, TripCount, VecTy);
413+
InsertVCTPIntrinsic(ActiveLaneMask, TripCount);
597414
}
598415

599-
Cleanup(Predicates, L);
416+
// Remove dead instructions and now dead phis.
417+
for (auto *II : ActiveLaneMasks)
418+
RecursivelyDeleteTriviallyDeadInstructions(II);
419+
for (auto I : L->blocks())
420+
DeleteDeadPHIs(I);
600421
return true;
601422
}
602423

0 commit comments

Comments
 (0)