Skip to content

[RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL #87997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Apr 8, 2024

In NodeExtensionHelper we keep track of the VL and mask of the operand being extended and check that they are the same as the root node's. However for the nodes that we support, none of them have a passthru operand with the exception of RISCV::VMV_V_X_VL, but we check that it's passthru is undef anyway.

So it's safe to just discard the extend node's VL and mask and just use the root's instead. (This is the same type of reasoning we use to treat any vmset_vl as an all ones mask)

This allows us to match some more cases where we mix VP/non-VP/VL nodes, but these don't seem to appear in practice. The main benefit from this would be to simplify the code.

@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes
  • Add tests where mismatching VL/masks prevents vwadd from being combined
  • [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL

Full diff: https://github.com/llvm/llvm-project/pull/87997.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+3-49)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll (+58)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b426f1a7b3791d..e13b3f3ca109fb 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13576,15 +13576,6 @@ struct NodeExtensionHelper {
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
-  /// Records if this operand's mask needs to match the mask of the operation
-  /// that it will fold into.
-  bool CheckMask;
-  /// Value of the Mask for this operand.
-  /// It may be SDValue().
-  SDValue Mask;
-  /// Value of the vector length operand.
-  /// It may be SDValue().
-  SDValue VL;
   /// Original value that this NodeExtensionHelper represents.
   SDValue OrigOperand;
 
@@ -13789,8 +13780,10 @@ struct NodeExtensionHelper {
     SupportsSExt = false;
     SupportsFPExt = false;
     EnforceOneUse = true;
-    CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
+    // For the nodes we handle below, we end up using their inputs directly: see
+    // getSource(). However since they either don't have a passthru or we check
+    // that their passthru is undef, we can safely ignore their mask and VL.
     switch (Opc) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND: {
@@ -13806,32 +13799,21 @@ struct NodeExtensionHelper {
 
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
-
-      SDLoc DL(Root);
-      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
       break;
     }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::FP_EXTEND_VL:
       SupportsFPExt = true;
-      Mask = OrigOperand.getOperand(1);
-      VL = OrigOperand.getOperand(2);
       break;
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
       EnforceOneUse = false;
-      CheckMask = false;
-      VL = OrigOperand.getOperand(2);
 
       // The operand is a splat of a scalar.
 
@@ -13930,8 +13912,6 @@ struct NodeExtensionHelper {
             Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
         SupportsFPExt =
             Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
-        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
-        CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
         EnforceOneUse = false;
@@ -13944,16 +13924,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if this operand is compatible with the given vector length \p VL.
-  bool isVLCompatible(SDValue VL) const {
-    return this->VL != SDValue() && this->VL == VL;
-  }
-
-  /// Check if this operand is compatible with the given \p Mask.
-  bool isMaskCompatible(SDValue Mask) const {
-    return !CheckMask || (this->Mask != SDValue() && this->Mask == Mask);
-  }
-
   /// Helper function to get the Mask and VL from \p Root.
   static std::pair<SDValue, SDValue>
   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
@@ -13973,13 +13943,6 @@ struct NodeExtensionHelper {
     }
   }
 
-  /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
-                              const RISCVSubtarget &Subtarget) const {
-    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
-    return isMaskCompatible(Mask) && isVLCompatible(VL);
-  }
-
   /// Helper function to check if \p N is commutative with respect to the
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
@@ -14079,9 +14042,6 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS,
                                  uint8_t AllowExtMask, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
@@ -14120,9 +14080,6 @@ static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
               const RISCVSubtarget &Subtarget) {
-  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
-
   if (RHS.SupportsFPExt)
     return CombineResult(
         NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
@@ -14190,9 +14147,6 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
 
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
-    return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
                        /*RHSExt=*/{ExtKind::ZExt});
diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
index a0b7726d3cb5e6..433f5d2717e48e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-vp.ll
@@ -41,3 +41,61 @@ declare <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vsca
 declare <vscale x 2 x i32> @llvm.vp.zext.nxv2i32.nxv2i8(<vscale x 2 x i8>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i32>, <vscale x 2 x i1>, i32)
 declare <vscale x 2 x i32> @llvm.vp.merge.nxv2i32(<vscale x 2 x i1>, <vscale x 2 x i32>, <vscale x 2 x i32>, i32)
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_vpnxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_vpnxv2i32_nxv2i16_nxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 32
+; CHECK-NEXT:    srli a0, a0, 32
+; CHECK-NEXT:    vsetvli zero, a0, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = sext <vscale x 2 x i16> %x to <vscale x 2 x i32>
+  %y.sext = sext <vscale x 2 x i16> %y to <vscale x 2 x i32>
+  %add = call <vscale x 2 x i32> @llvm.vp.add.nxv2i32(<vscale x 2 x i32> %x.sext, <vscale x 2 x i32> %y.sext, <vscale x 2 x i1> %m, i32 %evl)
+  ret <vscale x 2 x i32> %add
+}
+
+define <vscale x 2 x i32> @vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 signext %evl) {
+; CHECK-LABEL: vwadd_vv_nxv2i32_vpnxv2i16_vpnxv2i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwadd.vv v10, v8, v9
+; CHECK-NEXT:    vmv1r.v v8, v10
+; CHECK-NEXT:    ret
+  %x.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1> %m, i32 %evl)
+  %y.sext = call <vscale x 2 x i32> @llvm.vp.sext.nxv2i32.nxv2i16(<vscale x 2 x i16> %y, <vscale x 2 x i1> %m, i32 %evl)
+  %add = add <vscale x 2 x i32> %x.sext, %y.sext
+  ret <vscale x 2 x i32> %add
+}

@lukel97 lukel97 changed the title combineBinOp VLToVWBinOp VL removeMaskVLCheck [RISCV] Don't require mask or VL to be the same in combineBinOp_VLToVWBinOp_VL Apr 8, 2024
Copy link
Member

@sun-jacobi sun-jacobi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM

@lukel97 lukel97 requested a review from qcolombet April 8, 2024 15:30
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

lukel97 added 3 commits April 9, 2024 15:54
…WBinOp_VL

In NodeExtensionHelper we keep track of the VL and mask of the operand being extended and check that they are the same as the root node's. However for the nodes that we support, none of them have a passthru operand with the exception of RISCV::VMV_V_X_VL, but we check that it's passthru is undef anyway.

So it's safe to just discard the extend node's VL and mask and just use the root's instead. (This is the same type of reasoning we use to treat any vmset_vl as an all ones mask)

This allows us to match some more cases where we mix VP/non-VP/VL nodes, but these don't seem to appear in practice. The main benefit from this would be to simplify the code.
@lukel97 lukel97 force-pushed the combineBinOp_VLToVWBinOp_VL-removeMaskVLCheck branch from af02242 to cfbd6b0 Compare April 9, 2024 07:56
@lukel97 lukel97 merged commit 0f20b9b into llvm:main Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants