1
- // ===- RISCVFoldMasks .cpp - MI Vector Pseudo Mask Peepholes ---------------===//
1
+ // ===- RISCVVectorPeephole .cpp - MI Vector Pseudo Peepholes ---------------===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
- // ===---------------------------------------------------------------------===//
7
+ // ===---------------------------------------------------------------------- ===//
8
8
//
9
- // This pass performs various peephole optimisations that fold masks into vector
10
- // pseudo instructions after instruction selection.
9
+ // This pass performs various vector pseudo peephole optimisations after
10
+ // instruction selection.
11
11
//
12
- // Currently it converts
12
+ // Currently it converts vmerge.vvm to vmv.v.v
13
13
// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14
14
// ->
15
15
// PseudoVMV_V_V %false, %true, %vl, %sew
16
16
//
17
- // ===---------------------------------------------------------------------===//
17
+ // And masked pseudos to unmasked pseudos
18
+ // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19
+ // ->
20
+ // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21
+ //
22
+ // It also converts AVLs to VLMAX where possible
23
+ // %vl = VLENB * something
24
+ // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25
+ // ->
26
+ // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27
+ //
28
+ // ===----------------------------------------------------------------------===//
18
29
19
30
#include " RISCV.h"
20
31
#include " RISCVISelDAGToDAG.h"
26
37
27
38
using namespace llvm ;
28
39
29
- #define DEBUG_TYPE " riscv-fold-masks "
40
+ #define DEBUG_TYPE " riscv-vector-peephole "
30
41
31
42
namespace {
32
43
33
- class RISCVFoldMasks : public MachineFunctionPass {
44
+ class RISCVVectorPeephole : public MachineFunctionPass {
34
45
public:
35
46
static char ID;
36
47
const TargetInstrInfo *TII;
37
48
MachineRegisterInfo *MRI;
38
49
const TargetRegisterInfo *TRI;
39
- RISCVFoldMasks () : MachineFunctionPass(ID) {}
50
+ RISCVVectorPeephole () : MachineFunctionPass(ID) {}
40
51
41
52
bool runOnMachineFunction (MachineFunction &MF) override ;
42
53
MachineFunctionProperties getRequiredProperties () const override {
@@ -47,6 +58,7 @@ class RISCVFoldMasks : public MachineFunctionPass {
47
58
StringRef getPassName () const override { return " RISC-V Fold Masks" ; }
48
59
49
60
private:
61
+ bool convertToVLMAX (MachineInstr &MI) const ;
50
62
bool convertToUnmasked (MachineInstr &MI) const ;
51
63
bool convertVMergeToVMv (MachineInstr &MI) const ;
52
64
@@ -58,11 +70,65 @@ class RISCVFoldMasks : public MachineFunctionPass {
58
70
59
71
} // namespace
60
72
61
- char RISCVFoldMasks::ID = 0 ;
73
+ char RISCVVectorPeephole::ID = 0 ;
74
+
75
+ INITIALIZE_PASS (RISCVVectorPeephole, DEBUG_TYPE, " RISC-V Fold Masks" , false ,
76
+ false )
77
+
78
+ // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
79
+ // to the VLMAX sentinel value.
80
+ bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
81
+ if (!RISCVII::hasVLOp (MI.getDesc ().TSFlags ) ||
82
+ !RISCVII::hasSEWOp (MI.getDesc ().TSFlags ))
83
+ return false ;
84
+ MachineOperand &VL = MI.getOperand (RISCVII::getVLOpNum (MI.getDesc ()));
85
+ if (!VL.isReg ())
86
+ return false ;
87
+ MachineInstr *Def = MRI->getVRegDef (VL.getReg ());
88
+ if (!Def)
89
+ return false ;
90
+
91
+ // Fixed-point value, denominator=8
92
+ uint64_t ScaleFixed = 8 ;
93
+ // Check if the VLENB was potentially scaled with slli/srli
94
+ if (Def->getOpcode () == RISCV::SLLI) {
95
+ assert (Def->getOperand (2 ).getImm () < 64 );
96
+ ScaleFixed <<= Def->getOperand (2 ).getImm ();
97
+ Def = MRI->getVRegDef (Def->getOperand (1 ).getReg ());
98
+ } else if (Def->getOpcode () == RISCV::SRLI) {
99
+ assert (Def->getOperand (2 ).getImm () < 64 );
100
+ ScaleFixed >>= Def->getOperand (2 ).getImm ();
101
+ Def = MRI->getVRegDef (Def->getOperand (1 ).getReg ());
102
+ }
103
+
104
+ if (!Def || Def->getOpcode () != RISCV::PseudoReadVLENB)
105
+ return false ;
106
+
107
+ auto LMUL = RISCVVType::decodeVLMUL (RISCVII::getLMul (MI.getDesc ().TSFlags ));
108
+ // Fixed-point value, denominator=8
109
+ unsigned LMULFixed = LMUL.second ? (8 / LMUL.first ) : 8 * LMUL.first ;
110
+ unsigned Log2SEW = MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
111
+ // A Log2SEW of 0 is an operation on mask registers only
112
+ unsigned SEW = Log2SEW ? 1 << Log2SEW : 8 ;
113
+ assert (RISCVVType::isValidSEW (SEW) && " Unexpected SEW" );
114
+ assert (8 * LMULFixed / SEW > 0 );
62
115
63
- INITIALIZE_PASS (RISCVFoldMasks, DEBUG_TYPE, " RISC-V Fold Masks" , false , false )
116
+ // AVL = (VLENB * Scale)
117
+ //
118
+ // VLMAX = (VLENB * 8 * LMUL) / SEW
119
+ //
120
+ // AVL == VLMAX
121
+ // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
122
+ // -> Scale == (8 * LMUL) / SEW
123
+ if (ScaleFixed != 8 * LMULFixed / SEW)
124
+ return false ;
64
125
65
- bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
126
+ VL.ChangeToImmediate (RISCV::VLMaxSentinel);
127
+
128
+ return true ;
129
+ }
130
+
131
+ bool RISCVVectorPeephole::isAllOnesMask (const MachineInstr *MaskDef) const {
66
132
assert (MaskDef && MaskDef->isCopy () &&
67
133
MaskDef->getOperand (0 ).getReg () == RISCV::V0);
68
134
Register SrcReg = TRI->lookThruCopyLike (MaskDef->getOperand (1 ).getReg (), MRI);
@@ -91,7 +157,7 @@ bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
91
157
92
158
// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
93
159
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
94
- bool RISCVFoldMasks ::convertVMergeToVMv (MachineInstr &MI) const {
160
+ bool RISCVVectorPeephole ::convertVMergeToVMv (MachineInstr &MI) const {
95
161
#define CASE_VMERGE_TO_VMV (lmul ) \
96
162
case RISCV::PseudoVMERGE_VVM_##lmul: \
97
163
NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
@@ -134,7 +200,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
134
200
return true ;
135
201
}
136
202
137
- bool RISCVFoldMasks ::convertToUnmasked (MachineInstr &MI) const {
203
+ bool RISCVVectorPeephole ::convertToUnmasked (MachineInstr &MI) const {
138
204
const RISCV::RISCVMaskedPseudoInfo *I =
139
205
RISCV::getMaskedPseudoInfo (MI.getOpcode ());
140
206
if (!I)
@@ -178,7 +244,7 @@ bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
178
244
return true ;
179
245
}
180
246
181
- bool RISCVFoldMasks ::runOnMachineFunction (MachineFunction &MF) {
247
+ bool RISCVVectorPeephole ::runOnMachineFunction (MachineFunction &MF) {
182
248
if (skipFunction (MF.getFunction ()))
183
249
return false ;
184
250
@@ -213,6 +279,7 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
213
279
214
280
for (MachineBasicBlock &MBB : MF) {
215
281
for (MachineInstr &MI : MBB) {
282
+ Changed |= convertToVLMAX (MI);
216
283
Changed |= convertToUnmasked (MI);
217
284
Changed |= convertVMergeToVMv (MI);
218
285
}
@@ -221,4 +288,6 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
221
288
return Changed;
222
289
}
223
290
224
- FunctionPass *llvm::createRISCVFoldMasksPass () { return new RISCVFoldMasks (); }
291
+ FunctionPass *llvm::createRISCVVectorPeepholePass () {
292
+ return new RISCVVectorPeephole ();
293
+ }
0 commit comments