Skip to content

Commit 0f20b9b

Browse files
authored
[RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL (#87997)
In NodeExtensionHelper we keep track of the VL and mask of the operand being extended and check that they are the same as the root node's. However for the nodes that we support, none of them have a passthru operand with the exception of RISCV::VMV_V_X_VL, but we check that it's passthru is undef anyway. So it's safe to just discard the extend node's VL and mask and just use the root's instead. (This is the same type of reasoning we use to treat any vmset_vl as an all ones mask) This allows us to match some more cases where we mix VP/non-VP/VL nodes, but these don't seem to appear in practice. The main benefit from this would be to simplify the code.
1 parent d8d131d commit 0f20b9b

File tree

2 files changed

+62
-50
lines changed

2 files changed

+62
-50
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13552,7 +13552,7 @@ enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
1355213552
/// NodeExtensionHelper for `a` and one for `b`.
1355313553
///
1355413554
/// This class abstracts away how the extension is materialized and
13555-
/// how its Mask, VL, number of users affect the combines.
13555+
/// how its number of users affect the combines.
1355613556
///
1355713557
/// In particular:
1355813558
/// - VWADD_W is conceptually == add(op0, sext(op1))
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
1357613576
/// This boolean captures whether we care if this operand would still be
1357713577
/// around after the folding happens.
1357813578
bool EnforceOneUse;
13579-
/// Records if this operand's mask needs to match the mask of the operation
13580-
/// that it will fold into.
13581-
bool CheckMask;
13582-
/// Value of the Mask for this operand.
13583-
/// It may be SDValue().
13584-
SDValue Mask;
13585-
/// Value of the vector length operand.
13586-
/// It may be SDValue().
13587-
SDValue VL;
1358813579
/// Original value that this NodeExtensionHelper represents.
1358913580
SDValue OrigOperand;
1359013581

@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
1378913780
SupportsSExt = false;
1379013781
SupportsFPExt = false;
1379113782
EnforceOneUse = true;
13792-
CheckMask = true;
1379313783
unsigned Opc = OrigOperand.getOpcode();
13784+
// For the nodes we handle below, we end up using their inputs directly: see
13785+
// getSource(). However since they either don't have a passthru or we check
13786+
// that their passthru is undef, we can safely ignore their mask and VL.
1379413787
switch (Opc) {
1379513788
case ISD::ZERO_EXTEND:
1379613789
case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
1380613799

1380713800
SupportsZExt = Opc == ISD::ZERO_EXTEND;
1380813801
SupportsSExt = Opc == ISD::SIGN_EXTEND;
13809-
13810-
SDLoc DL(Root);
13811-
std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
1381213802
break;
1381313803
}
1381413804
case RISCVISD::VZEXT_VL:
1381513805
SupportsZExt = true;
13816-
Mask = OrigOperand.getOperand(1);
13817-
VL = OrigOperand.getOperand(2);
1381813806
break;
1381913807
case RISCVISD::VSEXT_VL:
1382013808
SupportsSExt = true;
13821-
Mask = OrigOperand.getOperand(1);
13822-
VL = OrigOperand.getOperand(2);
1382313809
break;
1382413810
case RISCVISD::FP_EXTEND_VL:
1382513811
SupportsFPExt = true;
13826-
Mask = OrigOperand.getOperand(1);
13827-
VL = OrigOperand.getOperand(2);
1382813812
break;
1382913813
case RISCVISD::VMV_V_X_VL: {
1383013814
// Historically, we didn't care about splat values not disappearing during
1383113815
// combines.
1383213816
EnforceOneUse = false;
13833-
CheckMask = false;
13834-
VL = OrigOperand.getOperand(2);
1383513817

1383613818
// The operand is a splat of a scalar.
1383713819

@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
1393013912
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
1393113913
SupportsFPExt =
1393213914
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
13933-
std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
13934-
CheckMask = true;
1393513915
// There's no existing extension here, so we don't have to worry about
1393613916
// making sure it gets removed.
1393713917
EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
1394413924
}
1394513925
}
1394613926

13947-
/// Check if this operand is compatible with the given vector length \p VL.
13948-
bool isVLCompatible(SDValue VL) const {
13949-
return this->VL != SDValue() && this->VL == VL;
13950-
}
13951-
13952-
/// Check if this operand is compatible with the given \p Mask.
13953-
bool isMaskCompatible(SDValue Mask) const {
13954-
return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
13955-
}
13956-
1395713927
/// Helper function to get the Mask and VL from \p Root.
1395813928
static std::pair<SDValue, SDValue>
1395913929
getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
1397313943
}
1397413944
}
1397513945

13976-
/// Check if the Mask and VL of this operand are compatible with \p Root.
13977-
bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
13978-
const RISCVSubtarget &Subtarget) const {
13979-
auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
13980-
return isMaskCompatible(Mask) && isVLCompatible(VL);
13981-
}
13982-
1398313946
/// Helper function to check if \p N is commutative with respect to the
1398413947
/// foldings that are supported by this class.
1398513948
static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1407914042
const NodeExtensionHelper &RHS,
1408014043
uint8_t AllowExtMask, SelectionDAG &DAG,
1408114044
const RISCVSubtarget &Subtarget) {
14082-
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
14083-
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
14084-
return std::nullopt;
1408514045
if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
1408614046
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
1408714047
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
1412014080
canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1412114081
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1412214082
const RISCVSubtarget &Subtarget) {
14123-
if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
14124-
return std::nullopt;
14125-
1412614083
if (RHS.SupportsFPExt)
1412714084
return CombineResult(
1412814085
NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
1419014147

1419114148
if (!LHS.SupportsSExt || !RHS.SupportsZExt)
1419214149
return std::nullopt;
14193-
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
14194-
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
14195-
return std::nullopt;
1419614150
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
1419714151
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
1419814152
/*RHSExt=*/{ExtKind::ZExt});

llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,61 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
4141
declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
4242
declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
4343
declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
44+
45+
define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
46+
; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
47+
; CHECK: # %bb.0:
48+
; CHECK-NEXT: slli a0, a0, 32
49+
; CHECK-NEXT: srli a0, a0, 32
50+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
51+
; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
52+
; CHECK-NEXT: vmv1r.v v8, v10
53+
; CHECK-NEXT: ret
54+
%x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
55+
%y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
56+
%add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
57+
ret <vscale x 2 x i32> %add
58+
}
59+
60+
define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
61+
; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
62+
; CHECK: # %bb.0:
63+
; CHECK-NEXT: slli a0, a0, 32
64+
; CHECK-NEXT: srli a0, a0, 32
65+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
66+
; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
67+
; CHECK-NEXT: vmv1r.v v8, v10
68+
; CHECK-NEXT: ret
69+
%x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
70+
%y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
71+
%add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
72+
ret <vscale x 2 x i32> %add
73+
}
74+
75+
define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
76+
; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
77+
; CHECK: # %bb.0:
78+
; CHECK-NEXT: slli a0, a0, 32
79+
; CHECK-NEXT: srli a0, a0, 32
80+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
81+
; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
82+
; CHECK-NEXT: vmv1r.v v8, v10
83+
; CHECK-NEXT: ret
84+
%x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
85+
%y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
86+
%add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
87+
ret <vscale x 2 x i32> %add
88+
}
89+
90+
define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
91+
; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
92+
; CHECK: # %bb.0:
93+
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
94+
; CHECK-NEXT: vwadd.vv v10, v8, v9
95+
; CHECK-NEXT: vmv1r.v v8, v10
96+
; CHECK-NEXT: ret
97+
%x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
98+
%y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
99+
%add = add <vscale x 2 x i32> %x.sext, %y.sext
100+
ret <vscale x 2 x i32> %add
101+
}

0 commit comments

Comments
 (0)