Skip to content

Commit a655973

Browse files
committed
[AArch64] Add SME peephole optimizer pass
This pass removes back-to-back smstart/smstop instructions to reduce the number of streaming mode changes in a function. The implementation as proposed doesn't aim to solve all problems yet and suggests a number of cases that can be optimized in the future.
1 parent 715e32b commit a655973

11 files changed

+241
-90
lines changed

llvm/lib/Target/AArch64/AArch64.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
5959

6060
FunctionPass *createAArch64CollectLOHPass();
6161
FunctionPass *createSMEABIPass();
62+
FunctionPass *createSMEPeepholeOptPass();
6263
ModulePass *createSVEIntrinsicOptsPass();
6364
InstructionSelector *
6465
createAArch64InstructionSelector(const AArch64TargetMachine &,
@@ -110,6 +111,7 @@ void initializeFalkorHWPFFixPass(PassRegistry&);
110111
void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
111112
void initializeLDTLSCleanupPass(PassRegistry&);
112113
void initializeSMEABIPass(PassRegistry &);
114+
void initializeSMEPeepholeOptPass(PassRegistry &);
113115
void initializeSVEIntrinsicOptsPass(PassRegistry &);
114116
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &);
115117
} // end namespace llvm

llvm/lib/Target/AArch64/AArch64TargetMachine.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ static cl::opt<bool>
167167
cl::desc("Enable SVE intrinsic opts"),
168168
cl::init(true));
169169

170+
static cl::opt<bool>
171+
EnableSMEPeepholeOpt("enable-aarch64-sme-peephole-opt", cl::init(true),
172+
cl::Hidden,
173+
cl::desc("Perform SME peephole optimization"));
174+
170175
static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
171176
cl::init(true), cl::Hidden);
172177

@@ -256,6 +261,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
256261
initializeLDTLSCleanupPass(*PR);
257262
initializeKCFIPass(*PR);
258263
initializeSMEABIPass(*PR);
264+
initializeSMEPeepholeOptPass(*PR);
259265
initializeSVEIntrinsicOptsPass(*PR);
260266
initializeAArch64SpeculationHardeningPass(*PR);
261267
initializeAArch64SLSHardeningPass(*PR);
@@ -754,6 +760,9 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
754760
}
755761

756762
void AArch64PassConfig::addMachineSSAOptimization() {
763+
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
764+
addPass(createSMEPeepholeOptPass());
765+
757766
// Run default MachineSSAOptimization first.
758767
TargetPassConfig::addMachineSSAOptimization();
759768

llvm/lib/Target/AArch64/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ add_llvm_target(AArch64CodeGen
8787
AArch64TargetObjectFile.cpp
8888
AArch64TargetTransformInfo.cpp
8989
SMEABIPass.cpp
90+
SMEPeepholeOpt.cpp
9091
SVEIntrinsicOpts.cpp
9192
AArch64SIMDInstrOpt.cpp
9293

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This pass tries to remove back-to-back (smstart, smstop) and
9+
// (smstop, smstart) sequences. The pass is conservative when it cannot
10+
// determine that it is safe to remove these sequences.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "AArch64InstrInfo.h"
14+
#include "AArch64MachineFunctionInfo.h"
15+
#include "AArch64Subtarget.h"
16+
#include "Utils/AArch64SMEAttributes.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/CodeGen/MachineBasicBlock.h"
19+
#include "llvm/CodeGen/MachineFunctionPass.h"
20+
21+
using namespace llvm;
22+
23+
#define DEBUG_TYPE "aarch64-sme-peephole-opt"
24+
25+
namespace {
26+
27+
struct SMEPeepholeOpt : public MachineFunctionPass {
28+
static char ID;
29+
30+
SMEPeepholeOpt() : MachineFunctionPass(ID) {
31+
initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry());
32+
}
33+
34+
bool runOnMachineFunction(MachineFunction &MF) override;
35+
36+
StringRef getPassName() const override {
37+
return "SME Peephole Optimization pass";
38+
}
39+
40+
void getAnalysisUsage(AnalysisUsage &AU) const override {
41+
AU.setPreservesCFG();
42+
MachineFunctionPass::getAnalysisUsage(AU);
43+
}
44+
45+
bool optimizeStartStopPairs(MachineBasicBlock &MBB,
46+
bool &HasRemainingSMChange) const;
47+
};
48+
49+
char SMEPeepholeOpt::ID = 0;
50+
51+
} // end anonymous namespace
52+
53+
static bool isConditionalStartStop(const MachineInstr *MI) {
54+
return MI->getOpcode() == AArch64::MSRpstatePseudo;
55+
}
56+
57+
static bool isMatchingStartStopPair(const MachineInstr *MI1,
58+
const MachineInstr *MI2) {
59+
// We only consider the same type of streaming mode change here, i.e.
60+
// start/stop SM, or start/stop ZA pairs.
61+
if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
62+
return false;
63+
64+
// One must be 'start', the other must be 'stop'
65+
if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
66+
return false;
67+
68+
bool IsConditional = isConditionalStartStop(MI2);
69+
if (isConditionalStartStop(MI1) != IsConditional)
70+
return false;
71+
72+
if (!IsConditional)
73+
return true;
74+
75+
// Check to make sure the conditional start/stop pairs are identical.
76+
if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
77+
return false;
78+
79+
// Ensure reg masks are identical.
80+
if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
81+
return false;
82+
83+
// This optimisation is unlikely to happen in practice for conditional
84+
// smstart/smstop pairs as the virtual registers for pstate.sm will always
85+
// be different.
86+
// TODO: For this optimisation to apply to conditional smstart/smstop,
87+
// this pass will need to do more work to remove redundant calls to
88+
// __arm_sme_state.
89+
90+
// Only consider conditional start/stop pairs which read the same register
91+
// holding the original value of pstate.sm, as some conditional start/stops
92+
// require the state on entry to the function.
93+
if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
94+
Register Reg1 = MI1->getOperand(3).getReg();
95+
Register Reg2 = MI2->getOperand(3).getReg();
96+
if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
97+
return false;
98+
}
99+
100+
return true;
101+
}
102+
103+
static bool ChangesStreamingMode(const MachineInstr *MI) {
104+
assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
105+
MI->getOpcode() == AArch64::MSRpstatePseudo) &&
106+
"Expected MI to be a smstart/smstop instruction");
107+
return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
108+
MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
109+
}
110+
111+
bool SMEPeepholeOpt::optimizeStartStopPairs(MachineBasicBlock &MBB,
112+
bool &HasRemainingSMChange) const {
113+
SmallVector<MachineInstr *, 4> ToBeRemoved;
114+
115+
bool Changed = false;
116+
MachineInstr *Prev = nullptr;
117+
HasRemainingSMChange = false;
118+
119+
auto Reset = [&]() {
120+
if (Prev && ChangesStreamingMode(Prev))
121+
HasRemainingSMChange = true;
122+
Prev = nullptr;
123+
ToBeRemoved.clear();
124+
};
125+
126+
// Walk through instructions in the block trying to find pairs of smstart
127+
// and smstop nodes that cancel each other out. We only permit a limited
128+
// set of instructions to appear between them, otherwise we reset our
129+
// tracking.
130+
for (MachineInstr &MI : make_early_inc_range(MBB)) {
131+
switch (MI.getOpcode()) {
132+
default:
133+
Reset();
134+
break;
135+
case AArch64::COPY: {
136+
// Permit copies of 32 and 64-bit registers.
137+
if (!MI.getOperand(1).isReg()) {
138+
Reset();
139+
break;
140+
}
141+
Register Reg = MI.getOperand(1).getReg();
142+
if (!AArch64::GPR32RegClass.contains(Reg) &&
143+
!AArch64::GPR64RegClass.contains(Reg))
144+
Reset();
145+
break;
146+
}
147+
case AArch64::ADJCALLSTACKDOWN:
148+
case AArch64::ADJCALLSTACKUP:
149+
case AArch64::ANDXri:
150+
case AArch64::ADDXri:
151+
// We permit these as they don't generate SVE/NEON instructions.
152+
break;
153+
case AArch64::VGRestorePseudo:
154+
case AArch64::VGSavePseudo:
155+
// When the smstart/smstop are removed, we should also remove
156+
// the pseudos that save/restore the VG value for CFI info.
157+
ToBeRemoved.push_back(&MI);
158+
break;
159+
case AArch64::MSRpstatesvcrImm1:
160+
case AArch64::MSRpstatePseudo: {
161+
if (!Prev)
162+
Prev = &MI;
163+
else if (isMatchingStartStopPair(Prev, &MI)) {
164+
// If they match, we can remove them, and possibly any instructions
165+
// that we marked for deletion in between.
166+
Prev->eraseFromParent();
167+
MI.eraseFromParent();
168+
for (MachineInstr *TBR : ToBeRemoved)
169+
TBR->eraseFromParent();
170+
ToBeRemoved.clear();
171+
Prev = nullptr;
172+
Changed = true;
173+
} else {
174+
Reset();
175+
Prev = &MI;
176+
}
177+
break;
178+
}
179+
}
180+
}
181+
182+
return Changed;
183+
}
184+
185+
INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
186+
"SME Peephole Optimization", false, false)
187+
188+
bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
189+
if (skipFunction(MF.getFunction()))
190+
return false;
191+
192+
if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
193+
return false;
194+
195+
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
196+
197+
bool Changed = false;
198+
bool FunctionHasRemainingSMChange = false;
199+
200+
// Even if the block lives in a function with no SME attributes attached we
201+
// still have to analyze all the blocks because we may call a streaming
202+
// function that requires smstart/smstop pairs.
203+
for (MachineBasicBlock &MBB : MF) {
204+
bool BlockHasRemainingSMChange;
205+
Changed |= optimizeStartStopPairs(MBB, BlockHasRemainingSMChange);
206+
FunctionHasRemainingSMChange |= BlockHasRemainingSMChange;
207+
}
208+
209+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
210+
if (Changed && AFI->hasStreamingModeChanges())
211+
AFI->setHasStreamingModeChanges(FunctionHasRemainingSMChange);
212+
213+
return Changed;
214+
}
215+
216+
FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }

llvm/test/CodeGen/AArch64/O3-pipeline.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
; CHECK-NEXT: MachineDominator Tree Construction
123123
; CHECK-NEXT: AArch64 Local Dynamic TLS Access Clean-up
124124
; CHECK-NEXT: Finalize ISel and expand pseudo-instructions
125+
; CHECK-NEXT: SME Peephole Optimization pass
125126
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
126127
; CHECK-NEXT: Early Tail Duplication
127128
; CHECK-NEXT: Optimize machine instruction PHIs

llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ define void @test0() nounwind {
1818
; CHECK-NEXT: str x0, [sp, #72] // 8-byte Folded Spill
1919
; CHECK-NEXT: smstart sm
2020
; CHECK-NEXT: bl callee
21-
; CHECK-NEXT: smstop sm
22-
; CHECK-NEXT: smstart sm
2321
; CHECK-NEXT: bl callee
2422
; CHECK-NEXT: smstop sm
2523
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
@@ -46,8 +44,6 @@ define void @test1() nounwind "aarch64_pstate_sm_enabled" {
4644
; CHECK-NEXT: str x0, [sp, #72] // 8-byte Folded Spill
4745
; CHECK-NEXT: smstop sm
4846
; CHECK-NEXT: bl callee
49-
; CHECK-NEXT: smstart sm
50-
; CHECK-NEXT: smstop sm
5147
; CHECK-NEXT: bl callee
5248
; CHECK-NEXT: smstart sm
5349
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
@@ -178,8 +174,6 @@ define void @test4() nounwind "aarch64_pstate_sm_enabled" {
178174
; CHECK-NEXT: smstop sm
179175
; CHECK-NEXT: fmov s0, wzr
180176
; CHECK-NEXT: bl callee_farg
181-
; CHECK-NEXT: smstart sm
182-
; CHECK-NEXT: smstop sm
183177
; CHECK-NEXT: fmov s0, wzr
184178
; CHECK-NEXT: bl callee_farg
185179
; CHECK-NEXT: smstart sm
@@ -210,8 +204,6 @@ define void @test5(float %f) nounwind "aarch64_pstate_sm_enabled" {
210204
; CHECK-NEXT: smstop sm
211205
; CHECK-NEXT: ldr s0, [sp, #12] // 4-byte Folded Reload
212206
; CHECK-NEXT: bl callee_farg
213-
; CHECK-NEXT: smstart sm
214-
; CHECK-NEXT: smstop sm
215207
; CHECK-NEXT: ldr s0, [sp, #12] // 4-byte Folded Reload
216208
; CHECK-NEXT: bl callee_farg
217209
; CHECK-NEXT: smstart sm
@@ -322,48 +314,11 @@ define void @test8() nounwind "aarch64_pstate_sm_enabled" {
322314
define void @test9() "aarch64_pstate_sm_body" {
323315
; CHECK-LABEL: test9:
324316
; CHECK: // %bb.0:
325-
; CHECK-NEXT: stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
326-
; CHECK-NEXT: .cfi_def_cfa_offset 96
327-
; CHECK-NEXT: rdsvl x9, #1
328-
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
329-
; CHECK-NEXT: lsr x9, x9, #3
330-
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
331-
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
332-
; CHECK-NEXT: stp x30, x9, [sp, #64] // 16-byte Folded Spill
333-
; CHECK-NEXT: bl __arm_get_current_vg
334-
; CHECK-NEXT: str x0, [sp, #80] // 8-byte Folded Spill
335-
; CHECK-NEXT: .cfi_offset vg, -16
336-
; CHECK-NEXT: .cfi_offset w30, -32
337-
; CHECK-NEXT: .cfi_offset b8, -40
338-
; CHECK-NEXT: .cfi_offset b9, -48
339-
; CHECK-NEXT: .cfi_offset b10, -56
340-
; CHECK-NEXT: .cfi_offset b11, -64
341-
; CHECK-NEXT: .cfi_offset b12, -72
342-
; CHECK-NEXT: .cfi_offset b13, -80
343-
; CHECK-NEXT: .cfi_offset b14, -88
344-
; CHECK-NEXT: .cfi_offset b15, -96
345-
; CHECK-NEXT: smstart sm
346-
; CHECK-NEXT: .cfi_offset vg, -24
347-
; CHECK-NEXT: smstop sm
317+
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
318+
; CHECK-NEXT: .cfi_def_cfa_offset 16
319+
; CHECK-NEXT: .cfi_offset w30, -16
348320
; CHECK-NEXT: bl callee
349-
; CHECK-NEXT: smstart sm
350-
; CHECK-NEXT: .cfi_restore vg
351-
; CHECK-NEXT: smstop sm
352-
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
353-
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
354-
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
355-
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
356-
; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload
357-
; CHECK-NEXT: .cfi_def_cfa_offset 0
358-
; CHECK-NEXT: .cfi_restore w30
359-
; CHECK-NEXT: .cfi_restore b8
360-
; CHECK-NEXT: .cfi_restore b9
361-
; CHECK-NEXT: .cfi_restore b10
362-
; CHECK-NEXT: .cfi_restore b11
363-
; CHECK-NEXT: .cfi_restore b12
364-
; CHECK-NEXT: .cfi_restore b13
365-
; CHECK-NEXT: .cfi_restore b14
366-
; CHECK-NEXT: .cfi_restore b15
321+
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
367322
; CHECK-NEXT: ret
368323
call void @callee()
369324
ret void
@@ -395,19 +350,13 @@ define void @test10() "aarch64_pstate_sm_body" {
395350
; CHECK-NEXT: .cfi_offset b13, -80
396351
; CHECK-NEXT: .cfi_offset b14, -88
397352
; CHECK-NEXT: .cfi_offset b15, -96
398-
; CHECK-NEXT: smstart sm
399-
; CHECK-NEXT: .cfi_offset vg, -24
400-
; CHECK-NEXT: smstop sm
401353
; CHECK-NEXT: bl callee
402354
; CHECK-NEXT: smstart sm
403355
; CHECK-NEXT: .cfi_restore vg
404356
; CHECK-NEXT: bl callee
405357
; CHECK-NEXT: .cfi_offset vg, -24
406358
; CHECK-NEXT: smstop sm
407359
; CHECK-NEXT: bl callee
408-
; CHECK-NEXT: smstart sm
409-
; CHECK-NEXT: .cfi_restore vg
410-
; CHECK-NEXT: smstop sm
411360
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
412361
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
413362
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload

0 commit comments

Comments
 (0)