Skip to content

[RISCV] Recursively split concat_vector into smaller LMULs #83035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15283,13 +15283,62 @@ static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
}

// Recursively split up concat_vectors with more than 2 operands:
//
// concat_vector op1, op2, op3, op4
// ->
// concat_vector (concat_vector op1, op2), (concat_vector op3, op4)
//
// This reduces the length of the chain of vslideups and allows us to perform
// the vslideups at a smaller LMUL, limited to MF2.
//
// We do this as a DAG combine rather than during lowering so that any undef
// operands can get combined away.
static SDValue
performCONCAT_VECTORSSplitCombine(SDNode *N, SelectionDAG &DAG,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);

if (N->getNumOperands() <= 2)
return SDValue();

if (!TLI.isTypeLegal(N->getValueType(0)))
return SDValue();
MVT VT = N->getSimpleValueType(0);

// Don't split any further than MF2.
MVT ContainerVT = VT;
if (VT.isFixedLengthVector())
ContainerVT = getContainerForFixedLengthVector(DAG, VT, TLI.getSubtarget());
if (ContainerVT.bitsLT(getLMUL1VT(ContainerVT)))
return SDValue();

MVT HalfVT = VT.getHalfNumVectorElementsVT();
assert(isPowerOf2_32(N->getNumOperands()));
size_t HalfNumOps = N->getNumOperands() / 2;
SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT,
N->ops().take_front(HalfNumOps));
SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT,
N->ops().drop_front(HalfNumOps));

// Lower to an insert_subvector directly so the concat_vectors don't get
// recombined.
SDValue Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), Lo,
DAG.getVectorIdxConstant(0, DL));
Vec = DAG.getNode(
ISD::INSERT_SUBVECTOR, DL, VT, Vec, Hi,
DAG.getVectorIdxConstant(HalfVT.getVectorMinNumElements(), DL));
return Vec;
}

// If we're concatenating a series of vector loads like
// concat_vectors (load v4i8, p+0), (load v4i8, p+n), (load v4i8, p+n*2) ...
// Then we can turn this into a strided load by widening the vector elements
// vlse32 p, stride=n
static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
static SDValue
performCONCAT_VECTORSStridedLoadCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);
EVT VT = N->getValueType(0);

Expand Down Expand Up @@ -16394,7 +16443,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return V;
break;
case ISD::CONCAT_VECTORS:
if (SDValue V = performCONCAT_VECTORSCombine(N, DAG, Subtarget, *this))
if (SDValue V =
performCONCAT_VECTORSStridedLoadCombine(N, DAG, Subtarget, *this))
return V;
if (SDValue V = performCONCAT_VECTORSSplitCombine(N, DAG, *this))
return V;
break;
case ISD::INSERT_VECTOR_ELT:
Expand Down
91 changes: 45 additions & 46 deletions llvm/test/CodeGen/RISCV/rvv/active_lane_mask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -161,72 +161,71 @@ define <64 x i1> @fv64(ptr %p, i64 %index, i64 %tc) {
define <128 x i1> @fv128(ptr %p, i64 %index, i64 %tc) {
; CHECK-LABEL: fv128:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: lui a0, %hi(.LCPI10_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_0)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vid.v v16
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v0, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 4, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_1)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_1)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 6, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 4
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v10, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v8, v16, a2
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_2)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_2)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 6
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v9, 4
; CHECK-NEXT: lui a0, %hi(.LCPI10_3)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_3)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 10, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 8
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v9, 6
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: lui a0, %hi(.LCPI10_4)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_4)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 12, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 10
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vid.v v16
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v0, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v0, v9, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_5)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_5)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 14, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 12
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v0, v9, 4
; CHECK-NEXT: lui a0, %hi(.LCPI10_6)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_6)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vi v0, v16, 14
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vslideup.vi v0, v9, 6
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vi v0, v8, 8
; CHECK-NEXT: ret
%mask = call <128 x i1> @llvm.get.active.lane.mask.v128i1.i64(i64 %index, i64 %tc)
ret <128 x i1> %mask
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/combine-store-extract-crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ define void @test(ptr %ref_array, ptr %sad_array) {
; RV32-NEXT: th.swia a0, (a1), 4, 0
; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; RV32-NEXT: vle8.v v10, (a3)
; RV32-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV32-NEXT: vslideup.vi v10, v9, 4
; RV32-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV32-NEXT: vzext.vf4 v12, v10
Expand All @@ -42,7 +42,7 @@ define void @test(ptr %ref_array, ptr %sad_array) {
; RV64-NEXT: th.swia a0, (a1), 4, 0
; RV64-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; RV64-NEXT: vle8.v v10, (a3)
; RV64-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; RV64-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV64-NEXT: vslideup.vi v10, v9, 4
; RV64-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV64-NEXT: vzext.vf4 v12, v10
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,8 @@ define <vscale x 6 x half> @extract_nxv6f16_nxv12f16_6(<vscale x 12 x half> %in)
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 2
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vslidedown.vx v13, v10, a0
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vslidedown.vx v13, v10, a0
; CHECK-NEXT: vslidedown.vx v12, v9, a0
; CHECK-NEXT: add a1, a0, a0
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
Expand Down
Loading