Skip to content

Commit 98f59b2

Browse files
committed
[RISCV] Teach doPeepholeMaskedRVV to handle FMA instructions.
This lets us remove some isel patterns. Reviewed By: fakepaper56 Differential Revision: https://reviews.llvm.org/D150463
1 parent 245cb1f commit 98f59b2

File tree

3 files changed

+27
-46
lines changed

3 files changed

+27
-46
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,37 +3157,42 @@ bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
31573157
const RISCVInstrInfo &TII = *Subtarget->getInstrInfo();
31583158
const MCInstrDesc &MaskedMCID = TII.get(N->getMachineOpcode());
31593159

3160-
bool IsTA = true;
3160+
bool UseTUPseudo = false;
31613161
if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) {
3162-
TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID);
3163-
if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
3164-
RISCVII::TAIL_AGNOSTIC)) {
3165-
// Keep the true-masked instruction when there is no unmasked TU
3166-
// instruction
3167-
if (I->UnmaskedTUPseudo == I->MaskedPseudo && !N->getOperand(0).isUndef())
3168-
return false;
3169-
// We can't use TA if the tie-operand is not IMPLICIT_DEF
3170-
if (!N->getOperand(0).isUndef())
3171-
IsTA = false;
3162+
// Some operations are their own TU.
3163+
if (I->UnmaskedTUPseudo == I->UnmaskedPseudo) {
3164+
UseTUPseudo = true;
3165+
} else {
3166+
TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID);
3167+
if (!(N->getConstantOperandVal(*TailPolicyOpIdx) &
3168+
RISCVII::TAIL_AGNOSTIC)) {
3169+
// We can't use TA if the tie-operand is not IMPLICIT_DEF
3170+
if (!N->getOperand(0).isUndef()) {
3171+
// Keep the true-masked instruction when there is no unmasked TU
3172+
// instruction
3173+
if (I->UnmaskedTUPseudo == I->MaskedPseudo)
3174+
return false;
3175+
UseTUPseudo = true;
3176+
}
3177+
}
31723178
}
31733179
}
31743180

3175-
unsigned Opc = IsTA ? I->UnmaskedPseudo : I->UnmaskedTUPseudo;
3181+
unsigned Opc = UseTUPseudo ? I->UnmaskedTUPseudo : I->UnmaskedPseudo;
31763182

31773183
// Check that we're dropping the mask operand and any policy operand
31783184
// when we transform to this unmasked pseudo. Additionally, if this
31793185
// instruction is tail agnostic, the unmasked instruction should not have a
31803186
// merge op.
31813187
uint64_t TSFlags = TII.get(Opc).TSFlags;
3182-
assert((IsTA != RISCVII::hasMergeOp(TSFlags)) &&
3188+
assert((UseTUPseudo == RISCVII::hasMergeOp(TSFlags)) &&
31833189
RISCVII::hasDummyMaskOp(TSFlags) &&
3184-
!RISCVII::hasVecPolicyOp(TSFlags) &&
31853190
"Unexpected pseudo to transform to");
31863191
(void)TSFlags;
31873192

31883193
SmallVector<SDValue, 8> Ops;
3189-
// Skip the merge operand at index 0 if IsTA
3190-
for (unsigned I = IsTA, E = N->getNumOperands(); I != E; I++) {
3194+
// Skip the merge operand at index 0 if !UseTUPseudo.
3195+
for (unsigned I = !UseTUPseudo, E = N->getNumOperands(); I != E; I++) {
31913196
// Skip the mask, the policy, and the Glue.
31923197
SDValue Op = N->getOperand(I);
31933198
if (I == MaskOpIdx || I == TailPolicyOpIdx ||

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,12 @@ def RISCVVIntrinsicsTable : GenericTable {
472472
let PrimaryKeyName = "getRISCVVIntrinsicInfo";
473473
}
474474

475-
class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true> {
475+
class RISCVMaskedPseudo<bits<4> MaskIdx, bit HasTU = true, bit IsTernary = false> {
476476
Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
477477
Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
478-
Pseudo UnmaskedTUPseudo = !if(HasTU, !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")), MaskedPseudo);
478+
Pseudo UnmaskedTUPseudo = !cond(HasTU : !cast<Pseudo>(!subst("_MASK", "", NAME # "_TU")),
479+
IsTernary : UnmaskedPseudo,
480+
true : MaskedPseudo);
479481
bits<4> MaskOpIdx = MaskIdx;
480482
}
481483

@@ -3192,7 +3194,8 @@ multiclass VPseudoTernaryWithPolicy<VReg RetClass,
31923194
let VLMul = MInfo.value in {
31933195
let isCommutable = Commutable in
31943196
def "_" # MInfo.MX : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
3195-
def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
3197+
def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
3198+
RISCVMaskedPseudo</*MaskOpIdx*/ 3, /*HasTU*/ false, /*IsTernary*/true>;
31963199
}
31973200
}
31983201

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,26 +1459,13 @@ multiclass VPatNarrowShiftSplat_WX_WI<SDNode op, string instruction_name> {
14591459
multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name> {
14601460
foreach vti = AllFloatVectors in {
14611461
defvar suffix = vti.LMul.MX;
1462-
def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
1463-
vti.RegClass:$rs2, (vti.Mask true_mask),
1464-
VLOpFrag)),
1465-
(!cast<Instruction>(instruction_name#"_VV_"# suffix)
1466-
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
1467-
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
14681462
def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
14691463
vti.RegClass:$rs2, (vti.Mask V0),
14701464
VLOpFrag)),
14711465
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
14721466
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
14731467
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
14741468

1475-
def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1),
1476-
vti.RegClass:$rd, vti.RegClass:$rs2,
1477-
(vti.Mask true_mask),
1478-
VLOpFrag)),
1479-
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix)
1480-
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1481-
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
14821469
def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1),
14831470
vti.RegClass:$rd, vti.RegClass:$rs2,
14841471
(vti.Mask V0),
@@ -1492,27 +1479,13 @@ multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name>
14921479
multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
14931480
foreach vti = AllFloatVectors in {
14941481
defvar suffix = vti.LMul.MX;
1495-
def : Pat<(riscv_vp_merge_vl (vti.Mask true_mask),
1496-
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
1497-
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
1498-
vti.RegClass:$rd, VLOpFrag),
1499-
(!cast<Instruction>(instruction_name#"_VV_"# suffix)
1500-
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
1501-
GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
15021482
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
15031483
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
15041484
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
15051485
vti.RegClass:$rd, VLOpFrag),
15061486
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
15071487
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
15081488
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
1509-
def : Pat<(riscv_vp_merge_vl (vti.Mask true_mask),
1510-
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
1511-
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
1512-
vti.RegClass:$rd, VLOpFrag),
1513-
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix)
1514-
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
1515-
GPR:$vl, vti.Log2SEW, TAIL_UNDISTURBED_MASK_UNDISTURBED)>;
15161489
def : Pat<(riscv_vp_merge_vl (vti.Mask V0),
15171490
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
15181491
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),

0 commit comments

Comments
 (0)