-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64] Add SME peephole optimizer pass #104612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
d1d21e6
Precommit test
sdesmalen-arm ebdf9f7
[AArch64] Add SME peephole optimizer pass
sdesmalen-arm b577d4d
Relax restrictions for COPY and refactor
sdesmalen-arm 0ebdba4
Fix issue where HasStreamingModeChanges may be set to 'false' incorre…
sdesmalen-arm 4c09faf
Fix failing test before committing
sdesmalen-arm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,260 @@ | ||
//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// This pass tries to remove back-to-back (smstart, smstop) and | ||
// (smstop, smstart) sequences. The pass is conservative when it cannot | ||
// determine that it is safe to remove these sequences. | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "AArch64InstrInfo.h" | ||
#include "AArch64MachineFunctionInfo.h" | ||
#include "AArch64Subtarget.h" | ||
#include "Utils/AArch64SMEAttributes.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/CodeGen/MachineBasicBlock.h" | ||
#include "llvm/CodeGen/MachineFunctionPass.h" | ||
#include "llvm/CodeGen/MachineRegisterInfo.h" | ||
#include "llvm/CodeGen/TargetRegisterInfo.h" | ||
|
||
using namespace llvm; | ||
|
||
#define DEBUG_TYPE "aarch64-sme-peephole-opt" | ||
|
||
namespace { | ||
|
||
struct SMEPeepholeOpt : public MachineFunctionPass { | ||
static char ID; | ||
|
||
SMEPeepholeOpt() : MachineFunctionPass(ID) { | ||
initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry()); | ||
} | ||
|
||
bool runOnMachineFunction(MachineFunction &MF) override; | ||
|
||
StringRef getPassName() const override { | ||
return "SME Peephole Optimization pass"; | ||
} | ||
|
||
void getAnalysisUsage(AnalysisUsage &AU) const override { | ||
AU.setPreservesCFG(); | ||
MachineFunctionPass::getAnalysisUsage(AU); | ||
} | ||
|
||
bool optimizeStartStopPairs(MachineBasicBlock &MBB, | ||
bool &HasRemovedAllSMChanges) const; | ||
}; | ||
|
||
char SMEPeepholeOpt::ID = 0; | ||
|
||
} // end anonymous namespace | ||
|
||
static bool isConditionalStartStop(const MachineInstr *MI) { | ||
return MI->getOpcode() == AArch64::MSRpstatePseudo; | ||
} | ||
|
||
static bool isMatchingStartStopPair(const MachineInstr *MI1, | ||
const MachineInstr *MI2) { | ||
// We only consider the same type of streaming mode change here, i.e. | ||
// start/stop SM, or start/stop ZA pairs. | ||
if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm()) | ||
return false; | ||
|
||
// One must be 'start', the other must be 'stop' | ||
if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm()) | ||
return false; | ||
|
||
bool IsConditional = isConditionalStartStop(MI2); | ||
if (isConditionalStartStop(MI1) != IsConditional) | ||
return false; | ||
|
||
if (!IsConditional) | ||
return true; | ||
|
||
// Check to make sure the conditional start/stop pairs are identical. | ||
if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm()) | ||
return false; | ||
|
||
// Ensure reg masks are identical. | ||
if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask()) | ||
return false; | ||
|
||
// This optimisation is unlikely to happen in practice for conditional | ||
// smstart/smstop pairs as the virtual registers for pstate.sm will always | ||
// be different. | ||
// TODO: For this optimisation to apply to conditional smstart/smstop, | ||
// this pass will need to do more work to remove redundant calls to | ||
// __arm_sme_state. | ||
|
||
// Only consider conditional start/stop pairs which read the same register | ||
// holding the original value of pstate.sm, as some conditional start/stops | ||
// require the state on entry to the function. | ||
if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) { | ||
Register Reg1 = MI1->getOperand(3).getReg(); | ||
Register Reg2 = MI2->getOperand(3).getReg(); | ||
if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2) | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
static bool ChangesStreamingMode(const MachineInstr *MI) { | ||
assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 || | ||
MI->getOpcode() == AArch64::MSRpstatePseudo) && | ||
"Expected MI to be a smstart/smstop instruction"); | ||
return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM || | ||
MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA; | ||
} | ||
|
||
static bool isSVERegOp(const TargetRegisterInfo &TRI, | ||
const MachineRegisterInfo &MRI, | ||
const MachineOperand &MO) { | ||
if (!MO.isReg()) | ||
return false; | ||
|
||
Register R = MO.getReg(); | ||
if (R.isPhysical()) | ||
return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) { | ||
return AArch64::ZPRRegClass.contains(SR) || | ||
AArch64::PPRRegClass.contains(SR); | ||
}); | ||
|
||
const TargetRegisterClass *RC = MRI.getRegClass(R); | ||
return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) || | ||
TRI.getCommonSubClass(&AArch64::PPRRegClass, RC); | ||
} | ||
|
||
bool SMEPeepholeOpt::optimizeStartStopPairs( | ||
MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const { | ||
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); | ||
const TargetRegisterInfo &TRI = | ||
*MBB.getParent()->getSubtarget().getRegisterInfo(); | ||
|
||
bool Changed = false; | ||
MachineInstr *Prev = nullptr; | ||
SmallVector<MachineInstr *, 4> ToBeRemoved; | ||
|
||
// Convenience function to reset the matching of a sequence. | ||
auto Reset = [&]() { | ||
Prev = nullptr; | ||
ToBeRemoved.clear(); | ||
}; | ||
|
||
// Walk through instructions in the block trying to find pairs of smstart | ||
// and smstop nodes that cancel each other out. We only permit a limited | ||
// set of instructions to appear between them, otherwise we reset our | ||
// tracking. | ||
unsigned NumSMChanges = 0; | ||
unsigned NumSMChangesRemoved = 0; | ||
for (MachineInstr &MI : make_early_inc_range(MBB)) { | ||
switch (MI.getOpcode()) { | ||
case AArch64::MSRpstatesvcrImm1: | ||
case AArch64::MSRpstatePseudo: { | ||
if (ChangesStreamingMode(&MI)) | ||
NumSMChanges++; | ||
|
||
if (!Prev) | ||
Prev = &MI; | ||
else if (isMatchingStartStopPair(Prev, &MI)) { | ||
// If they match, we can remove them, and possibly any instructions | ||
// that we marked for deletion in between. | ||
Prev->eraseFromParent(); | ||
MI.eraseFromParent(); | ||
for (MachineInstr *TBR : ToBeRemoved) | ||
TBR->eraseFromParent(); | ||
ToBeRemoved.clear(); | ||
Prev = nullptr; | ||
Changed = true; | ||
NumSMChangesRemoved += 2; | ||
} else { | ||
Reset(); | ||
Prev = &MI; | ||
} | ||
continue; | ||
} | ||
default: | ||
if (!Prev) | ||
// Avoid doing expensive checks when Prev is nullptr. | ||
continue; | ||
break; | ||
} | ||
|
||
// Test if the instructions in between the start/stop sequence are agnostic | ||
// of streaming mode. If not, the algorithm should reset. | ||
switch (MI.getOpcode()) { | ||
default: | ||
Reset(); | ||
break; | ||
case AArch64::COALESCER_BARRIER_FPR16: | ||
case AArch64::COALESCER_BARRIER_FPR32: | ||
case AArch64::COALESCER_BARRIER_FPR64: | ||
case AArch64::COALESCER_BARRIER_FPR128: | ||
case AArch64::COPY: | ||
// These instructions should be safe when executed on their own, but | ||
// the code remains conservative when SVE registers are used. There may | ||
// exist subtle cases where executing a COPY in a different mode results | ||
// in different behaviour, even if we can't yet come up with any | ||
// concrete example/test-case. | ||
if (isSVERegOp(TRI, MRI, MI.getOperand(0)) || | ||
isSVERegOp(TRI, MRI, MI.getOperand(1))) | ||
Reset(); | ||
break; | ||
case AArch64::ADJCALLSTACKDOWN: | ||
case AArch64::ADJCALLSTACKUP: | ||
case AArch64::ANDXri: | ||
case AArch64::ADDXri: | ||
// We permit these as they don't generate SVE/NEON instructions. | ||
break; | ||
case AArch64::VGRestorePseudo: | ||
case AArch64::VGSavePseudo: | ||
// When the smstart/smstop are removed, we should also remove | ||
// the pseudos that save/restore the VG value for CFI info. | ||
ToBeRemoved.push_back(&MI); | ||
break; | ||
case AArch64::MSRpstatesvcrImm1: | ||
case AArch64::MSRpstatePseudo: | ||
llvm_unreachable("Should have been handled"); | ||
} | ||
} | ||
|
||
HasRemovedAllSMChanges = | ||
NumSMChanges && (NumSMChanges == NumSMChangesRemoved); | ||
return Changed; | ||
} | ||
|
||
INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt", | ||
"SME Peephole Optimization", false, false) | ||
|
||
bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { | ||
if (skipFunction(MF.getFunction())) | ||
return false; | ||
|
||
if (!MF.getSubtarget<AArch64Subtarget>().hasSME()) | ||
return false; | ||
|
||
assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!"); | ||
|
||
bool Changed = false; | ||
bool FunctionHasAllSMChangesRemoved = false; | ||
|
||
// Even if the block lives in a function with no SME attributes attached we | ||
// still have to analyze all the blocks because we may call a streaming | ||
// function that requires smstart/smstop pairs. | ||
for (MachineBasicBlock &MBB : MF) { | ||
bool BlockHasAllSMChangesRemoved; | ||
Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved); | ||
FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved; | ||
} | ||
|
||
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); | ||
if (FunctionHasAllSMChangesRemoved) | ||
AFI->setHasStreamingModeChanges(false); | ||
|
||
return Changed; | ||
} | ||
|
||
FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might not be following this correctly, but if each block only has one smstart/smstop, does this ever get reached, and so does HasRemainingSMChange even get set? Is it relying on always having another instruction in the block (say a terminator)?
Like I said I may not understand the reasoning, and the pass otherwise looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
isMatchingStartStopPair
we test if these instructions are a matchingsmstart, smstop
orsmstop, smstart
pair, based on the operands to this pseudo. (we don't want the algorithm to fold awaysmstart sm, smstart za
for example). If we don't find a matching pseudo, it ends up in this block and the algorithm resets.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. My point was maybe quite contrived. If have a block like:
It doesn't have anything else in it (maybe a copy or another "allowed" instruction), and no terminator. There would be another block with the corresponding smstop, and another block with a smstart/smtop pair that gets removed (to set
Changed=true
). In the block above it would start out asHasRemainingSMChange = false;
, look at the smstart to setif (!Prev) Prev = &MI;
, then return because there were no changes. Nothing setHasRemainingSMChange
, so the runOnMachineFunction function setsAFI->setHasStreamingModeChanges(FunctionHasRemainingSMChange);
(= false). It has to be a bit contrived to setChanged
but notFunctionHasRemainingSMChange
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying. Yes, that's an unusual case that wouldn't ever happen with the way the smstart/smstop nodes are emitted, but I agree that the pass should be written with this case in mind. I've fixed the issue but I am unable to write an explicit test for it.