Skip to content

[SelectionDAG][RISCV] Add support for splitting vp.splice #145184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_VECTOR_INTERLEAVE(SDNode *N);
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_SPLICE(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_GET_ACTIVE_LANE_MASK(SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
75 changes: 75 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::UDIVFIXSAT:
SplitVecRes_FIX(N, Lo, Hi);
break;
case ISD::EXPERIMENTAL_VP_SPLICE:
SplitVecRes_VP_SPLICE(N, Lo, Hi);
break;
case ISD::EXPERIMENTAL_VP_REVERSE:
SplitVecRes_VP_REVERSE(N, Lo, Hi);
break;
Expand Down Expand Up @@ -3209,6 +3212,78 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}

void DAGTypeLegalizer::SplitVecRes_VP_SPLICE(SDNode *N, SDValue &Lo,
SDValue &Hi) {
EVT VT = N->getValueType(0);
SDValue V1 = N->getOperand(0);
SDValue V2 = N->getOperand(1);
int64_t Imm = cast<ConstantSDNode>(N->getOperand(2))->getSExtValue();
SDValue Mask = N->getOperand(3);
SDValue EVL1 = N->getOperand(4);
SDValue EVL2 = N->getOperand(5);
SDLoc DL(N);

// Since EVL2 is considered the real VL it gets promoted during
// SelectionDAGBuilder. Promote EVL1 here if needed.
if (getTypeAction(EVL1.getValueType()) == TargetLowering::TypePromoteInteger)
EVL1 = ZExtPromotedInteger(EVL1);

Align Alignment = DAG.getReducedAlign(VT, /*UseABI=*/false);

EVT MemVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
VT.getVectorElementCount() * 2);
SDValue StackPtr = DAG.CreateStackTemporary(MemVT.getStoreSize(), Alignment);
EVT PtrVT = StackPtr.getValueType();
auto &MF = DAG.getMachineFunction();
auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);

MachineMemOperand *StoreMMO = DAG.getMachineFunction().getMachineMemOperand(
PtrInfo, MachineMemOperand::MOStore, LocationSize::beforeOrAfterPointer(),
Alignment);
MachineMemOperand *LoadMMO = DAG.getMachineFunction().getMachineMemOperand(
PtrInfo, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
Alignment);

SDValue StackPtr2 = TLI.getVectorElementPointer(DAG, StackPtr, VT, EVL1);

SDValue TrueMask = DAG.getBoolConstant(true, DL, Mask.getValueType(), VT);
SDValue StoreV1 = DAG.getStoreVP(DAG.getEntryNode(), DL, V1, StackPtr,
DAG.getUNDEF(PtrVT), TrueMask, EVL1,
V1.getValueType(), StoreMMO, ISD::UNINDEXED);

SDValue StoreV2 =
DAG.getStoreVP(StoreV1, DL, V2, StackPtr2, DAG.getUNDEF(PtrVT), TrueMask,
EVL2, V2.getValueType(), StoreMMO, ISD::UNINDEXED);

SDValue Load;
if (Imm >= 0) {
StackPtr = TLI.getVectorElementPointer(DAG, StackPtr, VT, N->getOperand(2));
Load = DAG.getLoadVP(VT, DL, StoreV2, StackPtr, Mask, EVL2, LoadMMO);
} else {
uint64_t TrailingElts = -Imm;
unsigned EltWidth = VT.getScalarSizeInBits() / 8;
SDValue TrailingBytes = DAG.getConstant(TrailingElts * EltWidth, DL, PtrVT);

// Make sure TrailingBytes doesn't exceed the size of vec1.
SDValue OffsetToV2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, StackPtr);
TrailingBytes =
DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, OffsetToV2);

// Calculate the start address of the spliced result.
StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
Load = DAG.getLoadVP(VT, DL, StoreV2, StackPtr2, Mask, EVL2, LoadMMO);
}

EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoVT, Load,
DAG.getVectorIdxConstant(0, DL));
Hi =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiVT, Load,
DAG.getVectorIdxConstant(LoVT.getVectorMinNumElements(), DL));
}

void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc DL(N);
Expand Down
141 changes: 141 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,144 @@ define <vscale x 2 x float> @test_vp_splice_nxv2f32_masked(<vscale x 2 x float>
%v = call <vscale x 2 x float> @llvm.experimental.vp.splice.nxv2f32(<vscale x 2 x float> %va, <vscale x 2 x float> %vb, i32 5, <vscale x 2 x i1> %mask, i32 %evla, i32 %evlb)
ret <vscale x 2 x float> %v
}

define <vscale x 16 x i64> @test_vp_splice_nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 zeroext %evla, i32 zeroext %evlb) nounwind {
; CHECK-LABEL: test_vp_splice_nxv16i64:
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a4, vlenb
; CHECK-NEXT: slli a5, a4, 1
; CHECK-NEXT: addi a5, a5, -1
; CHECK-NEXT: slli a1, a4, 3
; CHECK-NEXT: mv a7, a2
; CHECK-NEXT: bltu a2, a5, .LBB21_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a7, a5
; CHECK-NEXT: .LBB21_2:
; CHECK-NEXT: addi sp, sp, -80
; CHECK-NEXT: sd ra, 72(sp) # 8-byte Folded Spill
; CHECK-NEXT: sd s0, 64(sp) # 8-byte Folded Spill
; CHECK-NEXT: addi s0, sp, 80
; CHECK-NEXT: csrr a5, vlenb
; CHECK-NEXT: slli a5, a5, 5
; CHECK-NEXT: sub sp, sp, a5
; CHECK-NEXT: andi sp, sp, -64
; CHECK-NEXT: add a5, a0, a1
; CHECK-NEXT: slli a7, a7, 3
; CHECK-NEXT: addi a6, sp, 64
; CHECK-NEXT: mv t0, a2
; CHECK-NEXT: bltu a2, a4, .LBB21_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv t0, a4
; CHECK-NEXT: .LBB21_4:
; CHECK-NEXT: vl8re64.v v24, (a5)
; CHECK-NEXT: add a5, a6, a7
; CHECK-NEXT: vl8re64.v v0, (a0)
; CHECK-NEXT: vsetvli zero, t0, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v8, (a6)
; CHECK-NEXT: sub a0, a2, a4
; CHECK-NEXT: sltu a2, a2, a0
; CHECK-NEXT: addi a2, a2, -1
; CHECK-NEXT: and a0, a2, a0
; CHECK-NEXT: add a6, a6, a1
; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v16, (a6)
; CHECK-NEXT: mv a0, a3
; CHECK-NEXT: bltu a3, a4, .LBB21_6
; CHECK-NEXT: # %bb.5:
; CHECK-NEXT: mv a0, a4
; CHECK-NEXT: .LBB21_6:
; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v0, (a5)
; CHECK-NEXT: sub a2, a3, a4
; CHECK-NEXT: add a5, a5, a1
; CHECK-NEXT: sltu a3, a3, a2
; CHECK-NEXT: addi a3, a3, -1
; CHECK-NEXT: and a2, a3, a2
; CHECK-NEXT: addi a3, sp, 104
; CHECK-NEXT: add a1, a3, a1
; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v24, (a5)
; CHECK-NEXT: vle64.v v16, (a1)
; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v8, (a3)
; CHECK-NEXT: addi sp, s0, -80
; CHECK-NEXT: ld ra, 72(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s0, 64(sp) # 8-byte Folded Reload
; CHECK-NEXT: addi sp, sp, 80
; CHECK-NEXT: ret
%v = call <vscale x 16 x i64> @llvm.experimental.vp.splice.nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 5, <vscale x 16 x i1> splat (i1 1), i32 %evla, i32 %evlb)
ret <vscale x 16 x i64> %v
}

define <vscale x 16 x i64> @test_vp_splice_nxv16i64_negative_offset(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 zeroext %evla, i32 zeroext %evlb) nounwind {
; CHECK-LABEL: test_vp_splice_nxv16i64_negative_offset:
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a5, vlenb
; CHECK-NEXT: slli a6, a5, 1
; CHECK-NEXT: addi a6, a6, -1
; CHECK-NEXT: slli a1, a5, 3
; CHECK-NEXT: mv a4, a2
; CHECK-NEXT: bltu a2, a6, .LBB22_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a4, a6
; CHECK-NEXT: .LBB22_2:
; CHECK-NEXT: addi sp, sp, -80
; CHECK-NEXT: sd ra, 72(sp) # 8-byte Folded Spill
; CHECK-NEXT: sd s0, 64(sp) # 8-byte Folded Spill
; CHECK-NEXT: addi s0, sp, 80
; CHECK-NEXT: csrr a6, vlenb
; CHECK-NEXT: slli a6, a6, 5
; CHECK-NEXT: sub sp, sp, a6
; CHECK-NEXT: andi sp, sp, -64
; CHECK-NEXT: add a6, a0, a1
; CHECK-NEXT: slli a4, a4, 3
; CHECK-NEXT: addi a7, sp, 64
; CHECK-NEXT: mv t0, a2
; CHECK-NEXT: bltu a2, a5, .LBB22_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv t0, a5
; CHECK-NEXT: .LBB22_4:
; CHECK-NEXT: vl8re64.v v24, (a6)
; CHECK-NEXT: add a6, a7, a4
; CHECK-NEXT: vl8re64.v v0, (a0)
; CHECK-NEXT: vsetvli zero, t0, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v8, (a7)
; CHECK-NEXT: sub a0, a2, a5
; CHECK-NEXT: sltu a2, a2, a0
; CHECK-NEXT: addi a2, a2, -1
; CHECK-NEXT: and a0, a2, a0
; CHECK-NEXT: add a7, a7, a1
; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v16, (a7)
; CHECK-NEXT: mv a0, a3
; CHECK-NEXT: bltu a3, a5, .LBB22_6
; CHECK-NEXT: # %bb.5:
; CHECK-NEXT: mv a0, a5
; CHECK-NEXT: .LBB22_6:
; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v0, (a6)
; CHECK-NEXT: sub a2, a3, a5
; CHECK-NEXT: add a5, a6, a1
; CHECK-NEXT: sltu a3, a3, a2
; CHECK-NEXT: addi a3, a3, -1
; CHECK-NEXT: and a2, a3, a2
; CHECK-NEXT: li a3, 8
; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v24, (a5)
; CHECK-NEXT: bltu a4, a3, .LBB22_8
; CHECK-NEXT: # %bb.7:
; CHECK-NEXT: li a4, 8
; CHECK-NEXT: .LBB22_8:
; CHECK-NEXT: sub a2, a6, a4
; CHECK-NEXT: add a1, a2, a1
; CHECK-NEXT: vle64.v v16, (a1)
; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v8, (a2)
; CHECK-NEXT: addi sp, s0, -80
; CHECK-NEXT: ld ra, 72(sp) # 8-byte Folded Reload
; CHECK-NEXT: ld s0, 64(sp) # 8-byte Folded Reload
; CHECK-NEXT: addi sp, sp, 80
; CHECK-NEXT: ret
%v = call <vscale x 16 x i64> @llvm.experimental.vp.splice.nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 -1, <vscale x 16 x i1> splat (i1 1), i32 %evla, i32 %evlb)
ret <vscale x 16 x i64> %v
}
Loading