Skip to content

Commit 18775a4

Browse files
authored
[AArch64][SVE2] Use rshrnb for masked stores (#70026)
This patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in: C = A + B D = C >> Shift is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store.
1 parent d1556e5 commit 18775a4

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21002,6 +21002,12 @@ static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG,
2100221002
Store->getMemOperand());
2100321003
}
2100421004

21005+
bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) {
21006+
return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
21007+
(SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
21008+
(SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
21009+
}
21010+
2100521011
static SDValue performSTORECombine(SDNode *N,
2100621012
TargetLowering::DAGCombinerInfo &DCI,
2100721013
SelectionDAG &DAG,
@@ -21043,16 +21049,16 @@ static SDValue performSTORECombine(SDNode *N,
2104321049
if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
2104421050
return Store;
2104521051

21046-
if (ST->isTruncatingStore())
21052+
if (ST->isTruncatingStore()) {
21053+
EVT StoreVT = ST->getMemoryVT();
21054+
if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
21055+
return SDValue();
2104721056
if (SDValue Rshrnb =
2104821057
trySimplifySrlAddToRshrnb(ST->getOperand(1), DAG, Subtarget)) {
21049-
EVT StoreVT = ST->getMemoryVT();
21050-
if ((ValueVT == MVT::nxv8i16 && StoreVT == MVT::nxv8i8) ||
21051-
(ValueVT == MVT::nxv4i32 && StoreVT == MVT::nxv4i16) ||
21052-
(ValueVT == MVT::nxv2i64 && StoreVT == MVT::nxv2i32))
21053-
return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
21054-
StoreVT, ST->getMemOperand());
21058+
return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
21059+
StoreVT, ST->getMemOperand());
2105521060
}
21061+
}
2105621062

2105721063
return SDValue();
2105821064
}
@@ -21098,6 +21104,19 @@ static SDValue performMSTORECombine(SDNode *N,
2109821104
}
2109921105
}
2110021106

21107+
if (MST->isTruncatingStore()) {
21108+
EVT ValueVT = Value->getValueType(0);
21109+
EVT MemVT = MST->getMemoryVT();
21110+
if (!isHalvingTruncateOfLegalScalableType(ValueVT, MemVT))
21111+
return SDValue();
21112+
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
21113+
return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(),
21114+
MST->getOffset(), MST->getMask(),
21115+
MST->getMemoryVT(), MST->getMemOperand(),
21116+
MST->getAddressingMode(), true);
21117+
}
21118+
}
21119+
2110121120
return SDValue();
2110221121
}
2110321122

llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
298298
store <vscale x 2 x i16> %3, ptr %4, align 1
299299
ret void
300300
}
301+
302+
define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i1> %mask) { ; preds = %vector.body, %vector.ph
303+
; CHECK-LABEL: masked_store_rshrnb:
304+
; CHECK: // %bb.0:
305+
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
306+
; CHECK-NEXT: rshrnb z0.b, z0.h, #6
307+
; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2]
308+
; CHECK-NEXT: ret
309+
%wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
310+
%1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
311+
%2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
312+
%3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
313+
%4 = getelementptr inbounds i8, ptr %dst, i64 %index
314+
tail call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> %3, ptr %4, i32 1, <vscale x 8 x i1> %mask)
315+
ret void
316+
}
317+
318+
declare void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8>, ptr, i32, <vscale x 8 x i1>)
319+
declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)

0 commit comments

Comments
 (0)