-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL #87997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL #87997
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) Changes
Full diff: https://github.com/llvm/llvm-project/pull/87997.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b426f1a7b3791d..e13b3f3ca109fb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
- /// Records if this operand's mask needs to match the mask of the operation
- /// that it will fold into.
- bool CheckMask;
- /// Value of the Mask for this operand.
- /// It may be SDValue().
- SDValue Mask;
- /// Value of the vector length operand.
- /// It may be SDValue().
- SDValue VL;
/// Original value that this NodeExtensionHelper represents.
SDValue OrigOperand;
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
SupportsSExt = false;
SupportsFPExt = false;
EnforceOneUse = true;
- CheckMask = true;
unsigned Opc = OrigOperand.getOpcode();
+ // For the nodes we handle below, we end up using their inputs directly: see
+ // getSource(). However since they either don't have a passthru or we check
+ // that their passthru is undef, we can safely ignore their mask and VL.
switch (Opc) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
SupportsZExt = Opc == ISD::ZERO_EXTEND;
SupportsSExt = Opc == ISD::SIGN_EXTEND;
-
- SDLoc DL(Root);
- std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
break;
}
case RISCVISD::VZEXT_VL:
SupportsZExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
break;
case RISCVISD::VSEXT_VL:
SupportsSExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
break;
case RISCVISD::FP_EXTEND_VL:
SupportsFPExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
break;
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
EnforceOneUse = false;
- CheckMask = false;
- VL = OrigOperand.getOperand(2);
// The operand is a splat of a scalar.
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
SupportsFPExt =
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
- std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
- CheckMask = true;
// There's no existing extension here, so we don't have to worry about
// making sure it gets removed.
EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
}
}
- /// Check if this operand is compatible with the given vector length \p VL.
- bool isVLCompatible(SDValue VL) const {
- return this->VL != SDValue() && this->VL == VL;
- }
-
- /// Check if this operand is compatible with the given \p Mask.
- bool isMaskCompatible(SDValue Mask) const {
- return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
- }
-
/// Helper function to get the Mask and VL from \p Root.
static std::pair<SDValue, SDValue>
getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
}
}
- /// Check if the Mask and VL of this operand are compatible with \p Root.
- bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) const {
- auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
- return isMaskCompatible(Mask) && isVLCompatible(VL);
- }
-
/// Helper function to check if \p N is commutative with respect to the
/// foldings that are supported by this class.
static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS,
uint8_t AllowExtMask, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
- !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
- return std::nullopt;
if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
- return std::nullopt;
-
if (RHS.SupportsFPExt)
return CombineResult(
NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
if (!LHS.SupportsSExt || !RHS.SupportsZExt)
return std::nullopt;
- if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
- !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
- return std::nullopt;
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
/*RHSExt=*/{ExtKind::ZExt});
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index a0b7726d3cb5e6..433f5d2717e48e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -41,3 +41,61 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+ %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+ %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+ ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+ %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+ %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+ ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 32
+; CHECK-NEXT: srli a0, a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
+ %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+ %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+ ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vwadd.vv v10, v8, v9
+; CHECK-NEXT: vmv1r.v v8, v10
+; CHECK-NEXT: ret
+ %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+ %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+ %add = add <vscale x 2 x i32> %x.sext, %y.sext
+ ret <vscale x 2 x i32> %add
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…WBinOp_VL In NodeExtensionHelper we keep track of the VL and mask of the operand being extended and check that they are the same as the root node's. However for the nodes that we support, none of them have a passthru operand with the exception of RISCV::VMV_V_X_VL, but we check that it's passthru is undef anyway. So it's safe to just discard the extend node's VL and mask and just use the root's instead. (This is the same type of reasoning we use to treat any vmset_vl as an all ones mask) This allows us to match some more cases where we mix VP/non-VP/VL nodes, but these don't seem to appear in practice. The main benefit from this would be to simplify the code.
af02242
to
cfbd6b0
Compare
In NodeExtensionHelper we keep track of the VL and mask of the operand being extended and check that they are the same as the root node's. However for the nodes that we support, none of them have a passthru operand with the exception of RISCV::VMV_V_X_VL, but we check that it's passthru is undef anyway.
So it's safe to just discard the extend node's VL and mask and just use the root's instead. (This is the same type of reasoning we use to treat any vmset_vl as an all ones mask)
This allows us to match some more cases where we mix VP/non-VP/VL nodes, but these don't seem to appear in practice. The main benefit from this would be to simplify the code.