|
| 1 | +//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// This pass tries to remove back-to-back (smstart, smstop) and |
| 9 | +// (smstop, smstart) sequences. The pass is conservative when it cannot |
| 10 | +// determine that it is safe to remove these sequences. |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "AArch64InstrInfo.h" |
| 14 | +#include "AArch64MachineFunctionInfo.h" |
| 15 | +#include "AArch64Subtarget.h" |
| 16 | +#include "Utils/AArch64SMEAttributes.h" |
| 17 | +#include "llvm/ADT/SmallVector.h" |
| 18 | +#include "llvm/CodeGen/MachineBasicBlock.h" |
| 19 | +#include "llvm/CodeGen/MachineFunctionPass.h" |
| 20 | +#include "llvm/CodeGen/MachineRegisterInfo.h" |
| 21 | +#include "llvm/CodeGen/TargetRegisterInfo.h" |
| 22 | + |
| 23 | +using namespace llvm; |
| 24 | + |
| 25 | +#define DEBUG_TYPE "aarch64-sme-peephole-opt" |
| 26 | + |
| 27 | +namespace { |
| 28 | + |
| 29 | +struct SMEPeepholeOpt : public MachineFunctionPass { |
| 30 | + static char ID; |
| 31 | + |
| 32 | + SMEPeepholeOpt() : MachineFunctionPass(ID) { |
| 33 | + initializeSMEPeepholeOptPass(*PassRegistry::getPassRegistry()); |
| 34 | + } |
| 35 | + |
| 36 | + bool runOnMachineFunction(MachineFunction &MF) override; |
| 37 | + |
| 38 | + StringRef getPassName() const override { |
| 39 | + return "SME Peephole Optimization pass"; |
| 40 | + } |
| 41 | + |
| 42 | + void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 43 | + AU.setPreservesCFG(); |
| 44 | + MachineFunctionPass::getAnalysisUsage(AU); |
| 45 | + } |
| 46 | + |
| 47 | + bool optimizeStartStopPairs(MachineBasicBlock &MBB, |
| 48 | + bool &HasRemovedAllSMChanges) const; |
| 49 | +}; |
| 50 | + |
| 51 | +char SMEPeepholeOpt::ID = 0; |
| 52 | + |
| 53 | +} // end anonymous namespace |
| 54 | + |
| 55 | +static bool isConditionalStartStop(const MachineInstr *MI) { |
| 56 | + return MI->getOpcode() == AArch64::MSRpstatePseudo; |
| 57 | +} |
| 58 | + |
| 59 | +static bool isMatchingStartStopPair(const MachineInstr *MI1, |
| 60 | + const MachineInstr *MI2) { |
| 61 | + // We only consider the same type of streaming mode change here, i.e. |
| 62 | + // start/stop SM, or start/stop ZA pairs. |
| 63 | + if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm()) |
| 64 | + return false; |
| 65 | + |
| 66 | + // One must be 'start', the other must be 'stop' |
| 67 | + if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm()) |
| 68 | + return false; |
| 69 | + |
| 70 | + bool IsConditional = isConditionalStartStop(MI2); |
| 71 | + if (isConditionalStartStop(MI1) != IsConditional) |
| 72 | + return false; |
| 73 | + |
| 74 | + if (!IsConditional) |
| 75 | + return true; |
| 76 | + |
| 77 | + // Check to make sure the conditional start/stop pairs are identical. |
| 78 | + if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm()) |
| 79 | + return false; |
| 80 | + |
| 81 | + // Ensure reg masks are identical. |
| 82 | + if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask()) |
| 83 | + return false; |
| 84 | + |
| 85 | + // This optimisation is unlikely to happen in practice for conditional |
| 86 | + // smstart/smstop pairs as the virtual registers for pstate.sm will always |
| 87 | + // be different. |
| 88 | + // TODO: For this optimisation to apply to conditional smstart/smstop, |
| 89 | + // this pass will need to do more work to remove redundant calls to |
| 90 | + // __arm_sme_state. |
| 91 | + |
| 92 | + // Only consider conditional start/stop pairs which read the same register |
| 93 | + // holding the original value of pstate.sm, as some conditional start/stops |
| 94 | + // require the state on entry to the function. |
| 95 | + if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) { |
| 96 | + Register Reg1 = MI1->getOperand(3).getReg(); |
| 97 | + Register Reg2 = MI2->getOperand(3).getReg(); |
| 98 | + if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2) |
| 99 | + return false; |
| 100 | + } |
| 101 | + |
| 102 | + return true; |
| 103 | +} |
| 104 | + |
| 105 | +static bool ChangesStreamingMode(const MachineInstr *MI) { |
| 106 | + assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 || |
| 107 | + MI->getOpcode() == AArch64::MSRpstatePseudo) && |
| 108 | + "Expected MI to be a smstart/smstop instruction"); |
| 109 | + return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM || |
| 110 | + MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA; |
| 111 | +} |
| 112 | + |
| 113 | +static bool isSVERegOp(const TargetRegisterInfo &TRI, |
| 114 | + const MachineRegisterInfo &MRI, |
| 115 | + const MachineOperand &MO) { |
| 116 | + if (!MO.isReg()) |
| 117 | + return false; |
| 118 | + |
| 119 | + Register R = MO.getReg(); |
| 120 | + if (R.isPhysical()) |
| 121 | + return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) { |
| 122 | + return AArch64::ZPRRegClass.contains(SR) || |
| 123 | + AArch64::PPRRegClass.contains(SR); |
| 124 | + }); |
| 125 | + |
| 126 | + const TargetRegisterClass *RC = MRI.getRegClass(R); |
| 127 | + return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) || |
| 128 | + TRI.getCommonSubClass(&AArch64::PPRRegClass, RC); |
| 129 | +} |
| 130 | + |
| 131 | +bool SMEPeepholeOpt::optimizeStartStopPairs( |
| 132 | + MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const { |
| 133 | + const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); |
| 134 | + const TargetRegisterInfo &TRI = |
| 135 | + *MBB.getParent()->getSubtarget().getRegisterInfo(); |
| 136 | + |
| 137 | + bool Changed = false; |
| 138 | + MachineInstr *Prev = nullptr; |
| 139 | + SmallVector<MachineInstr *, 4> ToBeRemoved; |
| 140 | + |
| 141 | + // Convenience function to reset the matching of a sequence. |
| 142 | + auto Reset = [&]() { |
| 143 | + Prev = nullptr; |
| 144 | + ToBeRemoved.clear(); |
| 145 | + }; |
| 146 | + |
| 147 | + // Walk through instructions in the block trying to find pairs of smstart |
| 148 | + // and smstop nodes that cancel each other out. We only permit a limited |
| 149 | + // set of instructions to appear between them, otherwise we reset our |
| 150 | + // tracking. |
| 151 | + unsigned NumSMChanges = 0; |
| 152 | + unsigned NumSMChangesRemoved = 0; |
| 153 | + for (MachineInstr &MI : make_early_inc_range(MBB)) { |
| 154 | + switch (MI.getOpcode()) { |
| 155 | + case AArch64::MSRpstatesvcrImm1: |
| 156 | + case AArch64::MSRpstatePseudo: { |
| 157 | + if (ChangesStreamingMode(&MI)) |
| 158 | + NumSMChanges++; |
| 159 | + |
| 160 | + if (!Prev) |
| 161 | + Prev = &MI; |
| 162 | + else if (isMatchingStartStopPair(Prev, &MI)) { |
| 163 | + // If they match, we can remove them, and possibly any instructions |
| 164 | + // that we marked for deletion in between. |
| 165 | + Prev->eraseFromParent(); |
| 166 | + MI.eraseFromParent(); |
| 167 | + for (MachineInstr *TBR : ToBeRemoved) |
| 168 | + TBR->eraseFromParent(); |
| 169 | + ToBeRemoved.clear(); |
| 170 | + Prev = nullptr; |
| 171 | + Changed = true; |
| 172 | + NumSMChangesRemoved += 2; |
| 173 | + } else { |
| 174 | + Reset(); |
| 175 | + Prev = &MI; |
| 176 | + } |
| 177 | + continue; |
| 178 | + } |
| 179 | + default: |
| 180 | + if (!Prev) |
| 181 | + // Avoid doing expensive checks when Prev is nullptr. |
| 182 | + continue; |
| 183 | + break; |
| 184 | + } |
| 185 | + |
| 186 | + // Test if the instructions in between the start/stop sequence are agnostic |
| 187 | + // of streaming mode. If not, the algorithm should reset. |
| 188 | + switch (MI.getOpcode()) { |
| 189 | + default: |
| 190 | + Reset(); |
| 191 | + break; |
| 192 | + case AArch64::COALESCER_BARRIER_FPR16: |
| 193 | + case AArch64::COALESCER_BARRIER_FPR32: |
| 194 | + case AArch64::COALESCER_BARRIER_FPR64: |
| 195 | + case AArch64::COALESCER_BARRIER_FPR128: |
| 196 | + case AArch64::COPY: |
| 197 | + // These instructions should be safe when executed on their own, but |
| 198 | + // the code remains conservative when SVE registers are used. There may |
| 199 | + // exist subtle cases where executing a COPY in a different mode results |
| 200 | + // in different behaviour, even if we can't yet come up with any |
| 201 | + // concrete example/test-case. |
| 202 | + if (isSVERegOp(TRI, MRI, MI.getOperand(0)) || |
| 203 | + isSVERegOp(TRI, MRI, MI.getOperand(1))) |
| 204 | + Reset(); |
| 205 | + break; |
| 206 | + case AArch64::ADJCALLSTACKDOWN: |
| 207 | + case AArch64::ADJCALLSTACKUP: |
| 208 | + case AArch64::ANDXri: |
| 209 | + case AArch64::ADDXri: |
| 210 | + // We permit these as they don't generate SVE/NEON instructions. |
| 211 | + break; |
| 212 | + case AArch64::VGRestorePseudo: |
| 213 | + case AArch64::VGSavePseudo: |
| 214 | + // When the smstart/smstop are removed, we should also remove |
| 215 | + // the pseudos that save/restore the VG value for CFI info. |
| 216 | + ToBeRemoved.push_back(&MI); |
| 217 | + break; |
| 218 | + case AArch64::MSRpstatesvcrImm1: |
| 219 | + case AArch64::MSRpstatePseudo: |
| 220 | + llvm_unreachable("Should have been handled"); |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + HasRemovedAllSMChanges = |
| 225 | + NumSMChanges && (NumSMChanges == NumSMChangesRemoved); |
| 226 | + return Changed; |
| 227 | +} |
| 228 | + |
| 229 | +INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt", |
| 230 | + "SME Peephole Optimization", false, false) |
| 231 | + |
| 232 | +bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { |
| 233 | + if (skipFunction(MF.getFunction())) |
| 234 | + return false; |
| 235 | + |
| 236 | + if (!MF.getSubtarget<AArch64Subtarget>().hasSME()) |
| 237 | + return false; |
| 238 | + |
| 239 | + assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!"); |
| 240 | + |
| 241 | + bool Changed = false; |
| 242 | + bool FunctionHasAllSMChangesRemoved = false; |
| 243 | + |
| 244 | + // Even if the block lives in a function with no SME attributes attached we |
| 245 | + // still have to analyze all the blocks because we may call a streaming |
| 246 | + // function that requires smstart/smstop pairs. |
| 247 | + for (MachineBasicBlock &MBB : MF) { |
| 248 | + bool BlockHasAllSMChangesRemoved; |
| 249 | + Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved); |
| 250 | + FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved; |
| 251 | + } |
| 252 | + |
| 253 | + AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); |
| 254 | + if (FunctionHasAllSMChangesRemoved) |
| 255 | + AFI->setHasStreamingModeChanges(false); |
| 256 | + |
| 257 | + return Changed; |
| 258 | +} |
| 259 | + |
| 260 | +FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); } |
0 commit comments