Skip to content

Commit 6c189ea

Browse files
[AArch64] Add SME peephole optimizer pass (#104612)
This pass removes back-to-back smstart/smstop instructions to reduce the number of streaming mode changes in a function. The implementation as proposed doesn't aim to solve all problems yet and suggests a number of cases that can be optimized in the future.
1 parent bacedb5 commit 6c189ea

12 files changed

+789
-60
lines changed

llvm/lib/Target/AArch64/AArch64.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
5959

6060
FunctionPass *createAArch64CollectLOHPass();
6161
FunctionPass *createSMEABIPass();
62+
FunctionPass *createSMEPeepholeOptPass();
6263
ModulePass *createSVEIntrinsicOptsPass();
6364
InstructionSelector *
6465
createAArch64InstructionSelector(const AArch64TargetMachine &,
@@ -110,6 +111,7 @@ void initializeFalkorHWPFFixPass(PassRegistry&);
110111
void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
111112
void initializeLDTLSCleanupPass(PassRegistry&);
112113
void initializeSMEABIPass(PassRegistry &);
114+
void initializeSMEPeepholeOptPass(PassRegistry &);
113115
void initializeSVEIntrinsicOptsPass(PassRegistry &);
114116
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &);
115117
} // end namespace llvm

llvm/lib/Target/AArch64/AArch64TargetMachine.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ static cl::opt<bool>
167167
cl::desc("Enable SVE intrinsic opts"),
168168
cl::init(true));
169169

170+
static cl::opt<bool>
171+
EnableSMEPeepholeOpt("enable-aarch64-sme-peephole-opt", cl::init(true),
172+
cl::Hidden,
173+
cl::desc("Perform SME peephole optimization"));
174+
170175
static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
171176
cl::init(true), cl::Hidden);
172177

@@ -256,6 +261,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
256261
initializeLDTLSCleanupPass(*PR);
257262
initializeKCFIPass(*PR);
258263
initializeSMEABIPass(*PR);
264+
initializeSMEPeepholeOptPass(*PR);
259265
initializeSVEIntrinsicOptsPass(*PR);
260266
initializeAArch64SpeculationHardeningPass(*PR);
261267
initializeAArch64SLSHardeningPass(*PR);
@@ -754,6 +760,9 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
754760
}
755761

756762
void AArch64PassConfig::addMachineSSAOptimization() {
763+
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
764+
addPass(createSMEPeepholeOptPass());
765+
757766
// Run default MachineSSAOptimization first.
758767
TargetPassConfig::addMachineSSAOptimization();
759768

llvm/lib/Target/AArch64/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ add_llvm_target(AArch64CodeGen
8787
AArch64TargetObjectFile.cpp
8888
AArch64TargetTransformInfo.cpp
8989
SMEABIPass.cpp
90+
SMEPeepholeOpt.cpp
9091
SVEIntrinsicOpts.cpp
9192
AArch64SIMDInstrOpt.cpp
9293

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This pass tries to remove back-to-back (smstart, smstop) and
9+
// (smstop, smstart) sequences. The pass is conservative when it cannot
10+
// determine that it is safe to remove these sequences.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "AArch64InstrInfo.h"
14+
#include "AArch64MachineFunctionInfo.h"
15+
#include "AArch64Subtarget.h"
16+
#include "Utils/AArch64SMEAttributes.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/CodeGen/MachineBasicBlock.h"
19+
#include "llvm/CodeGen/MachineFunctionPass.h"
20+
#include "llvm/CodeGen/MachineRegisterInfo.h"
21+
#include "llvm/CodeGen/TargetRegisterInfo.h"
22+
23+
using namespace llvm;
24+
25+
#define DEBUG_TYPE "aarch64-sme-peephole-opt"
26+
27+
namespace {
28+
29+
struct SMEPeepholeOpt : public MachineFunctionPass {
30+
static char ID;
31+
32+
SMEPeepholeOpt() : MachineFunctionPass(ID) {
33+
initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry());
34+
}
35+
36+
bool runOnMachineFunction(MachineFunction &MF) override;
37+
38+
StringRef getPassName() const override {
39+
return "SME Peephole Optimization pass";
40+
}
41+
42+
void getAnalysisUsage(AnalysisUsage &AU) const override {
43+
AU.setPreservesCFG();
44+
MachineFunctionPass::getAnalysisUsage(AU);
45+
}
46+
47+
bool optimizeStartStopPairs(MachineBasicBlock &MBB,
48+
bool &HasRemovedAllSMChanges) const;
49+
};
50+
51+
char SMEPeepholeOpt::ID = 0;
52+
53+
} // end anonymous namespace
54+
55+
static bool isConditionalStartStop(const MachineInstr *MI) {
56+
return MI->getOpcode() == AArch64::MSRpstatePseudo;
57+
}
58+
59+
static bool isMatchingStartStopPair(const MachineInstr *MI1,
60+
const MachineInstr *MI2) {
61+
// We only consider the same type of streaming mode change here, i.e.
62+
// start/stop SM, or start/stop ZA pairs.
63+
if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
64+
return false;
65+
66+
// One must be 'start', the other must be 'stop'
67+
if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
68+
return false;
69+
70+
bool IsConditional = isConditionalStartStop(MI2);
71+
if (isConditionalStartStop(MI1) != IsConditional)
72+
return false;
73+
74+
if (!IsConditional)
75+
return true;
76+
77+
// Check to make sure the conditional start/stop pairs are identical.
78+
if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
79+
return false;
80+
81+
// Ensure reg masks are identical.
82+
if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
83+
return false;
84+
85+
// This optimisation is unlikely to happen in practice for conditional
86+
// smstart/smstop pairs as the virtual registers for pstate.sm will always
87+
// be different.
88+
// TODO: For this optimisation to apply to conditional smstart/smstop,
89+
// this pass will need to do more work to remove redundant calls to
90+
// __arm_sme_state.
91+
92+
// Only consider conditional start/stop pairs which read the same register
93+
// holding the original value of pstate.sm, as some conditional start/stops
94+
// require the state on entry to the function.
95+
if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
96+
Register Reg1 = MI1->getOperand(3).getReg();
97+
Register Reg2 = MI2->getOperand(3).getReg();
98+
if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
99+
return false;
100+
}
101+
102+
return true;
103+
}
104+
105+
static bool ChangesStreamingMode(const MachineInstr *MI) {
106+
assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
107+
MI->getOpcode() == AArch64::MSRpstatePseudo) &&
108+
"Expected MI to be a smstart/smstop instruction");
109+
return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
110+
MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
111+
}
112+
113+
static bool isSVERegOp(const TargetRegisterInfo &TRI,
114+
const MachineRegisterInfo &MRI,
115+
const MachineOperand &MO) {
116+
if (!MO.isReg())
117+
return false;
118+
119+
Register R = MO.getReg();
120+
if (R.isPhysical())
121+
return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
122+
return AArch64::ZPRRegClass.contains(SR) ||
123+
AArch64::PPRRegClass.contains(SR);
124+
});
125+
126+
const TargetRegisterClass *RC = MRI.getRegClass(R);
127+
return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
128+
TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
129+
}
130+
131+
bool SMEPeepholeOpt::optimizeStartStopPairs(
132+
MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
133+
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
134+
const TargetRegisterInfo &TRI =
135+
*MBB.getParent()->getSubtarget().getRegisterInfo();
136+
137+
bool Changed = false;
138+
MachineInstr *Prev = nullptr;
139+
SmallVector<MachineInstr *, 4> ToBeRemoved;
140+
141+
// Convenience function to reset the matching of a sequence.
142+
auto Reset = [&]() {
143+
Prev = nullptr;
144+
ToBeRemoved.clear();
145+
};
146+
147+
// Walk through instructions in the block trying to find pairs of smstart
148+
// and smstop nodes that cancel each other out. We only permit a limited
149+
// set of instructions to appear between them, otherwise we reset our
150+
// tracking.
151+
unsigned NumSMChanges = 0;
152+
unsigned NumSMChangesRemoved = 0;
153+
for (MachineInstr &MI : make_early_inc_range(MBB)) {
154+
switch (MI.getOpcode()) {
155+
case AArch64::MSRpstatesvcrImm1:
156+
case AArch64::MSRpstatePseudo: {
157+
if (ChangesStreamingMode(&MI))
158+
NumSMChanges++;
159+
160+
if (!Prev)
161+
Prev = &MI;
162+
else if (isMatchingStartStopPair(Prev, &MI)) {
163+
// If they match, we can remove them, and possibly any instructions
164+
// that we marked for deletion in between.
165+
Prev->eraseFromParent();
166+
MI.eraseFromParent();
167+
for (MachineInstr *TBR : ToBeRemoved)
168+
TBR->eraseFromParent();
169+
ToBeRemoved.clear();
170+
Prev = nullptr;
171+
Changed = true;
172+
NumSMChangesRemoved += 2;
173+
} else {
174+
Reset();
175+
Prev = &MI;
176+
}
177+
continue;
178+
}
179+
default:
180+
if (!Prev)
181+
// Avoid doing expensive checks when Prev is nullptr.
182+
continue;
183+
break;
184+
}
185+
186+
// Test if the instructions in between the start/stop sequence are agnostic
187+
// of streaming mode. If not, the algorithm should reset.
188+
switch (MI.getOpcode()) {
189+
default:
190+
Reset();
191+
break;
192+
case AArch64::COALESCER_BARRIER_FPR16:
193+
case AArch64::COALESCER_BARRIER_FPR32:
194+
case AArch64::COALESCER_BARRIER_FPR64:
195+
case AArch64::COALESCER_BARRIER_FPR128:
196+
case AArch64::COPY:
197+
// These instructions should be safe when executed on their own, but
198+
// the code remains conservative when SVE registers are used. There may
199+
// exist subtle cases where executing a COPY in a different mode results
200+
// in different behaviour, even if we can't yet come up with any
201+
// concrete example/test-case.
202+
if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
203+
isSVERegOp(TRI, MRI, MI.getOperand(1)))
204+
Reset();
205+
break;
206+
case AArch64::ADJCALLSTACKDOWN:
207+
case AArch64::ADJCALLSTACKUP:
208+
case AArch64::ANDXri:
209+
case AArch64::ADDXri:
210+
// We permit these as they don't generate SVE/NEON instructions.
211+
break;
212+
case AArch64::VGRestorePseudo:
213+
case AArch64::VGSavePseudo:
214+
// When the smstart/smstop are removed, we should also remove
215+
// the pseudos that save/restore the VG value for CFI info.
216+
ToBeRemoved.push_back(&MI);
217+
break;
218+
case AArch64::MSRpstatesvcrImm1:
219+
case AArch64::MSRpstatePseudo:
220+
llvm_unreachable("Should have been handled");
221+
}
222+
}
223+
224+
HasRemovedAllSMChanges =
225+
NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
226+
return Changed;
227+
}
228+
229+
INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
230+
"SME Peephole Optimization", false, false)
231+
232+
bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
233+
if (skipFunction(MF.getFunction()))
234+
return false;
235+
236+
if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
237+
return false;
238+
239+
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
240+
241+
bool Changed = false;
242+
bool FunctionHasAllSMChangesRemoved = false;
243+
244+
// Even if the block lives in a function with no SME attributes attached we
245+
// still have to analyze all the blocks because we may call a streaming
246+
// function that requires smstart/smstop pairs.
247+
for (MachineBasicBlock &MBB : MF) {
248+
bool BlockHasAllSMChangesRemoved;
249+
Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
250+
FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
251+
}
252+
253+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
254+
if (FunctionHasAllSMChangesRemoved)
255+
AFI->setHasStreamingModeChanges(false);
256+
257+
return Changed;
258+
}
259+
260+
FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }

llvm/test/CodeGen/AArch64/O3-pipeline.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
; CHECK-NEXT: MachineDominator Tree Construction
123123
; CHECK-NEXT: AArch64 Local Dynamic TLS Access Clean-up
124124
; CHECK-NEXT: Finalize ISel and expand pseudo-instructions
125+
; CHECK-NEXT: SME Peephole Optimization pass
125126
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
126127
; CHECK-NEXT: Early Tail Duplication
127128
; CHECK-NEXT: Optimize machine instruction PHIs

llvm/test/CodeGen/AArch64/sme-darwin-sve-vg.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: llc -mtriple=aarch64-darwin -mattr=+sve -mattr=+sme -verify-machineinstrs < %s | FileCheck %s
1+
; RUN: llc -mtriple=aarch64-darwin -mattr=+sve -mattr=+sme -enable-aarch64-sme-peephole-opt=false -verify-machineinstrs < %s | FileCheck %s
22

33
declare void @normal_callee();
44

0 commit comments

Comments
 (0)