Skip to content

Commit ddae50d

Browse files
authored
[RISCV] Combine trunc (sra sext (x), zext (y)) to sra (x, smin (y, scalarsizeinbits(y) - 1)) (#65728)
For RVV, If we want to perform an i8 or i16 element-wise vector arithmetic right shift in the upper C/C++ program, the value to be shifted would be first sign extended to i32, and the shift amount would also be zero_extended to i32 to perform the vsra.vv instruction, and followed by a truncate to get the final calculation result, such pattern will later expanded to a series of "vsetvli" and "vnsrl" instructions later, this is because the RVV spec only support 2 * SEW -> SEW truncate. But for vector, the shift amount can also be determined by smin (Y, ScalarSizeInBits(Y) - 1)). Also, for the vsra instruction, we only care about the low lg2(SEW) bits as the shift amount. - Alive2: https://alive2.llvm.org/ce/z/u3-Zdr - C++ Test cases : https://gcc.godbolt.org/z/q1qE7fbha
1 parent 2861ec8 commit ddae50d

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13749,6 +13749,56 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1374913749
}
1375013750
}
1375113751
return SDValue();
13752+
case RISCVISD::TRUNCATE_VECTOR_VL: {
13753+
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
13754+
// This would be benefit for the cases where X and Y are both the same value
13755+
// type of low precision vectors. Since the truncate would be lowered into
13756+
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
13757+
// restriction, such pattern would be expanded into a series of "vsetvli"
13758+
// and "vnsrl" instructions later to reach this point.
13759+
auto IsTruncNode = [](SDValue V) {
13760+
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
13761+
return false;
13762+
SDValue VL = V.getOperand(2);
13763+
auto *C = dyn_cast<ConstantSDNode>(VL);
13764+
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
13765+
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
13766+
(isa<RegisterSDNode>(VL) &&
13767+
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
13768+
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
13769+
IsVLMAXForVMSET;
13770+
};
13771+
13772+
SDValue Op = N->getOperand(0);
13773+
13774+
// We need to first find the inner level of TRUNCATE_VECTOR_VL node
13775+
// to distinguish such pattern.
13776+
while (IsTruncNode(Op)) {
13777+
if (!Op.hasOneUse())
13778+
return SDValue();
13779+
Op = Op.getOperand(0);
13780+
}
13781+
13782+
if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
13783+
SDValue N0 = Op.getOperand(0);
13784+
SDValue N1 = Op.getOperand(1);
13785+
if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
13786+
N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
13787+
SDValue N00 = N0.getOperand(0);
13788+
SDValue N10 = N1.getOperand(0);
13789+
if (N00.getValueType().isVector() &&
13790+
N00.getValueType() == N10.getValueType() &&
13791+
N->getValueType(0) == N10.getValueType()) {
13792+
unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
13793+
SDValue SMin = DAG.getNode(
13794+
ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
13795+
DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
13796+
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
13797+
}
13798+
}
13799+
}
13800+
break;
13801+
}
1375213802
case ISD::TRUNCATE:
1375313803
return performTRUNCATECombine(N, DAG, Subtarget);
1375413804
case ISD::SELECT:

llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ define <vscale x 1 x i8> @vsra_vv_nxv1i8(<vscale x 1 x i8> %va, <vscale x 1 x i8
1212
ret <vscale x 1 x i8> %vc
1313
}
1414

15+
define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb) {
16+
; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext:
17+
; CHECK: # %bb.0:
18+
; CHECK-NEXT: li a0, 7
19+
; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
20+
; CHECK-NEXT: vmin.vx v9, v8, a0
21+
; CHECK-NEXT: vsra.vv v8, v8, v9
22+
; CHECK-NEXT: ret
23+
%sexted_va = sext <vscale x 1 x i8> %va to <vscale x 1 x i32>
24+
%zexted_vb = zext <vscale x 1 x i8> %va to <vscale x 1 x i32>
25+
%expand = ashr <vscale x 1 x i32> %sexted_va, %zexted_vb
26+
%vc = trunc <vscale x 1 x i32> %expand to <vscale x 1 x i8>
27+
ret <vscale x 1 x i8> %vc
28+
}
29+
1530
define <vscale x 1 x i8> @vsra_vx_nxv1i8(<vscale x 1 x i8> %va, i8 signext %b) {
1631
; CHECK-LABEL: vsra_vx_nxv1i8:
1732
; CHECK: # %bb.0:
@@ -46,6 +61,21 @@ define <vscale x 2 x i8> @vsra_vv_nxv2i8(<vscale x 2 x i8> %va, <vscale x 2 x i8
4661
ret <vscale x 2 x i8> %vc
4762
}
4863

64+
define <vscale x 2 x i8> @vsra_vv_nxv2i8_sext_zext(<vscale x 2 x i8> %va, <vscale x 2 x i8> %vb) {
65+
; CHECK-LABEL: vsra_vv_nxv2i8_sext_zext:
66+
; CHECK: # %bb.0:
67+
; CHECK-NEXT: li a0, 7
68+
; CHECK-NEXT: vsetvli a1, zero, e8, mf4, ta, ma
69+
; CHECK-NEXT: vmin.vx v9, v8, a0
70+
; CHECK-NEXT: vsra.vv v8, v8, v9
71+
; CHECK-NEXT: ret
72+
%sexted_va = sext <vscale x 2 x i8> %va to <vscale x 2 x i32>
73+
%zexted_vb = zext <vscale x 2 x i8> %va to <vscale x 2 x i32>
74+
%expand = ashr <vscale x 2 x i32> %sexted_va, %zexted_vb
75+
%vc = trunc <vscale x 2 x i32> %expand to <vscale x 2 x i8>
76+
ret <vscale x 2 x i8> %vc
77+
}
78+
4979
define <vscale x 2 x i8> @vsra_vx_nxv2i8(<vscale x 2 x i8> %va, i8 signext %b) {
5080
; CHECK-LABEL: vsra_vx_nxv2i8:
5181
; CHECK: # %bb.0:
@@ -80,6 +110,21 @@ define <vscale x 4 x i8> @vsra_vv_nxv4i8(<vscale x 4 x i8> %va, <vscale x 4 x i8
80110
ret <vscale x 4 x i8> %vc
81111
}
82112

113+
define <vscale x 4 x i8> @vsra_vv_nxv4i8_sext_zext(<vscale x 4 x i8> %va, <vscale x 4 x i8> %vb) {
114+
; CHECK-LABEL: vsra_vv_nxv4i8_sext_zext:
115+
; CHECK: # %bb.0:
116+
; CHECK-NEXT: li a0, 7
117+
; CHECK-NEXT: vsetvli a1, zero, e8, mf2, ta, ma
118+
; CHECK-NEXT: vmin.vx v9, v8, a0
119+
; CHECK-NEXT: vsra.vv v8, v8, v9
120+
; CHECK-NEXT: ret
121+
%sexted_va = sext <vscale x 4 x i8> %va to <vscale x 4 x i32>
122+
%zexted_vb = zext <vscale x 4 x i8> %va to <vscale x 4 x i32>
123+
%expand = ashr <vscale x 4 x i32> %sexted_va, %zexted_vb
124+
%vc = trunc <vscale x 4 x i32> %expand to <vscale x 4 x i8>
125+
ret <vscale x 4 x i8> %vc
126+
}
127+
83128
define <vscale x 4 x i8> @vsra_vx_nxv4i8(<vscale x 4 x i8> %va, i8 signext %b) {
84129
; CHECK-LABEL: vsra_vx_nxv4i8:
85130
; CHECK: # %bb.0:
@@ -114,6 +159,21 @@ define <vscale x 8 x i8> @vsra_vv_nxv8i8(<vscale x 8 x i8> %va, <vscale x 8 x i8
114159
ret <vscale x 8 x i8> %vc
115160
}
116161

162+
define <vscale x 8 x i8> @vsra_vv_nxv8i8_sext_zext(<vscale x 8 x i8> %va, <vscale x 8 x i8> %vb) {
163+
; CHECK-LABEL: vsra_vv_nxv8i8_sext_zext:
164+
; CHECK: # %bb.0:
165+
; CHECK-NEXT: li a0, 7
166+
; CHECK-NEXT: vsetvli a1, zero, e8, m1, ta, ma
167+
; CHECK-NEXT: vmin.vx v9, v8, a0
168+
; CHECK-NEXT: vsra.vv v8, v8, v9
169+
; CHECK-NEXT: ret
170+
%sexted_va = sext <vscale x 8 x i8> %va to <vscale x 8 x i32>
171+
%zexted_vb = zext <vscale x 8 x i8> %va to <vscale x 8 x i32>
172+
%expand = ashr <vscale x 8 x i32> %sexted_va, %zexted_vb
173+
%vc = trunc <vscale x 8 x i32> %expand to <vscale x 8 x i8>
174+
ret <vscale x 8 x i8> %vc
175+
}
176+
117177
define <vscale x 8 x i8> @vsra_vx_nxv8i8(<vscale x 8 x i8> %va, i8 signext %b) {
118178
; CHECK-LABEL: vsra_vx_nxv8i8:
119179
; CHECK: # %bb.0:
@@ -148,6 +208,21 @@ define <vscale x 16 x i8> @vsra_vv_nxv16i8(<vscale x 16 x i8> %va, <vscale x 16
148208
ret <vscale x 16 x i8> %vc
149209
}
150210

211+
define <vscale x 16 x i8> @vsra_vv_nxv16i8_sext_zext(<vscale x 16 x i8> %va, <vscale x 16 x i8> %vb) {
212+
; CHECK-LABEL: vsra_vv_nxv16i8_sext_zext:
213+
; CHECK: # %bb.0:
214+
; CHECK-NEXT: li a0, 7
215+
; CHECK-NEXT: vsetvli a1, zero, e8, m2, ta, ma
216+
; CHECK-NEXT: vmin.vx v10, v8, a0
217+
; CHECK-NEXT: vsra.vv v8, v8, v10
218+
; CHECK-NEXT: ret
219+
%sexted_va = sext <vscale x 16 x i8> %va to <vscale x 16 x i32>
220+
%zexted_vb = zext <vscale x 16 x i8> %va to <vscale x 16 x i32>
221+
%expand = ashr <vscale x 16 x i32> %sexted_va, %zexted_vb
222+
%vc = trunc <vscale x 16 x i32> %expand to <vscale x 16 x i8>
223+
ret <vscale x 16 x i8> %vc
224+
}
225+
151226
define <vscale x 16 x i8> @vsra_vx_nxv16i8(<vscale x 16 x i8> %va, i8 signext %b) {
152227
; CHECK-LABEL: vsra_vx_nxv16i8:
153228
; CHECK: # %bb.0:
@@ -250,6 +325,21 @@ define <vscale x 1 x i16> @vsra_vv_nxv1i16(<vscale x 1 x i16> %va, <vscale x 1 x
250325
ret <vscale x 1 x i16> %vc
251326
}
252327

328+
define <vscale x 1 x i16> @vsra_vv_nxv1i16_sext_zext(<vscale x 1 x i16> %va, <vscale x 1 x i16> %vb) {
329+
; CHECK-LABEL: vsra_vv_nxv1i16_sext_zext:
330+
; CHECK: # %bb.0:
331+
; CHECK-NEXT: li a0, 15
332+
; CHECK-NEXT: vsetvli a1, zero, e16, mf4, ta, ma
333+
; CHECK-NEXT: vmin.vx v9, v8, a0
334+
; CHECK-NEXT: vsra.vv v8, v8, v9
335+
; CHECK-NEXT: ret
336+
%sexted_va = sext <vscale x 1 x i16> %va to <vscale x 1 x i32>
337+
%zexted_vb = zext <vscale x 1 x i16> %va to <vscale x 1 x i32>
338+
%expand = ashr <vscale x 1 x i32> %sexted_va, %zexted_vb
339+
%vc = trunc <vscale x 1 x i32> %expand to <vscale x 1 x i16>
340+
ret <vscale x 1 x i16> %vc
341+
}
342+
253343
define <vscale x 1 x i16> @vsra_vx_nxv1i16(<vscale x 1 x i16> %va, i16 signext %b) {
254344
; CHECK-LABEL: vsra_vx_nxv1i16:
255345
; CHECK: # %bb.0:
@@ -284,6 +374,21 @@ define <vscale x 2 x i16> @vsra_vv_nxv2i16(<vscale x 2 x i16> %va, <vscale x 2 x
284374
ret <vscale x 2 x i16> %vc
285375
}
286376

377+
define <vscale x 2 x i16> @vsra_vv_nxv2i16_sext_zext(<vscale x 2 x i16> %va, <vscale x 2 x i16> %vb) {
378+
; CHECK-LABEL: vsra_vv_nxv2i16_sext_zext:
379+
; CHECK: # %bb.0:
380+
; CHECK-NEXT: li a0, 15
381+
; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, ma
382+
; CHECK-NEXT: vmin.vx v9, v8, a0
383+
; CHECK-NEXT: vsra.vv v8, v8, v9
384+
; CHECK-NEXT: ret
385+
%sexted_va = sext <vscale x 2 x i16> %va to <vscale x 2 x i32>
386+
%zexted_vb = zext <vscale x 2 x i16> %va to <vscale x 2 x i32>
387+
%expand = ashr <vscale x 2 x i32> %sexted_va, %zexted_vb
388+
%vc = trunc <vscale x 2 x i32> %expand to <vscale x 2 x i16>
389+
ret <vscale x 2 x i16> %vc
390+
}
391+
287392
define <vscale x 2 x i16> @vsra_vx_nxv2i16(<vscale x 2 x i16> %va, i16 signext %b) {
288393
; CHECK-LABEL: vsra_vx_nxv2i16:
289394
; CHECK: # %bb.0:
@@ -318,6 +423,21 @@ define <vscale x 4 x i16> @vsra_vv_nxv4i16(<vscale x 4 x i16> %va, <vscale x 4 x
318423
ret <vscale x 4 x i16> %vc
319424
}
320425

426+
define <vscale x 4 x i16> @vsra_vv_nxv4i16_sext_zext(<vscale x 4 x i16> %va, <vscale x 4 x i16> %vb) {
427+
; CHECK-LABEL: vsra_vv_nxv4i16_sext_zext:
428+
; CHECK: # %bb.0:
429+
; CHECK-NEXT: li a0, 15
430+
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
431+
; CHECK-NEXT: vmin.vx v9, v8, a0
432+
; CHECK-NEXT: vsra.vv v8, v8, v9
433+
; CHECK-NEXT: ret
434+
%sexted_va = sext <vscale x 4 x i16> %va to <vscale x 4 x i32>
435+
%zexted_vb = zext <vscale x 4 x i16> %va to <vscale x 4 x i32>
436+
%expand = ashr <vscale x 4 x i32> %sexted_va, %zexted_vb
437+
%vc = trunc <vscale x 4 x i32> %expand to <vscale x 4 x i16>
438+
ret <vscale x 4 x i16> %vc
439+
}
440+
321441
define <vscale x 4 x i16> @vsra_vx_nxv4i16(<vscale x 4 x i16> %va, i16 signext %b) {
322442
; CHECK-LABEL: vsra_vx_nxv4i16:
323443
; CHECK: # %bb.0:
@@ -352,6 +472,21 @@ define <vscale x 8 x i16> @vsra_vv_nxv8i16(<vscale x 8 x i16> %va, <vscale x 8 x
352472
ret <vscale x 8 x i16> %vc
353473
}
354474

475+
define <vscale x 8 x i16> @vsra_vv_nxv8i16_sext_zext(<vscale x 8 x i16> %va, <vscale x 8 x i16> %vb) {
476+
; CHECK-LABEL: vsra_vv_nxv8i16_sext_zext:
477+
; CHECK: # %bb.0:
478+
; CHECK-NEXT: li a0, 15
479+
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, ma
480+
; CHECK-NEXT: vmin.vx v10, v8, a0
481+
; CHECK-NEXT: vsra.vv v8, v8, v10
482+
; CHECK-NEXT: ret
483+
%sexted_va = sext <vscale x 8 x i16> %va to <vscale x 8 x i32>
484+
%zexted_vb = zext <vscale x 8 x i16> %va to <vscale x 8 x i32>
485+
%expand = ashr <vscale x 8 x i32> %sexted_va, %zexted_vb
486+
%vc = trunc <vscale x 8 x i32> %expand to <vscale x 8 x i16>
487+
ret <vscale x 8 x i16> %vc
488+
}
489+
355490
define <vscale x 8 x i16> @vsra_vx_nxv8i16(<vscale x 8 x i16> %va, i16 signext %b) {
356491
; CHECK-LABEL: vsra_vx_nxv8i16:
357492
; CHECK: # %bb.0:
@@ -386,6 +521,21 @@ define <vscale x 16 x i16> @vsra_vv_nxv16i16(<vscale x 16 x i16> %va, <vscale x
386521
ret <vscale x 16 x i16> %vc
387522
}
388523

524+
define <vscale x 16 x i16> @vsra_vv_nxv16i16_sext_zext(<vscale x 16 x i16> %va, <vscale x 16 x i16> %vb) {
525+
; CHECK-LABEL: vsra_vv_nxv16i16_sext_zext:
526+
; CHECK: # %bb.0:
527+
; CHECK-NEXT: li a0, 15
528+
; CHECK-NEXT: vsetvli a1, zero, e16, m4, ta, ma
529+
; CHECK-NEXT: vmin.vx v12, v8, a0
530+
; CHECK-NEXT: vsra.vv v8, v8, v12
531+
; CHECK-NEXT: ret
532+
%sexted_va = sext <vscale x 16 x i16> %va to <vscale x 16 x i32>
533+
%zexted_vb = zext <vscale x 16 x i16> %va to <vscale x 16 x i32>
534+
%expand = ashr <vscale x 16 x i32> %sexted_va, %zexted_vb
535+
%vc = trunc <vscale x 16 x i32> %expand to <vscale x 16 x i16>
536+
ret <vscale x 16 x i16> %vc
537+
}
538+
389539
define <vscale x 16 x i16> @vsra_vx_nxv16i16(<vscale x 16 x i16> %va, i16 signext %b) {
390540
; CHECK-LABEL: vsra_vx_nxv16i16:
391541
; CHECK: # %bb.0:

0 commit comments

Comments
 (0)