Skip to content

[AArch64] Add SME peephole optimizer pass #104612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();

FunctionPass *createAArch64CollectLOHPass();
FunctionPass *createSMEABIPass();
FunctionPass *createSMEPeepholeOptPass();
ModulePass *createSVEIntrinsicOptsPass();
InstructionSelector *
createAArch64InstructionSelector(const AArch64TargetMachine &,
Expand Down Expand Up @@ -110,6 +111,7 @@ void initializeFalkorHWPFFixPass(PassRegistry&);
void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
void initializeLDTLSCleanupPass(PassRegistry&);
void initializeSMEABIPass(PassRegistry &);
void initializeSMEPeepholeOptPass(PassRegistry &);
void initializeSVEIntrinsicOptsPass(PassRegistry &);
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &);
} // end namespace llvm
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ static cl::opt<bool>
cl::desc("Enable SVE intrinsic opts"),
cl::init(true));

static cl::opt<bool>
EnableSMEPeepholeOpt("enable-aarch64-sme-peephole-opt", cl::init(true),
cl::Hidden,
cl::desc("Perform SME peephole optimization"));

static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
cl::init(true), cl::Hidden);

Expand Down Expand Up @@ -256,6 +261,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
initializeLDTLSCleanupPass(*PR);
initializeKCFIPass(*PR);
initializeSMEABIPass(*PR);
initializeSMEPeepholeOptPass(*PR);
initializeSVEIntrinsicOptsPass(*PR);
initializeAArch64SpeculationHardeningPass(*PR);
initializeAArch64SLSHardeningPass(*PR);
Expand Down Expand Up @@ -754,6 +760,9 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
}

void AArch64PassConfig::addMachineSSAOptimization() {
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
addPass(createSMEPeepholeOptPass());

// Run default MachineSSAOptimization first.
TargetPassConfig::addMachineSSAOptimization();

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ add_llvm_target(AArch64CodeGen
AArch64TargetObjectFile.cpp
AArch64TargetTransformInfo.cpp
SMEABIPass.cpp
SMEPeepholeOpt.cpp
SVEIntrinsicOpts.cpp
AArch64SIMDInstrOpt.cpp

Expand Down
260 changes: 260 additions & 0 deletions llvm/lib/Target/AArch64/SMEPeepholeOpt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This pass tries to remove back-to-back (smstart, smstop) and
// (smstop, smstart) sequences. The pass is conservative when it cannot
// determine that it is safe to remove these sequences.
//===----------------------------------------------------------------------===//

#include "AArch64InstrInfo.h"
#include "AArch64MachineFunctionInfo.h"
#include "AArch64Subtarget.h"
#include "Utils/AArch64SMEAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"

using namespace llvm;

#define DEBUG_TYPE "aarch64-sme-peephole-opt"

namespace {

struct SMEPeepholeOpt : public MachineFunctionPass {
static char ID;

SMEPeepholeOpt() : MachineFunctionPass(ID) {
initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry());
}

bool runOnMachineFunction(MachineFunction &MF) override;

StringRef getPassName() const override {
return "SME Peephole Optimization pass";
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
MachineFunctionPass::getAnalysisUsage(AU);
}

bool optimizeStartStopPairs(MachineBasicBlock &MBB,
bool &HasRemovedAllSMChanges) const;
};

char SMEPeepholeOpt::ID = 0;

} // end anonymous namespace

static bool isConditionalStartStop(const MachineInstr *MI) {
return MI->getOpcode() == AArch64::MSRpstatePseudo;
}

static bool isMatchingStartStopPair(const MachineInstr *MI1,
const MachineInstr *MI2) {
// We only consider the same type of streaming mode change here, i.e.
// start/stop SM, or start/stop ZA pairs.
if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
return false;

// One must be 'start', the other must be 'stop'
if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
return false;

bool IsConditional = isConditionalStartStop(MI2);
if (isConditionalStartStop(MI1) != IsConditional)
return false;

if (!IsConditional)
return true;

// Check to make sure the conditional start/stop pairs are identical.
if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
return false;

// Ensure reg masks are identical.
if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
return false;

// This optimisation is unlikely to happen in practice for conditional
// smstart/smstop pairs as the virtual registers for pstate.sm will always
// be different.
// TODO: For this optimisation to apply to conditional smstart/smstop,
// this pass will need to do more work to remove redundant calls to
// __arm_sme_state.

// Only consider conditional start/stop pairs which read the same register
// holding the original value of pstate.sm, as some conditional start/stops
// require the state on entry to the function.
if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
Register Reg1 = MI1->getOperand(3).getReg();
Register Reg2 = MI2->getOperand(3).getReg();
if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
return false;
}

return true;
}

static bool ChangesStreamingMode(const MachineInstr *MI) {
assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
MI->getOpcode() == AArch64::MSRpstatePseudo) &&
"Expected MI to be a smstart/smstop instruction");
return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
}

static bool isSVERegOp(const TargetRegisterInfo &TRI,
const MachineRegisterInfo &MRI,
const MachineOperand &MO) {
if (!MO.isReg())
return false;

Register R = MO.getReg();
if (R.isPhysical())
return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
return AArch64::ZPRRegClass.contains(SR) ||
AArch64::PPRRegClass.contains(SR);
});

const TargetRegisterClass *RC = MRI.getRegClass(R);
return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
}

bool SMEPeepholeOpt::optimizeStartStopPairs(
MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
const TargetRegisterInfo &TRI =
*MBB.getParent()->getSubtarget().getRegisterInfo();

bool Changed = false;
MachineInstr *Prev = nullptr;
SmallVector<MachineInstr *, 4> ToBeRemoved;

// Convenience function to reset the matching of a sequence.
auto Reset = [&]() {
Prev = nullptr;
ToBeRemoved.clear();
};

// Walk through instructions in the block trying to find pairs of smstart
// and smstop nodes that cancel each other out. We only permit a limited
// set of instructions to appear between them, otherwise we reset our
// tracking.
unsigned NumSMChanges = 0;
unsigned NumSMChangesRemoved = 0;
for (MachineInstr &MI : make_early_inc_range(MBB)) {
switch (MI.getOpcode()) {
case AArch64::MSRpstatesvcrImm1:
case AArch64::MSRpstatePseudo: {
if (ChangesStreamingMode(&MI))
NumSMChanges++;

if (!Prev)
Prev = &MI;
else if (isMatchingStartStopPair(Prev, &MI)) {
// If they match, we can remove them, and possibly any instructions
// that we marked for deletion in between.
Prev->eraseFromParent();
MI.eraseFromParent();
for (MachineInstr *TBR : ToBeRemoved)
TBR->eraseFromParent();
ToBeRemoved.clear();
Prev = nullptr;
Changed = true;
NumSMChangesRemoved += 2;
} else {
Reset();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not be following this correctly, but if each block only has one smstart/smstop, does this ever get reached, and so does HasRemainingSMChange even get set? Is it relying on always having another instruction in the block (say a terminator)?

Like I said I may not understand the reasoning, and the pass otherwise looks good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In isMatchingStartStopPair we test if these instructions are a matching smstart, smstop or smstop, smstart pair, based on the operands to this pseudo. (we don't want the algorithm to fold away smstart sm, smstart za for example). If we don't find a matching pseudo, it ends up in this block and the algorithm resets.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. My point was maybe quite contrived. If have a block like:

bb:
  smstart

It doesn't have anything else in it (maybe a copy or another "allowed" instruction), and no terminator. There would be another block with the corresponding smstop, and another block with a smstart/smtop pair that gets removed (to set Changed=true). In the block above it would start out as HasRemainingSMChange = false;, look at the smstart to set if (!Prev) Prev = &MI;, then return because there were no changes. Nothing set HasRemainingSMChange, so the runOnMachineFunction function sets AFI->setHasStreamingModeChanges(FunctionHasRemainingSMChange); (= false). It has to be a bit contrived to set Changed but not FunctionHasRemainingSMChange.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying. Yes, that's an unusual case that wouldn't ever happen with the way the smstart/smstop nodes are emitted, but I agree that the pass should be written with this case in mind. I've fixed the issue but I am unable to write an explicit test for it.

Prev = &MI;
}
continue;
}
default:
if (!Prev)
// Avoid doing expensive checks when Prev is nullptr.
continue;
break;
}

// Test if the instructions in between the start/stop sequence are agnostic
// of streaming mode. If not, the algorithm should reset.
switch (MI.getOpcode()) {
default:
Reset();
break;
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
case AArch64::COALESCER_BARRIER_FPR64:
case AArch64::COALESCER_BARRIER_FPR128:
case AArch64::COPY:
// These instructions should be safe when executed on their own, but
// the code remains conservative when SVE registers are used. There may
// exist subtle cases where executing a COPY in a different mode results
// in different behaviour, even if we can't yet come up with any
// concrete example/test-case.
if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
isSVERegOp(TRI, MRI, MI.getOperand(1)))
Reset();
break;
case AArch64::ADJCALLSTACKDOWN:
case AArch64::ADJCALLSTACKUP:
case AArch64::ANDXri:
case AArch64::ADDXri:
// We permit these as they don't generate SVE/NEON instructions.
break;
case AArch64::VGRestorePseudo:
case AArch64::VGSavePseudo:
// When the smstart/smstop are removed, we should also remove
// the pseudos that save/restore the VG value for CFI info.
ToBeRemoved.push_back(&MI);
break;
case AArch64::MSRpstatesvcrImm1:
case AArch64::MSRpstatePseudo:
llvm_unreachable("Should have been handled");
}
}

HasRemovedAllSMChanges =
NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
return Changed;
}

INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
"SME Peephole Optimization", false, false)

bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;

if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
return false;

assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");

bool Changed = false;
bool FunctionHasAllSMChangesRemoved = false;

// Even if the block lives in a function with no SME attributes attached we
// still have to analyze all the blocks because we may call a streaming
// function that requires smstart/smstop pairs.
for (MachineBasicBlock &MBB : MF) {
bool BlockHasAllSMChangesRemoved;
Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
}

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
if (FunctionHasAllSMChangesRemoved)
AFI->setHasStreamingModeChanges(false);

return Changed;
}

FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
1 change: 1 addition & 0 deletions llvm/test/CodeGen/AArch64/O3-pipeline.ll
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
; CHECK-NEXT: MachineDominator Tree Construction
; CHECK-NEXT: AArch64 Local Dynamic TLS Access Clean-up
; CHECK-NEXT: Finalize ISel and expand pseudo-instructions
; CHECK-NEXT: SME Peephole Optimization pass
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
; CHECK-NEXT: Early Tail Duplication
; CHECK-NEXT: Optimize machine instruction PHIs
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/sme-darwin-sve-vg.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: llc -mtriple=aarch64-darwin -mattr=+sve -mattr=+sme -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=aarch64-darwin -mattr=+sve -mattr=+sme -enable-aarch64-sme-peephole-opt=false -verify-machineinstrs < %s | FileCheck %s

declare void @normal_callee();

Expand Down
Loading
Loading