@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
13576
13576
/// This boolean captures whether we care if this operand would still be
13577
13577
/// around after the folding happens.
13578
13578
bool EnforceOneUse;
13579
- /// Records if this operand's mask needs to match the mask of the operation
13580
- /// that it will fold into.
13581
- bool CheckMask;
13582
- /// Value of the Mask for this operand.
13583
- /// It may be SDValue().
13584
- SDValue Mask;
13585
- /// Value of the vector length operand.
13586
- /// It may be SDValue().
13587
- SDValue VL;
13588
13579
/// Original value that this NodeExtensionHelper represents.
13589
13580
SDValue OrigOperand;
13590
13581
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
13789
13780
SupportsSExt = false;
13790
13781
SupportsFPExt = false;
13791
13782
EnforceOneUse = true;
13792
- CheckMask = true;
13793
13783
unsigned Opc = OrigOperand.getOpcode();
13784
+ // For the nodes we handle below, we end up using their inputs directly: see
13785
+ // getSource(). However since they either don't have a passthru or we check
13786
+ // that their passthru is undef, we can safely ignore their mask and VL.
13794
13787
switch (Opc) {
13795
13788
case ISD::ZERO_EXTEND:
13796
13789
case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
13806
13799
13807
13800
SupportsZExt = Opc == ISD::ZERO_EXTEND;
13808
13801
SupportsSExt = Opc == ISD::SIGN_EXTEND;
13809
-
13810
- SDLoc DL(Root);
13811
- std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
13812
13802
break;
13813
13803
}
13814
13804
case RISCVISD::VZEXT_VL:
13815
13805
SupportsZExt = true;
13816
- Mask = OrigOperand.getOperand(1);
13817
- VL = OrigOperand.getOperand(2);
13818
13806
break;
13819
13807
case RISCVISD::VSEXT_VL:
13820
13808
SupportsSExt = true;
13821
- Mask = OrigOperand.getOperand(1);
13822
- VL = OrigOperand.getOperand(2);
13823
13809
break;
13824
13810
case RISCVISD::FP_EXTEND_VL:
13825
13811
SupportsFPExt = true;
13826
- Mask = OrigOperand.getOperand(1);
13827
- VL = OrigOperand.getOperand(2);
13828
13812
break;
13829
13813
case RISCVISD::VMV_V_X_VL: {
13830
13814
// Historically, we didn't care about splat values not disappearing during
13831
13815
// combines.
13832
13816
EnforceOneUse = false;
13833
- CheckMask = false;
13834
- VL = OrigOperand.getOperand(2);
13835
13817
13836
13818
// The operand is a splat of a scalar.
13837
13819
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
13930
13912
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
13931
13913
SupportsFPExt =
13932
13914
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
13933
- std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
13934
- CheckMask = true;
13935
13915
// There's no existing extension here, so we don't have to worry about
13936
13916
// making sure it gets removed.
13937
13917
EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
13944
13924
}
13945
13925
}
13946
13926
13947
- /// Check if this operand is compatible with the given vector length \p VL.
13948
- bool isVLCompatible(SDValue VL) const {
13949
- return this->VL != SDValue() && this->VL == VL;
13950
- }
13951
-
13952
- /// Check if this operand is compatible with the given \p Mask.
13953
- bool isMaskCompatible(SDValue Mask) const {
13954
- return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
13955
- }
13956
-
13957
13927
/// Helper function to get the Mask and VL from \p Root.
13958
13928
static std::pair<SDValue, SDValue>
13959
13929
getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
13973
13943
}
13974
13944
}
13975
13945
13976
- /// Check if the Mask and VL of this operand are compatible with \p Root.
13977
- bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
13978
- const RISCVSubtarget &Subtarget) const {
13979
- auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
13980
- return isMaskCompatible(Mask) && isVLCompatible(VL);
13981
- }
13982
-
13983
13946
/// Helper function to check if \p N is commutative with respect to the
13984
13947
/// foldings that are supported by this class.
13985
13948
static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
14079
14042
const NodeExtensionHelper &RHS,
14080
14043
uint8_t AllowExtMask, SelectionDAG &DAG,
14081
14044
const RISCVSubtarget &Subtarget) {
14082
- if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
14083
- !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
14084
- return std::nullopt;
14085
14045
if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
14086
14046
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
14087
14047
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
14120
14080
canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
14121
14081
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
14122
14082
const RISCVSubtarget &Subtarget) {
14123
- if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
14124
- return std::nullopt;
14125
-
14126
14083
if (RHS.SupportsFPExt)
14127
14084
return CombineResult(
14128
14085
NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
14190
14147
14191
14148
if (!LHS.SupportsSExt || !RHS.SupportsZExt)
14192
14149
return std::nullopt;
14193
- if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
14194
- !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
14195
- return std::nullopt;
14196
14150
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
14197
14151
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
14198
14152
/*RHSExt=*/{ExtKind::ZExt});
0 commit comments