Skip to content

Commit 8367247

Browse files
topperc4vtomat
andcommitted
[RISCV] Fold vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR, -1, MASK)
This was extracted from our downstream with only a quick re-review. It was originally written 1.5 years ago so there might be existing helper functions added since then that could simplify it. Co-authored-by: Brandon Wu <[email protected]>
1 parent ef1260a commit 8367247

File tree

2 files changed

+153
-6
lines changed

2 files changed

+153
-6
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,13 +1524,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15241524
setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
15251525
ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
15261526
if (Subtarget.hasVInstructions())
1527-
setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
1528-
ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
1529-
ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
1527+
setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER,
1528+
ISD::MSCATTER, ISD::VP_GATHER,
1529+
ISD::VP_SCATTER, ISD::SRA,
1530+
ISD::SRL, ISD::SHL,
1531+
ISD::STORE, ISD::SPLAT_VECTOR,
15301532
ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
1531-
ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
1532-
ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
1533-
ISD::INSERT_VECTOR_ELT, ISD::ABS, ISD::CTPOP});
1533+
ISD::VP_STORE, ISD::EXPERIMENTAL_VP_REVERSE,
1534+
ISD::MUL, ISD::SDIV,
1535+
ISD::UDIV, ISD::SREM,
1536+
ISD::UREM, ISD::INSERT_VECTOR_ELT,
1537+
ISD::ABS, ISD::CTPOP});
15341538
if (Subtarget.hasVendorXTHeadMemPair())
15351539
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
15361540
if (Subtarget.useRVVForFixedLengthVectors())
@@ -16229,6 +16233,66 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
1622916233
return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
1623016234
}
1623116235

16236+
static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
16237+
const RISCVSubtarget &Subtarget) {
16238+
// Fold:
16239+
// vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR,
16240+
// -1, MASK)
16241+
auto *VPStore = cast<VPStoreSDNode>(N);
16242+
16243+
if (VPStore->getValue().getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE)
16244+
return SDValue();
16245+
16246+
SDValue VPReverse = VPStore->getValue();
16247+
EVT ReverseVT = VPReverse->getValueType(0);
16248+
16249+
// We do not have a strided_store version for masks, and the evl of vp.reverse
16250+
// and vp.store should always be the same.
16251+
if (!ReverseVT.getVectorElementType().isByteSized() ||
16252+
VPStore->getVectorLength() != VPReverse.getOperand(2) ||
16253+
!VPReverse.hasOneUse())
16254+
return SDValue();
16255+
16256+
SDValue StoreMask = VPStore->getMask();
16257+
// If Mask is not all 1's, try to replace the mask if it's opcode
16258+
// is EXPERIMENTAL_VP_REVERSE and it's operand can be directly extracted.
16259+
if (!isOneOrOneSplat(StoreMask)) {
16260+
// Check if the mask of vp.reverse in vp.store are all 1's and
16261+
// the length of mask is same as evl.
16262+
if (StoreMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
16263+
!isOneOrOneSplat(StoreMask.getOperand(1)) ||
16264+
StoreMask.getOperand(2) != VPStore->getVectorLength())
16265+
return SDValue();
16266+
StoreMask = StoreMask.getOperand(0);
16267+
}
16268+
16269+
// Base = StoreAddr + (NumElem - 1) * ElemWidthByte
16270+
SDLoc DL(N);
16271+
MVT XLenVT = Subtarget.getXLenVT();
16272+
SDValue NumElem = VPStore->getVectorLength();
16273+
uint64_t ElemWidthByte = VPReverse.getValueType().getScalarSizeInBits() / 8;
16274+
16275+
SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
16276+
DAG.getConstant(1, DL, XLenVT));
16277+
SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
16278+
DAG.getConstant(ElemWidthByte, DL, XLenVT));
16279+
SDValue Base =
16280+
DAG.getNode(ISD::ADD, DL, XLenVT, VPStore->getBasePtr(), Temp2);
16281+
SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
16282+
16283+
MachineFunction &MF = DAG.getMachineFunction();
16284+
MachinePointerInfo PtrInfo(VPStore->getAddressSpace());
16285+
MachineMemOperand *MMO = MF.getMachineMemOperand(
16286+
PtrInfo, VPStore->getMemOperand()->getFlags(),
16287+
LocationSize::beforeOrAfterPointer(), VPStore->getAlign());
16288+
16289+
return DAG.getStridedStoreVP(
16290+
VPStore->getChain(), DL, VPReverse.getOperand(0), Base,
16291+
VPStore->getOffset(), Stride, StoreMask, VPStore->getVectorLength(),
16292+
VPStore->getMemoryVT(), MMO, VPStore->getAddressingMode(),
16293+
VPStore->isTruncatingStore(), VPStore->isCompressingStore());
16294+
}
16295+
1623216296
// Convert from one FMA opcode to another based on whether we are negating the
1623316297
// multiply result and/or the accumulator.
1623416298
// NOTE: Only supports RVV operations with VL.
@@ -18372,6 +18436,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1837218436
}
1837318437
}
1837418438
}
18439+
case ISD::VP_STORE:
18440+
return performVP_STORECombine(N, DAG, Subtarget);
1837518441
case ISD::BITCAST: {
1837618442
assert(Subtarget.useRVVForFixedLengthVectors());
1837718443
SDValue N0 = N->getOperand(0);
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s
3+
4+
define void @test_store_reverse_combiner(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, i32 zeroext %evl) {
5+
; CHECK-LABEL: test_store_reverse_combiner:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: slli a2, a1, 2
8+
; CHECK-NEXT: add a0, a2, a0
9+
; CHECK-NEXT: addi a0, a0, -4
10+
; CHECK-NEXT: li a2, -4
11+
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
12+
; CHECK-NEXT: vsse32.v v8, (a0), a2
13+
; CHECK-NEXT: ret
14+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl)
15+
call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> splat (i1 true), i32 %evl)
16+
ret void
17+
}
18+
19+
define void @test_store_mask_is_vp_reverse(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl) {
20+
; CHECK-LABEL: test_store_mask_is_vp_reverse:
21+
; CHECK: # %bb.0:
22+
; CHECK-NEXT: slli a2, a1, 2
23+
; CHECK-NEXT: add a0, a2, a0
24+
; CHECK-NEXT: addi a0, a0, -4
25+
; CHECK-NEXT: li a2, -4
26+
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
27+
; CHECK-NEXT: vsse32.v v8, (a0), a2, v0.t
28+
; CHECK-NEXT: ret
29+
%storemask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl)
30+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl)
31+
call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %storemask, i32 %evl)
32+
ret void
33+
}
34+
35+
define void @test_store_mask_not_all_one(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 zeroext %evl) {
36+
; CHECK-LABEL: test_store_mask_not_all_one:
37+
; CHECK: # %bb.0:
38+
; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
39+
; CHECK-NEXT: vid.v v9, v0.t
40+
; CHECK-NEXT: addi a1, a1, -1
41+
; CHECK-NEXT: vrsub.vx v9, v9, a1, v0.t
42+
; CHECK-NEXT: vrgather.vv v10, v8, v9, v0.t
43+
; CHECK-NEXT: vse32.v v10, (a0), v0.t
44+
; CHECK-NEXT: ret
45+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> %notallones, i32 %evl)
46+
call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 %evl)
47+
ret void
48+
}
49+
50+
define void @test_different_evl(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl1, i32 zeroext %evl2) {
51+
; CHECK-LABEL: test_different_evl:
52+
; CHECK: # %bb.0:
53+
; CHECK-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
54+
; CHECK-NEXT: vid.v v9
55+
; CHECK-NEXT: addi a1, a1, -1
56+
; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
57+
; CHECK-NEXT: vmv.v.i v10, 0
58+
; CHECK-NEXT: vmerge.vim v10, v10, 1, v0
59+
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
60+
; CHECK-NEXT: vid.v v11
61+
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
62+
; CHECK-NEXT: vrsub.vx v9, v9, a1
63+
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
64+
; CHECK-NEXT: vrsub.vx v11, v11, a1
65+
; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
66+
; CHECK-NEXT: vrgatherei16.vv v12, v10, v9
67+
; CHECK-NEXT: vmsne.vi v0, v12, 0
68+
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
69+
; CHECK-NEXT: vrgather.vv v9, v8, v11
70+
; CHECK-NEXT: vsetvli zero, a2, e32, m1, ta, ma
71+
; CHECK-NEXT: vse32.v v9, (a0), v0.t
72+
; CHECK-NEXT: ret
73+
%storemask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
74+
%rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
75+
call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %storemask, i32 %evl2)
76+
ret void
77+
}
78+
79+
declare <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
80+
declare <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1>, <vscale x 2 x i1>, i32)
81+
declare void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float>, <vscale x 2 x float>* nocapture, <vscale x 2 x i1>, i32)

0 commit comments

Comments
 (0)