|
18 | 18 |
|
19 | 19 | #include "RISCV.h"
|
20 | 20 | #include "RISCVSubtarget.h"
|
| 21 | +#include "RISCVISelDAGToDAG.h" |
21 | 22 | #include "llvm/CodeGen/MachineFunctionPass.h"
|
22 | 23 | #include "llvm/CodeGen/MachineRegisterInfo.h"
|
23 | 24 | #include "llvm/CodeGen/TargetInstrInfo.h"
|
@@ -48,6 +49,7 @@ class RISCVFoldMasks : public MachineFunctionPass {
|
48 | 49 | StringRef getPassName() const override { return "RISC-V Fold Masks"; }
|
49 | 50 |
|
50 | 51 | private:
|
| 52 | + bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef); |
51 | 53 | bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef);
|
52 | 54 |
|
53 | 55 | bool isAllOnesMask(MachineInstr *MaskDef);
|
@@ -132,6 +134,49 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) {
|
132 | 134 | return true;
|
133 | 135 | }
|
134 | 136 |
|
| 137 | +bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI, |
| 138 | + MachineInstr *MaskDef) { |
| 139 | + const RISCV::RISCVMaskedPseudoInfo *I = |
| 140 | + RISCV::getMaskedPseudoInfo(MI.getOpcode()); |
| 141 | + if (!I) |
| 142 | + return false; |
| 143 | + |
| 144 | + if (!isAllOnesMask(MaskDef)) |
| 145 | + return false; |
| 146 | + |
| 147 | + // There are two classes of pseudos in the table - compares and |
| 148 | + // everything else. See the comment on RISCVMaskedPseudo for details. |
| 149 | + const unsigned Opc = I->UnmaskedPseudo; |
| 150 | + const MCInstrDesc &MCID = TII->get(Opc); |
| 151 | + const bool HasPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags); |
| 152 | + const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID); |
| 153 | +#ifndef NDEBUG |
| 154 | + const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode()); |
| 155 | + assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) == |
| 156 | + RISCVII::hasVecPolicyOp(MCID.TSFlags) && |
| 157 | + "Masked and unmasked pseudos are inconsistent"); |
| 158 | + assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure"); |
| 159 | +#endif |
| 160 | + |
| 161 | + MI.setDesc(MCID); |
| 162 | + |
| 163 | + // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs? |
| 164 | + unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs(); |
| 165 | + MI.removeOperand(MaskOpIdx); |
| 166 | + |
| 167 | + // The unmasked pseudo will no longer be constrained to the vrnov0 reg class, |
| 168 | + // so try and relax it to vr. |
| 169 | + MRI->recomputeRegClass(MI.getOperand(0).getReg()); |
| 170 | + unsigned PassthruOpIdx = MI.getNumExplicitDefs(); |
| 171 | + if (HasPassthru) { |
| 172 | + if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) |
| 173 | + MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg()); |
| 174 | + } else |
| 175 | + MI.removeOperand(PassthruOpIdx); |
| 176 | + |
| 177 | + return true; |
| 178 | +} |
| 179 | + |
135 | 180 | bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
|
136 | 181 | if (skipFunction(MF.getFunction()))
|
137 | 182 | return false;
|
@@ -159,6 +204,7 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
|
159 | 204 | CurrentV0Def = nullptr;
|
160 | 205 | for (MachineInstr &MI : MBB) {
|
161 | 206 | unsigned BaseOpc = RISCV::getRVVMCOpcode(MI.getOpcode());
|
| 207 | + Changed |= convertToUnmasked(MI, CurrentV0Def); |
162 | 208 | if (BaseOpc == RISCV::VMERGE_VVM)
|
163 | 209 | Changed |= convertVMergeToVMv(MI, CurrentV0Def);
|
164 | 210 |
|
|
0 commit comments