Skip to content

[SelectionDAG] Support integer promotion for VP_STORE #81299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::EXTRACT_VECTOR_ELT:
Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
case ISD::LOAD: Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
case ISD::VP_LOAD:
Res = PromoteIntRes_VP_LOAD(cast<VPLoadSDNode>(N));
break;
case ISD::MLOAD: Res = PromoteIntRes_MLOAD(cast<MaskedLoadSDNode>(N));
break;
case ISD::MGATHER: Res = PromoteIntRes_MGATHER(cast<MaskedGatherSDNode>(N));
Expand Down Expand Up @@ -957,6 +960,23 @@ SDValue DAGTypeLegalizer::PromoteIntRes_LOAD(LoadSDNode *N) {
return Res;
}

SDValue DAGTypeLegalizer::PromoteIntRes_VP_LOAD(VPLoadSDNode *N) {
assert(!N->isIndexed() && "Indexed vp_load during type legalization!");
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
ISD::LoadExtType ExtType = (N->getExtensionType() == ISD::NON_EXTLOAD)
? ISD::EXTLOAD
: N->getExtensionType();
SDLoc dl(N);
SDValue Res =
DAG.getLoadVP(N->getAddressingMode(), ExtType, NVT, dl, N->getChain(),
N->getBasePtr(), N->getOffset(), N->getMask(),
N->getVectorLength(), N->getMemoryVT(), N->getMemOperand());
// Legalize the chain result - switch anything that used the old chain to
// use the new one.
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}

SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDValue ExtPassThru = GetPromotedInteger(N->getPassThru());
Expand Down Expand Up @@ -1957,6 +1977,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break;
case ISD::STORE: Res = PromoteIntOp_STORE(cast<StoreSDNode>(N),
OpNo); break;
case ISD::VP_STORE:
Res = PromoteIntOp_VP_STORE(cast<VPStoreSDNode>(N), OpNo);
break;
case ISD::MSTORE: Res = PromoteIntOp_MSTORE(cast<MaskedStoreSDNode>(N),
OpNo); break;
case ISD::MLOAD: Res = PromoteIntOp_MLOAD(cast<MaskedLoadSDNode>(N),
Expand Down Expand Up @@ -2378,6 +2401,19 @@ SDValue DAGTypeLegalizer::PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo){
N->getMemoryVT(), N->getMemOperand());
}

SDValue DAGTypeLegalizer::PromoteIntOp_VP_STORE(VPStoreSDNode *N,
unsigned OpNo) {

assert(OpNo == 1 && "Unexpected operand for promotion");
assert(!N->isIndexed() && "expecting unindexed vp_store!");

SDValue DataOp = GetPromotedInteger(N->getValue());
return DAG.getTruncStoreVP(N->getChain(), SDLoc(N), DataOp, N->getBasePtr(),
N->getMask(), N->getVectorLength(),
N->getMemoryVT(), N->getMemOperand(),
N->isCompressingStore());
}

SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N,
unsigned OpNo) {
SDValue DataOp = N->getValue();
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_FREEZE(SDNode *N);
SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
SDValue PromoteIntRes_LOAD(LoadSDNode *N);
SDValue PromoteIntRes_VP_LOAD(VPLoadSDNode *N);
SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N);
SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N);
SDValue PromoteIntRes_VECTOR_COMPRESS(SDNode *N);
Expand Down Expand Up @@ -420,6 +421,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_ExpOp(SDNode *N);
SDValue PromoteIntOp_VECREDUCE(SDNode *N);
SDValue PromoteIntOp_VP_REDUCE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VP_STORE(VPStoreSDNode *N, unsigned OpNo);
SDValue PromoteIntOp_SET_ROUNDING(SDNode *N);
SDValue PromoteIntOp_STACKMAP(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_PATCHPOINT(SDNode *N, unsigned OpNo);
Expand Down
16 changes: 14 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpstore.ll
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ define void @vpstore_v4i8(<4 x i8> %val, ptr %ptr, <4 x i1> %m, i32 zeroext %evl
ret void
}

declare void @llvm.vp.store.v8i7.v8i7.p0(<8 x i7>, ptr, <8 x i1>, i32)

define void @vpstore_v8i7(<8 x i7> %val, ptr %ptr, <8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpstore_v8i7:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a1, e8, mf2, ta, ma
; CHECK-NEXT: vse8.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.vp.store.v8i7.v8i7.p0(<8 x i7> %val, ptr %ptr, <8 x i1> %m, i32 %evl)
ret void
}

declare void @llvm.vp.store.v8i8.p0(<8 x i8>, ptr, <8 x i1>, i32)

define void @vpstore_v8i8(<8 x i8> %val, ptr %ptr, <8 x i1> %m, i32 zeroext %evl) {
Expand Down Expand Up @@ -285,10 +297,10 @@ define void @vpstore_v32f64(<32 x double> %val, ptr %ptr, <32 x i1> %m, i32 zero
; CHECK: # %bb.0:
; CHECK-NEXT: li a3, 16
; CHECK-NEXT: mv a2, a1
; CHECK-NEXT: bltu a1, a3, .LBB23_2
; CHECK-NEXT: bltu a1, a3, .LBB24_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: li a2, 16
; CHECK-NEXT: .LBB23_2:
; CHECK-NEXT: .LBB24_2:
; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v8, (a0), v0.t
; CHECK-NEXT: addi a2, a1, -16
Expand Down
28 changes: 20 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/vpload.ll
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ define <vscale x 3 x i8> @vpload_nxv3i8(ptr %ptr, <vscale x 3 x i1> %m, i32 zero
ret <vscale x 3 x i8> %load
}

declare <vscale x 4 x i6> @llvm.vp.load.nxv4i6.nxv4i6.p0(<vscale x 4 x i6>*, <vscale x 4 x i1>, i32)

define <vscale x 4 x i6> @vpload_nxv4i6(<vscale x 4 x i6>* %ptr, <vscale x 4 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpload_nxv4i6:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a1, e8, mf2, ta, ma
; CHECK-NEXT: vle8.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 4 x i6> @llvm.vp.load.nxv4i6.nxv4i6.p0(<vscale x 4 x i6>* %ptr, <vscale x 4 x i1> %m, i32 %evl)
ret <vscale x 4 x i6> %load
}

declare <vscale x 4 x i8> @llvm.vp.load.nxv4i8.p0(ptr, <vscale x 4 x i1>, i32)

define <vscale x 4 x i8> @vpload_nxv4i8(ptr %ptr, <vscale x 4 x i1> %m, i32 zeroext %evl) {
Expand Down Expand Up @@ -523,10 +535,10 @@ define <vscale x 16 x double> @vpload_nxv16f64(ptr %ptr, <vscale x 16 x i1> %m,
; CHECK-NEXT: add a4, a0, a4
; CHECK-NEXT: vsetvli zero, a3, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v16, (a4), v0.t
; CHECK-NEXT: bltu a1, a2, .LBB43_2
; CHECK-NEXT: bltu a1, a2, .LBB44_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a1, a2
; CHECK-NEXT: .LBB43_2:
; CHECK-NEXT: .LBB44_2:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: vsetvli zero, a1, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v8, (a0), v0.t
Expand All @@ -553,10 +565,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
; CHECK-NEXT: csrr a3, vlenb
; CHECK-NEXT: slli a5, a3, 1
; CHECK-NEXT: mv a4, a2
; CHECK-NEXT: bltu a2, a5, .LBB44_2
; CHECK-NEXT: bltu a2, a5, .LBB45_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a4, a5
; CHECK-NEXT: .LBB44_2:
; CHECK-NEXT: .LBB45_2:
; CHECK-NEXT: sub a6, a4, a3
; CHECK-NEXT: slli a7, a3, 3
; CHECK-NEXT: srli t0, a3, 3
Expand All @@ -572,21 +584,21 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
; CHECK-NEXT: sltu a2, a2, a5
; CHECK-NEXT: addi a2, a2, -1
; CHECK-NEXT: and a2, a2, a5
; CHECK-NEXT: bltu a2, a3, .LBB44_4
; CHECK-NEXT: bltu a2, a3, .LBB45_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a2, a3
; CHECK-NEXT: .LBB44_4:
; CHECK-NEXT: .LBB45_4:
; CHECK-NEXT: slli a5, a3, 4
; CHECK-NEXT: srli a6, a3, 2
; CHECK-NEXT: vsetvli a7, zero, e8, mf2, ta, ma
; CHECK-NEXT: vslidedown.vx v0, v8, a6
; CHECK-NEXT: add a5, a0, a5
; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v24, (a5), v0.t
; CHECK-NEXT: bltu a4, a3, .LBB44_6
; CHECK-NEXT: bltu a4, a3, .LBB45_6
; CHECK-NEXT: # %bb.5:
; CHECK-NEXT: mv a4, a3
; CHECK-NEXT: .LBB44_6:
; CHECK-NEXT: .LBB45_6:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: vsetvli zero, a4, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v8, (a0), v0.t
Expand Down
28 changes: 20 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/vpstore.ll
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ define void @vpstore_nxv4i16(<vscale x 4 x i16> %val, ptr %ptr, <vscale x 4 x i1
ret void
}

declare void @llvm.vp.store.nxv8i12.nxv8i12.p0(<vscale x 8 x i12>, <vscale x 8 x i12>*, <vscale x 8 x i1>, i32)

define void @vpstore_nxv8i12(<vscale x 8 x i12> %val, <vscale x 8 x i12>* %ptr, <vscale x 8 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vpstore_nxv8i12:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a1, e16, m2, ta, ma
; CHECK-NEXT: vse16.v v8, (a0), v0.t
; CHECK-NEXT: ret
call void @llvm.vp.store.nxv8i12.nxv8i12.p0(<vscale x 8 x i12> %val, <vscale x 8 x i12>* %ptr, <vscale x 8 x i1> %m, i32 %evl)
ret void
}

declare void @llvm.vp.store.nxv8i16.p0(<vscale x 8 x i16>, ptr, <vscale x 8 x i1>, i32)

define void @vpstore_nxv8i16(<vscale x 8 x i16> %val, ptr %ptr, <vscale x 8 x i1> %m, i32 zeroext %evl) {
Expand Down Expand Up @@ -421,10 +433,10 @@ define void @vpstore_nxv16f64(<vscale x 16 x double> %val, ptr %ptr, <vscale x 1
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a2, vlenb
; CHECK-NEXT: mv a3, a1
; CHECK-NEXT: bltu a1, a2, .LBB34_2
; CHECK-NEXT: bltu a1, a2, .LBB35_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: .LBB34_2:
; CHECK-NEXT: .LBB35_2:
; CHECK-NEXT: vsetvli zero, a3, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v8, (a0), v0.t
; CHECK-NEXT: srli a3, a2, 3
Expand Down Expand Up @@ -462,15 +474,15 @@ define void @vpstore_nxv17f64(<vscale x 17 x double> %val, ptr %ptr, <vscale x 1
; CHECK-NEXT: csrr a3, vlenb
; CHECK-NEXT: slli a4, a3, 1
; CHECK-NEXT: mv a5, a2
; CHECK-NEXT: bltu a2, a4, .LBB35_2
; CHECK-NEXT: bltu a2, a4, .LBB36_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a5, a4
; CHECK-NEXT: .LBB35_2:
; CHECK-NEXT: .LBB36_2:
; CHECK-NEXT: mv a6, a5
; CHECK-NEXT: bltu a5, a3, .LBB35_4
; CHECK-NEXT: bltu a5, a3, .LBB36_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a6, a3
; CHECK-NEXT: .LBB35_4:
; CHECK-NEXT: .LBB36_4:
; CHECK-NEXT: vmv1r.v v0, v24
; CHECK-NEXT: vl8re64.v v16, (a0)
; CHECK-NEXT: vsetvli zero, a6, e64, m8, ta, ma
Expand All @@ -492,10 +504,10 @@ define void @vpstore_nxv17f64(<vscale x 17 x double> %val, ptr %ptr, <vscale x 1
; CHECK-NEXT: vl8r.v v8, (a2) # Unknown-size Folded Reload
; CHECK-NEXT: vsetvli zero, a5, e64, m8, ta, ma
; CHECK-NEXT: vse64.v v8, (a6), v0.t
; CHECK-NEXT: bltu a0, a3, .LBB35_6
; CHECK-NEXT: bltu a0, a3, .LBB36_6
; CHECK-NEXT: # %bb.5:
; CHECK-NEXT: mv a0, a3
; CHECK-NEXT: .LBB35_6:
; CHECK-NEXT: .LBB36_6:
; CHECK-NEXT: slli a2, a3, 4
; CHECK-NEXT: srli a3, a3, 2
; CHECK-NEXT: vsetvli a4, zero, e8, mf2, ta, ma
Expand Down
Loading