Skip to content

[RISCV] Combine trunc (srl zext (x), zext (y)) to srl (x, umin (y, scalarsizeinbits(y) - 1)) #69092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

LWenH
Copy link
Member

@LWenH LWenH commented Oct 15, 2023

Like #65728, for i8/i16 element-wise vector logical right shift, the source value and the shift amount
would first be zero-extended to i32, then the vsrl instruction is performed, followed by a truncation
to obtain the final calculation result. This would be expanded into a series of "vsetvli" and "vnsrl" instructions
later. For RVV, the vsrl instruction only treats the lg2(sew) bits as the shift amount, so we can also get
the shift amount by using umin(Y, scalarsize(Y) - 1).

LWenH added 2 commits October 15, 2023 14:10
…alarsize(Y) - 1)

Like llvm#65728, for i8/i16 element-wise vector logical right shift,
the src value would be first zext to i32 and the shift amount
would be zext to i32 to perform the vsrl instruction, and followed
by a trunc to get the final calculation result. This would
be expanded into a series of "vsetvli" and "vnsrl" instructions
later. For RVV, the vsrl instruction only treats the lg2(sew)
bits as the shift amount, so we can calculate the shift amount
by using umin(Y,  scalarsize(Y) - 1).
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2023

@llvm/pr-subscribers-backend-risc-v

Author: Vettel (LWenH)

Changes

Like #65728, for i8/i16 element-wise vector logical right shift, the source value and the shift amount
would first be zero-extended to i32, then the vsrl instruction is performed, followed by a truncation
to obtain the final calculation result. This would be expanded into a series of "vsetvli" and "vnsrl" instructions
later. For RVV, the vsrl instruction only treats the lg2(sew) bits as the shift amount, so we can also get
the shift amount by using umin(Y, scalarsize(Y) - 1).


Full diff: https://github.com/llvm/llvm-project/pull/69092.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+21)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vsrl-sdnode.ll (+150)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d7552317fd8bc69..036e5655a2984cc 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14303,6 +14303,27 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
         }
       }
     }
+
+    // Similarly, we can also optimize the zext nodes for the srl here
+    // trunc (srl zext (X), zext (Y)) -> srl (X, umin (Y, scalarsize(Y) - 1))
+    if (Op.getOpcode() == ISD::SRL && Op.hasOneUse()) {
+      SDValue N0 = Op.getOperand(0);
+      SDValue N1 = Op.getOperand(1);
+      if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
+          N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
+        SDValue N00 = N0.getOperand(0);
+        SDValue N10 = N1.getOperand(0);
+        if (N00.getValueType().isVector() &&
+            N00.getValueType() == N10.getValueType() &&
+            N->getValueType(0) == N10.getValueType()) {
+          unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+          SDValue UMin = DAG.getNode(
+              ISD::UMIN, SDLoc(N1), N->getValueType(0), N10,
+              DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+          return DAG.getNode(ISD::SRL, SDLoc(N), N->getValueType(0), N00, UMin);
+        }
+      }
+    }
     break;
   }
   case ISD::TRUNCATE:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vsrl-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vsrl-sdnode.ll
index be70b20181b1484..8b2201e147ffe1e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vsrl-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vsrl-sdnode.ll
@@ -26,6 +26,21 @@ define <vscale x 1 x i8> @vsrl_vx_nxv1i8_0(<vscale x 1 x i8> %va) {
   ret <vscale x 1 x i8> %vc
 }
 
+define <vscale x 1 x i8> @vsrl_vv_nxv1i8_zext_zext(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv1i8_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 1 x i8> %va to <vscale x 1 x i32>
+  %zexted_vb = zext <vscale x 1 x i8> %vb to <vscale x 1 x i32>
+  %expand = lshr <vscale x 1 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 1 x i32> %expand to <vscale x 1 x i8>
+  ret <vscale x 1 x i8> %vc
+}
+
 define <vscale x 2 x i8> @vsrl_vx_nxv2i8(<vscale x 2 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv2i8:
 ; CHECK:       # %bb.0:
@@ -50,6 +65,21 @@ define <vscale x 2 x i8> @vsrl_vx_nxv2i8_0(<vscale x 2 x i8> %va) {
   ret <vscale x 2 x i8> %vc
 }
 
+define <vscale x 2 x i8> @vsrl_vv_nxv2i8_zext_zext(<vscale x 2 x i8> %va, <vscale x 2 x i8> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv2i8_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 2 x i8> %va to <vscale x 2 x i32>
+  %zexted_vb = zext <vscale x 2 x i8> %vb to <vscale x 2 x i32>
+  %expand = lshr <vscale x 2 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 2 x i32> %expand to <vscale x 2 x i8>
+  ret <vscale x 2 x i8> %vc
+}
+
 define <vscale x 4 x i8> @vsrl_vx_nxv4i8(<vscale x 4 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv4i8:
 ; CHECK:       # %bb.0:
@@ -74,6 +104,21 @@ define <vscale x 4 x i8> @vsrl_vx_nxv4i8_0(<vscale x 4 x i8> %va) {
   ret <vscale x 4 x i8> %vc
 }
 
+define <vscale x 4 x i8> @vsrl_vv_nxv4i8_zext_zext(<vscale x 4 x i8> %va, <vscale x 4 x i8> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv4i8_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 4 x i8> %va to <vscale x 4 x i32>
+  %zexted_vb = zext <vscale x 4 x i8> %vb to <vscale x 4 x i32>
+  %expand = lshr <vscale x 4 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 4 x i32> %expand to <vscale x 4 x i8>
+  ret <vscale x 4 x i8> %vc
+}
+
 define <vscale x 8 x i8> @vsrl_vx_nxv8i8(<vscale x 8 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv8i8:
 ; CHECK:       # %bb.0:
@@ -98,6 +143,21 @@ define <vscale x 8 x i8> @vsrl_vx_nxv8i8_0(<vscale x 8 x i8> %va) {
   ret <vscale x 8 x i8> %vc
 }
 
+define <vscale x 8 x i8> @vsrl_vv_nxv8i8_zext_zext(<vscale x 8 x i8> %va, <vscale x 8 x i8> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv8i8_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 8 x i8> %va to <vscale x 8 x i32>
+  %zexted_vb = zext <vscale x 8 x i8> %vb to <vscale x 8 x i32>
+  %expand = lshr <vscale x 8 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 8 x i32> %expand to <vscale x 8 x i8>
+  ret <vscale x 8 x i8> %vc
+}
+
 define <vscale x 16 x i8> @vsrl_vx_nxv16i8(<vscale x 16 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv16i8:
 ; CHECK:       # %bb.0:
@@ -122,6 +182,21 @@ define <vscale x 16 x i8> @vsrl_vx_nxv16i8_0(<vscale x 16 x i8> %va) {
   ret <vscale x 16 x i8> %vc
 }
 
+define <vscale x 16 x i8> @vsrl_vv_nxv16i8_zext_zext(<vscale x 16 x i8> %va, <vscale x 16 x i8> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv16i8_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, m2, ta, ma
+; CHECK-NEXT:    vminu.vx v10, v10, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v10
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 16 x i8> %va to <vscale x 16 x i32>
+  %zexted_vb = zext <vscale x 16 x i8> %vb to <vscale x 16 x i32>
+  %expand = lshr <vscale x 16 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 16 x i32> %expand to <vscale x 16 x i8>
+  ret <vscale x 16 x i8> %vc
+}
+
 define <vscale x 32 x i8> @vsrl_vx_nxv32i8(<vscale x 32 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv32i8:
 ; CHECK:       # %bb.0:
@@ -194,6 +269,21 @@ define <vscale x 1 x i16> @vsrl_vx_nxv1i16_0(<vscale x 1 x i16> %va) {
   ret <vscale x 1 x i16> %vc
 }
 
+define <vscale x 1 x i16> @vsrl_vv_nxv1i16_zext_zext(<vscale x 1 x i16> %va, <vscale x 1 x i16> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv1i16_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 1 x i16> %va to <vscale x 1 x i32>
+  %zexted_vb = zext <vscale x 1 x i16> %vb to <vscale x 1 x i32>
+  %expand = lshr <vscale x 1 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 1 x i32> %expand to <vscale x 1 x i16>
+  ret <vscale x 1 x i16> %vc
+}
+
 define <vscale x 2 x i16> @vsrl_vx_nxv2i16(<vscale x 2 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv2i16:
 ; CHECK:       # %bb.0:
@@ -218,6 +308,21 @@ define <vscale x 2 x i16> @vsrl_vx_nxv2i16_0(<vscale x 2 x i16> %va) {
   ret <vscale x 2 x i16> %vc
 }
 
+define <vscale x 2 x i16> @vsrl_vv_nxv2i16_zext_zext(<vscale x 2 x i16> %va, <vscale x 2 x i16> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv2i16_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 2 x i16> %va to <vscale x 2 x i32>
+  %zexted_vb = zext <vscale x 2 x i16> %vb to <vscale x 2 x i32>
+  %expand = lshr <vscale x 2 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 2 x i32> %expand to <vscale x 2 x i16>
+  ret <vscale x 2 x i16> %vc
+}
+
 define <vscale x 4 x i16> @vsrl_vx_nxv4i16(<vscale x 4 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv4i16:
 ; CHECK:       # %bb.0:
@@ -242,6 +347,21 @@ define <vscale x 4 x i16> @vsrl_vx_nxv4i16_0(<vscale x 4 x i16> %va) {
   ret <vscale x 4 x i16> %vc
 }
 
+define <vscale x 4 x i16> @vsrl_vv_nxv4i16_zext_zext(<vscale x 4 x i16> %va, <vscale x 4 x i16> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv4i16_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vminu.vx v9, v9, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 4 x i16> %va to <vscale x 4 x i32>
+  %zexted_vb = zext <vscale x 4 x i16> %vb to <vscale x 4 x i32>
+  %expand = lshr <vscale x 4 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 4 x i32> %expand to <vscale x 4 x i16>
+  ret <vscale x 4 x i16> %vc
+}
+
 define <vscale x 8 x i16> @vsrl_vx_nxv8i16(<vscale x 8 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv8i16:
 ; CHECK:       # %bb.0:
@@ -266,6 +386,21 @@ define <vscale x 8 x i16> @vsrl_vx_nxv8i16_0(<vscale x 8 x i16> %va) {
   ret <vscale x 8 x i16> %vc
 }
 
+define <vscale x 8 x i16> @vsrl_vv_nxv8i16_zext_zext(<vscale x 8 x i16> %va, <vscale x 8 x i16> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv8i16_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vminu.vx v10, v10, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v10
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 8 x i16> %va to <vscale x 8 x i32>
+  %zexted_vb = zext <vscale x 8 x i16> %vb to <vscale x 8 x i32>
+  %expand = lshr <vscale x 8 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 8 x i32> %expand to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %vc
+}
+
 define <vscale x 16 x i16> @vsrl_vx_nxv16i16(<vscale x 16 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv16i16:
 ; CHECK:       # %bb.0:
@@ -290,6 +425,21 @@ define <vscale x 16 x i16> @vsrl_vx_nxv16i16_0(<vscale x 16 x i16> %va) {
   ret <vscale x 16 x i16> %vc
 }
 
+define <vscale x 16 x i16> @vsrl_vv_nxv16i16_zext_zext(<vscale x 16 x i16> %va, <vscale x 16 x i16> %vb) {
+; CHECK-LABEL: vsrl_vv_nxv16i16_zext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vminu.vx v12, v12, a0
+; CHECK-NEXT:    vsrl.vv v8, v8, v12
+; CHECK-NEXT:    ret
+  %zexted_va = zext <vscale x 16 x i16> %va to <vscale x 16 x i32>
+  %zexted_vb = zext <vscale x 16 x i16> %vb to <vscale x 16 x i32>
+  %expand = lshr <vscale x 16 x i32> %zexted_va, %zexted_vb
+  %vc = trunc <vscale x 16 x i32> %expand to <vscale x 16 x i16>
+  ret <vscale x 16 x i16> %vc
+}
+
 define <vscale x 32 x i16> @vsrl_vx_nxv32i16(<vscale x 32 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsrl_vx_nxv32i16:
 ; CHECK:       # %bb.0:

@lukel97
Copy link
Contributor

lukel97 commented Oct 16, 2023

It this valid for y > 7? E.g. (trunc (srl (zext 128 to i32), (zext 8 to i32)) to i8) = 0, but if we clip y to umin(y, scalarsizeinbits(y) - 1) then we get (srl 128, 7) = 1

https://alive2.llvm.org/ce/z/T28oxz

@LWenH
Copy link
Member Author

LWenH commented Oct 16, 2023

It this valid for y > 7? E.g. (trunc (srl (zext 128 to i32), (zext 8 to i32)) to i8) = 0, but if we clip y to umin(y, scalarsizeinbits(y) - 1) then we get (srl 128, 7) = 1

https://alive2.llvm.org/ce/z/T28oxz

Yeah, I think you are right. Even though the rvv only utilize the lg2(sew) bits as the shift amount. But the result are still different after zext, this is only work for arithmetic right shift, for vsrl the transformation for y > 7 is still inconsistent. To promise
upper C/C++ int8/int16 data type program logical shift right correctly, even for rvv, it's wise to zext to i32 type first.
Thank you for verifying this case.

  • Close this pull request.

@LWenH LWenH closed this Oct 16, 2023
@lukel97
Copy link
Contributor

lukel97 commented Oct 16, 2023

Yeah, but it looks like GCC is still able to perform the vsrl on the narrower SEW with vminu.vv from this C example: https://gcc.godbolt.org/z/r5refhWc9
Is there another combine we're able to add for this?

@topperc
Copy link
Collaborator

topperc commented Oct 16, 2023

Is #65728 also broken?

@LWenH
Copy link
Member Author

LWenH commented Oct 16, 2023

#65728 is still work. This is actually quite a tricky but interesting question about the rvv spec I think, such optimization can't work for srl, but it's work for sra. For the vsrl, I think in the rvv spec may not only define the lg2(sew) as shift amount, like above example, the shift amount is valid where y is 8, but for vsra, 7 is the maximum shift amount.

I think the RVV spec could be better to seperate this lg2(sew) restriction for vsrl and vsra.

@lukel97
Copy link
Contributor

lukel97 commented Oct 16, 2023

Specifically I think #65728 is fine because sra has the property that ∀ n : nat. (sra x, sew-1) = (sra x, sew-1+n), but srl doesn't i.e. (srl x, sew-1) != (srl x, sew)

@LWenH
Copy link
Member Author

LWenH commented Oct 16, 2023

Specifically I think #65728 is fine because sra has the property that ∀ n : nat. (sra x, sew-1) = (sra x, sew-1+n), but srl doesn't i.e. (srl x, sew-1) != (srl x, sew)

Yeah, still big difference between vsrl and vsra, but through this I think lg2(sew) is still not enough for vsrl, logical right shift can reach the sew bits, that’s interesting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants