Skip to content

Commit 9872269

Browse files
committed
[ARM] WLS/LE Code Generation
Backend changes to enable WLS/LE low-overhead loops for armv8.1-m: 1) Use TTI to communicate to the HardwareLoop pass that we should try to generate intrinsics that guard the loop entry, as well as setting the loop trip count. 2) Lower the BRCOND that uses said intrinsic to an Arm specific node: ARMWLS. 3) ISelDAGToDAG the node to a new pseudo instruction: t2WhileLoopStart. 4) Add support in ArmLowOverheadLoops to handle the new pseudo instruction. Differential Revision: https://reviews.llvm.org/D63816 llvm-svn: 364733
1 parent 0384a78 commit 9872269

22 files changed

+765
-92
lines changed

llvm/lib/CodeGen/HardwareLoops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ static bool CanGenerateTest(Loop *L, Value *Count) {
294294
// Check that the icmp is checking for equality of Count and zero and that
295295
// a non-zero value results in entering the loop.
296296
auto ICmp = cast<ICmpInst>(BI->getCondition());
297+
LLVM_DEBUG(dbgs() << " - Found condition: " << *ICmp << "\n");
297298
if (!ICmp->isEquality())
298299
return false;
299300

llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,6 +2998,16 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
29982998
// Other cases are autogenerated.
29992999
break;
30003000
}
3001+
case ARMISD::WLS: {
3002+
SDValue Ops[] = { N->getOperand(1), // Loop count
3003+
N->getOperand(2), // Exit target
3004+
N->getOperand(0) };
3005+
SDNode *LoopStart =
3006+
CurDAG->getMachineNode(ARM::t2WhileLoopStart, dl, MVT::Other, Ops);
3007+
ReplaceUses(N, LoopStart);
3008+
CurDAG->RemoveDeadNode(N);
3009+
return;
3010+
}
30013011
case ARMISD::BRCOND: {
30023012
// Pattern: (ARMbrcond:void (bb:Other):$dst, (imm:i32):$cc)
30033013
// Emits: (Bcc:void (bb:Other):$dst, (imm:i32):$cc)

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,10 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
633633
if (Subtarget->hasMVEIntegerOps())
634634
addMVEVectorTypes(Subtarget->hasMVEFloatOps());
635635

636+
// Combine low-overhead loop intrinsics so that we can lower i1 types.
637+
if (Subtarget->hasLOB())
638+
setTargetDAGCombine(ISD::BRCOND);
639+
636640
if (Subtarget->hasNEON()) {
637641
addDRTypeForNEON(MVT::v2f32);
638642
addDRTypeForNEON(MVT::v8i8);
@@ -1542,6 +1546,7 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
15421546
case ARMISD::VST2LN_UPD: return "ARMISD::VST2LN_UPD";
15431547
case ARMISD::VST3LN_UPD: return "ARMISD::VST3LN_UPD";
15441548
case ARMISD::VST4LN_UPD: return "ARMISD::VST4LN_UPD";
1549+
case ARMISD::WLS: return "ARMISD::WLS";
15451550
}
15461551
return nullptr;
15471552
}
@@ -12883,6 +12888,42 @@ SDValue ARMTargetLowering::PerformCMOVToBFICombine(SDNode *CMOV, SelectionDAG &D
1288312888
return V;
1288412889
}
1288512890

12891+
static SDValue PerformHWLoopCombine(SDNode *N,
12892+
TargetLowering::DAGCombinerInfo &DCI,
12893+
const ARMSubtarget *ST) {
12894+
// Look for (brcond (xor test.set.loop.iterations, -1)
12895+
SDValue CC = N->getOperand(1);
12896+
12897+
if (CC->getOpcode() != ISD::XOR && CC->getOpcode() != ISD::SETCC)
12898+
return SDValue();
12899+
12900+
if (CC->getOperand(0)->getOpcode() != ISD::INTRINSIC_W_CHAIN)
12901+
return SDValue();
12902+
12903+
SDValue Int = CC->getOperand(0);
12904+
unsigned IntOp = cast<ConstantSDNode>(Int.getOperand(1))->getZExtValue();
12905+
if (IntOp != Intrinsic::test_set_loop_iterations)
12906+
return SDValue();
12907+
12908+
if (auto *Const = dyn_cast<ConstantSDNode>(CC->getOperand(1)))
12909+
assert(Const->isOne() && "Expected to compare against 1");
12910+
else
12911+
assert(Const->isOne() && "Expected to compare against 1");
12912+
12913+
SDLoc dl(Int);
12914+
SDValue Chain = N->getOperand(0);
12915+
SDValue Elements = Int.getOperand(2);
12916+
SDValue ExitBlock = N->getOperand(2);
12917+
12918+
// TODO: Once we start supporting tail predication, we can add another
12919+
// operand to WLS for the number of elements processed in a vector loop.
12920+
12921+
SDValue Ops[] = { Chain, Elements, ExitBlock };
12922+
SDValue Res = DCI.DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops);
12923+
DCI.DAG.ReplaceAllUsesOfValueWith(Int.getValue(1), Int.getOperand(0));
12924+
return Res;
12925+
}
12926+
1288612927
/// PerformBRCONDCombine - Target-specific DAG combining for ARMISD::BRCOND.
1288712928
SDValue
1288812929
ARMTargetLowering::PerformBRCONDCombine(SDNode *N, SelectionDAG &DAG) const {
@@ -13114,6 +13155,7 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
1311413155
case ISD::OR: return PerformORCombine(N, DCI, Subtarget);
1311513156
case ISD::XOR: return PerformXORCombine(N, DCI, Subtarget);
1311613157
case ISD::AND: return PerformANDCombine(N, DCI, Subtarget);
13158+
case ISD::BRCOND: return PerformHWLoopCombine(N, DCI, Subtarget);
1311713159
case ARMISD::ADDC:
1311813160
case ARMISD::SUBC: return PerformAddcSubcCombine(N, DCI, Subtarget);
1311913161
case ARMISD::SUBE: return PerformAddeSubeCombine(N, DCI, Subtarget);

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ class VectorType;
125125
WIN__CHKSTK, // Windows' __chkstk call to do stack probing.
126126
WIN__DBZCHK, // Windows' divide by zero check
127127

128+
WLS, // Low-overhead loops, While Loop Start
129+
128130
VCEQ, // Vector compare equal.
129131
VCEQZ, // Vector compare equal to zero.
130132
VCGE, // Vector compare greater than or equal.

llvm/lib/Target/ARM/ARMInstrInfo.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def SDT_ARMIntShiftParts : SDTypeProfile<2, 3, [SDTCisSameAs<0, 1>,
106106
SDTCisInt<0>,
107107
SDTCisInt<4>]>;
108108

109+
// TODO Add another operand for 'Size' so that we can re-use this node when we
110+
// start supporting *TP versions.
111+
def SDT_ARMWhileLoop : SDTypeProfile<0, 2, [SDTCisVT<0, i32>,
112+
SDTCisVT<1, OtherVT>]>;
113+
109114
def ARMSmlald : SDNode<"ARMISD::SMLALD", SDT_LongMac>;
110115
def ARMSmlaldx : SDNode<"ARMISD::SMLALDX", SDT_LongMac>;
111116
def ARMSmlsld : SDNode<"ARMISD::SMLSLD", SDT_LongMac>;
@@ -244,6 +249,9 @@ def SDTARMVGETLN : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisInt<1>,
244249
def ARMvgetlaneu : SDNode<"ARMISD::VGETLANEu", SDTARMVGETLN>;
245250
def ARMvgetlanes : SDNode<"ARMISD::VGETLANEs", SDTARMVGETLN>;
246251

252+
def ARMWLS : SDNode<"ARMISD::WLS", SDT_ARMWhileLoop,
253+
[SDNPHasChain]>;
254+
247255
//===----------------------------------------------------------------------===//
248256
// ARM Flag Definitions.
249257

llvm/lib/Target/ARM/ARMInstrThumb2.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5216,11 +5216,19 @@ def t2LoopDec :
52165216
t2PseudoInst<(outs GPRlr:$Rm), (ins GPRlr:$Rn, imm0_7:$size),
52175217
4, IIC_Br, []>, Sched<[WriteBr]>;
52185218

5219-
let isBranch = 1, isTerminator = 1, hasSideEffects = 1 in
5219+
let isBranch = 1, isTerminator = 1, hasSideEffects = 1 in {
5220+
def t2WhileLoopStart :
5221+
t2PseudoInst<(outs),
5222+
(ins rGPR:$elts, brtarget:$target),
5223+
4, IIC_Br, []>,
5224+
Sched<[WriteBr]>;
5225+
52205226
def t2LoopEnd :
52215227
t2PseudoInst<(outs), (ins GPRlr:$elts, brtarget:$target),
52225228
8, IIC_Br, []>, Sched<[WriteBr]>;
52235229

5230+
} // end isBranch, isTerminator, hasSideEffects
5231+
52245232
} // end isNotDuplicable
52255233

52265234
class CS<string iname, bits<4> opcode, list<dag> pattern=[]>

llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,20 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
105105
LLVM_DEBUG(dbgs() << "ARM Loops: Processing " << *ML);
106106

107107
auto IsLoopStart = [](MachineInstr &MI) {
108-
return MI.getOpcode() == ARM::t2DoLoopStart;
108+
return MI.getOpcode() == ARM::t2DoLoopStart ||
109+
MI.getOpcode() == ARM::t2WhileLoopStart;
109110
};
110111

111-
auto SearchForStart =
112-
[&IsLoopStart](MachineBasicBlock *MBB) -> MachineInstr* {
112+
// Search the given block for a loop start instruction. If one isn't found,
113+
// and there's only one predecessor block, search that one too.
114+
std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
115+
[&IsLoopStart, &SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
113116
for (auto &MI : *MBB) {
114117
if (IsLoopStart(MI))
115118
return &MI;
116119
}
120+
if (MBB->pred_size() == 1)
121+
return SearchForStart(*MBB->pred_begin());
117122
return nullptr;
118123
};
119124

@@ -122,8 +127,28 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
122127
MachineInstr *End = nullptr;
123128
bool Revert = false;
124129

125-
if (auto *Preheader = ML->getLoopPreheader())
130+
// Search the preheader for the start intrinsic, or look through the
131+
// predecessors of the header to find exactly one set.iterations intrinsic.
132+
// FIXME: I don't see why we shouldn't be supporting multiple predecessors
133+
// with potentially multiple set.loop.iterations, so we need to enable this.
134+
if (auto *Preheader = ML->getLoopPreheader()) {
126135
Start = SearchForStart(Preheader);
136+
} else {
137+
LLVM_DEBUG(dbgs() << "ARM Loops: Failed to find loop preheader!\n"
138+
<< " - Performing manual predecessor search.\n");
139+
MachineBasicBlock *Pred = nullptr;
140+
for (auto *MBB : ML->getHeader()->predecessors()) {
141+
if (!ML->contains(MBB)) {
142+
if (Pred) {
143+
LLVM_DEBUG(dbgs() << " - Found multiple out-of-loop preds.\n");
144+
Start = nullptr;
145+
break;
146+
}
147+
Pred = MBB;
148+
Start = SearchForStart(MBB);
149+
}
150+
}
151+
}
127152

128153
// Find the low-overhead loop components and decide whether or not to fall
129154
// back to a normal loop.
@@ -158,12 +183,11 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
158183
break;
159184
}
160185

161-
if (Start || Dec || End) {
162-
if (!Start || !Dec || !End)
163-
report_fatal_error("Failed to find all loop components");
164-
} else {
186+
if (!Start && !Dec && !End) {
165187
LLVM_DEBUG(dbgs() << "ARM Loops: Not a low-overhead loop.\n");
166188
return Changed;
189+
} if (!(Start && Dec && End)) {
190+
report_fatal_error("Failed to find all loop components");
167191
}
168192

169193
if (!End->getOperand(1).isMBB() ||
@@ -212,15 +236,21 @@ void ARMLowOverheadLoops::Expand(MachineLoop *ML, MachineInstr *Start,
212236
break;
213237
}
214238

239+
unsigned Opc = Start->getOpcode() == ARM::t2DoLoopStart ?
240+
ARM::t2DLS : ARM::t2WLS;
215241
MachineInstrBuilder MIB =
216-
BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(ARM::t2DLS));
217-
if (InsertPt != Start)
218-
InsertPt->eraseFromParent();
242+
BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
219243

220244
MIB.addDef(ARM::LR);
221245
MIB.add(Start->getOperand(0));
222-
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted DLS: " << *MIB);
246+
if (Opc == ARM::t2WLS)
247+
MIB.add(Start->getOperand(1));
248+
249+
if (InsertPt != Start)
250+
InsertPt->eraseFromParent();
223251
Start->eraseFromParent();
252+
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
253+
return &*MIB;
224254
};
225255

226256
// Combine the LoopDec and LoopEnd instructions into LE(TP).
@@ -234,24 +264,15 @@ void ARMLowOverheadLoops::Expand(MachineLoop *ML, MachineInstr *Start,
234264
MIB.add(End->getOperand(1));
235265
LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
236266

237-
// If there is a branch after loop end, which branches to the fallthrough
238-
// block, remove the branch.
239-
MachineBasicBlock *Latch = End->getParent();
240-
MachineInstr *Terminator = &Latch->instr_back();
241-
if (End != Terminator) {
242-
MachineBasicBlock *Exit = ML->getExitBlock();
243-
if (Latch->isLayoutSuccessor(Exit)) {
244-
LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop exit branch: "
245-
<< *Terminator);
246-
Terminator->eraseFromParent();
247-
}
248-
}
249267
End->eraseFromParent();
250268
Dec->eraseFromParent();
269+
return &*MIB;
251270
};
252271

253272
// Generate a subs, or sub and cmp, and a branch instead of an LE.
254273
// TODO: Check flags so that we can possibly generate a subs.
274+
// FIXME: Need to check that we're not trashing the CPSR when generating
275+
// the cmp.
255276
auto ExpandBranch = [this](MachineInstr *Dec, MachineInstr *End) {
256277
LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub, cmp, br.\n");
257278
// Create sub
@@ -282,12 +303,53 @@ void ARMLowOverheadLoops::Expand(MachineLoop *ML, MachineInstr *Start,
282303
Dec->eraseFromParent();
283304
};
284305

306+
// WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
307+
// beq that branches to the exit branch.
308+
// FIXME: Need to check that we're not trashing the CPSR when generating the
309+
// cmp. We could also try to generate a cbz if the value in LR is also in
310+
// another low register.
311+
auto ExpandStart = [this](MachineInstr *MI) {
312+
MachineBasicBlock *MBB = MI->getParent();
313+
MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
314+
TII->get(ARM::t2CMPri));
315+
MIB.addReg(ARM::LR);
316+
MIB.addImm(0);
317+
MIB.addImm(ARMCC::AL);
318+
MIB.addReg(ARM::CPSR);
319+
320+
MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2Bcc));
321+
MIB.add(MI->getOperand(1)); // branch target
322+
MIB.addImm(ARMCC::EQ); // condition code
323+
MIB.addReg(ARM::CPSR);
324+
};
325+
326+
// TODO: We should be able to automatically remove these branches before we
327+
// get here - probably by teaching analyzeBranch about the pseudo
328+
// instructions.
329+
// If there is an unconditional branch, after I, that just branches to the
330+
// next block, remove it.
331+
auto RemoveDeadBranch = [](MachineInstr *I) {
332+
MachineBasicBlock *BB = I->getParent();
333+
MachineInstr *Terminator = &BB->instr_back();
334+
if (Terminator->isUnconditionalBranch() && I != Terminator) {
335+
MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
336+
if (BB->isLayoutSuccessor(Succ)) {
337+
LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
338+
Terminator->eraseFromParent();
339+
}
340+
}
341+
};
342+
285343
if (Revert) {
286-
Start->eraseFromParent();
344+
if (Start->getOpcode() == ARM::t2WhileLoopStart)
345+
ExpandStart(Start);
287346
ExpandBranch(Dec, End);
347+
Start->eraseFromParent();
288348
} else {
289-
ExpandLoopStart(ML, Start);
290-
ExpandLoopEnd(ML, Dec, End);
349+
Start = ExpandLoopStart(ML, Start);
350+
RemoveDeadBranch(Start);
351+
End = ExpandLoopEnd(ML, Dec, End);
352+
RemoveDeadBranch(End);
291353
}
292354
}
293355

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ bool ARMTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
806806
default:
807807
break;
808808
case Intrinsic::set_loop_iterations:
809+
case Intrinsic::test_set_loop_iterations:
809810
case Intrinsic::loop_decrement:
810811
case Intrinsic::loop_decrement_reg:
811812
return true;
@@ -841,6 +842,7 @@ bool ARMTTIImpl::isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
841842
LLVMContext &C = L->getHeader()->getContext();
842843
HWLoopInfo.CounterInReg = true;
843844
HWLoopInfo.IsNestingLegal = false;
845+
HWLoopInfo.PerformEntryTest = true;
844846
HWLoopInfo.CountType = Type::getInt32Ty(C);
845847
HWLoopInfo.LoopDecrement = ConstantInt::get(HWLoopInfo.CountType, 1);
846848
return true;

0 commit comments

Comments
 (0)