Skip to content

Commit 019f022

Browse files
[AArch64][SVE] Fold gather/scatter with 32bits when possible
In AArch64ISelLowering.cpp this patch implements this fold: GEP (%ptr, (splat(%offset) + stepvector(A))) into GEP ((%ptr + %offset), stepvector(A)) The above transform simplifies the index operand so that it can be expressed as i32 elements. This allows using only one gather/scatter assembly instruction instead of two. Patch by Paul Walker (@paulwalker-arm). Depends on D118459 Differential Revision: https://reviews.llvm.org/D117900
1 parent 14124c3 commit 019f022

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
889889
setTargetDAGCombine(ISD::VECREDUCE_ADD);
890890
setTargetDAGCombine(ISD::STEP_VECTOR);
891891

892+
setTargetDAGCombine(ISD::MGATHER);
893+
setTargetDAGCombine(ISD::MSCATTER);
894+
892895
setTargetDAGCombine(ISD::FP_EXTEND);
893896

894897
setTargetDAGCombine(ISD::GlobalAddress);
@@ -16358,6 +16361,93 @@ static SDValue performSTORECombine(SDNode *N,
1635816361
return SDValue();
1635916362
}
1636016363

16364+
// Analyse the specified address returning true if a more optimal addressing
16365+
// mode is available. When returning true all parameters are updated to reflect
16366+
// their recommended values.
16367+
static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
16368+
SDValue &BasePtr, SDValue &Index,
16369+
ISD::MemIndexType &IndexType,
16370+
SelectionDAG &DAG) {
16371+
// Only consider element types that are pointer sized as smaller types can
16372+
// be easily promoted.
16373+
EVT IndexVT = Index.getValueType();
16374+
if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
16375+
return false;
16376+
16377+
int64_t Stride = 0;
16378+
SDLoc DL(N);
16379+
// Index = step(const) + splat(offset)
16380+
if (Index.getOpcode() == ISD::ADD &&
16381+
Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
16382+
SDValue StepVector = Index.getOperand(0);
16383+
if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
16384+
Stride = cast<ConstantSDNode>(StepVector.getOperand(0))->getSExtValue();
16385+
Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
16386+
BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
16387+
}
16388+
}
16389+
16390+
// Return early because no supported pattern is found.
16391+
if (Stride == 0)
16392+
return false;
16393+
16394+
if (Stride < std::numeric_limits<int32_t>::min() ||
16395+
Stride > std::numeric_limits<int32_t>::max())
16396+
return false;
16397+
16398+
const auto &Subtarget =
16399+
static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
16400+
unsigned MaxVScale =
16401+
Subtarget.getMaxSVEVectorSizeInBits() / AArch64::SVEBitsPerBlock;
16402+
int64_t LastElementOffset =
16403+
IndexVT.getVectorMinNumElements() * Stride * MaxVScale;
16404+
16405+
if (LastElementOffset < std::numeric_limits<int32_t>::min() ||
16406+
LastElementOffset > std::numeric_limits<int32_t>::max())
16407+
return false;
16408+
16409+
EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
16410+
Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT,
16411+
DAG.getTargetConstant(Stride, DL, MVT::i32));
16412+
return true;
16413+
}
16414+
16415+
static SDValue performMaskedGatherScatterCombine(
16416+
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
16417+
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
16418+
assert(MGS && "Can only combine gather load or scatter store nodes");
16419+
16420+
if (!DCI.isBeforeLegalize())
16421+
return SDValue();
16422+
16423+
SDLoc DL(MGS);
16424+
SDValue Chain = MGS->getChain();
16425+
SDValue Scale = MGS->getScale();
16426+
SDValue Index = MGS->getIndex();
16427+
SDValue Mask = MGS->getMask();
16428+
SDValue BasePtr = MGS->getBasePtr();
16429+
ISD::MemIndexType IndexType = MGS->getIndexType();
16430+
16431+
if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG))
16432+
return SDValue();
16433+
16434+
// Here we catch such cases early and change MGATHER's IndexType to allow
16435+
// the use of an Index that's more legalisation friendly.
16436+
if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
16437+
SDValue PassThru = MGT->getPassThru();
16438+
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
16439+
return DAG.getMaskedGather(
16440+
DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
16441+
Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
16442+
}
16443+
auto *MSC = cast<MaskedScatterSDNode>(MGS);
16444+
SDValue Data = MSC->getValue();
16445+
SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
16446+
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL,
16447+
Ops, MSC->getMemOperand(), IndexType,
16448+
MSC->isTruncatingStore());
16449+
}
16450+
1636116451
/// Target-specific DAG combine function for NEON load/store intrinsics
1636216452
/// to merge base address updates.
1636316453
static SDValue performNEONPostLDSTCombine(SDNode *N,
@@ -17820,6 +17910,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
1782017910
break;
1782117911
case ISD::STORE:
1782217912
return performSTORECombine(N, DCI, DAG, Subtarget);
17913+
case ISD::MGATHER:
17914+
case ISD::MSCATTER:
17915+
return performMaskedGatherScatterCombine(N, DCI, DAG);
1782317916
case ISD::VECTOR_SPLICE:
1782417917
return performSVESpliceCombine(N, DAG);
1782517918
case ISD::FP_EXTEND:
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s -mtriple=aarch64-linux-unknown | FileCheck %s
3+
4+
5+
; Ensure we use a "vscale x 4" wide scatter for the maximum supported offset.
6+
define void @scatter_i8_index_offset_maximum(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
7+
; CHECK-LABEL: scatter_i8_index_offset_maximum:
8+
; CHECK: // %bb.0:
9+
; CHECK-NEXT: mov w8, #33554431
10+
; CHECK-NEXT: add x9, x0, x1
11+
; CHECK-NEXT: index z1.s, #0, w8
12+
; CHECK-NEXT: st1b { z0.s }, p0, [x9, z1.s, sxtw]
13+
; CHECK-NEXT: ret
14+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
15+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
16+
%t2 = insertelement <vscale x 4 x i64> undef, i64 33554431, i32 0
17+
%t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
18+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
19+
%t4 = mul <vscale x 4 x i64> %t3, %step
20+
%t5 = add <vscale x 4 x i64> %t1, %t4
21+
%t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
22+
call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
23+
ret void
24+
}
25+
26+
; Ensure we use a "vscale x 4" wide scatter for the minimum supported offset.
27+
define void @scatter_i16_index_offset_minimum(i16* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i16> %data) #0 {
28+
; CHECK-LABEL: scatter_i16_index_offset_minimum:
29+
; CHECK: // %bb.0:
30+
; CHECK-NEXT: mov w8, #-33554432
31+
; CHECK-NEXT: add x9, x0, x1, lsl #1
32+
; CHECK-NEXT: index z1.s, #0, w8
33+
; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw #1]
34+
; CHECK-NEXT: ret
35+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
36+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
37+
%t2 = insertelement <vscale x 4 x i64> undef, i64 -33554432, i32 0
38+
%t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
39+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
40+
%t4 = mul <vscale x 4 x i64> %t3, %step
41+
%t5 = add <vscale x 4 x i64> %t1, %t4
42+
%t6 = getelementptr i16, i16* %base, <vscale x 4 x i64> %t5
43+
call void @llvm.masked.scatter.nxv4i16(<vscale x 4 x i16> %data, <vscale x 4 x i16*> %t6, i32 2, <vscale x 4 x i1> %pg)
44+
ret void
45+
}
46+
47+
; Ensure we use a "vscale x 4" gather for an offset in the limits of 32 bits.
48+
define <vscale x 4 x i8> @gather_i8_index_offset_8(i8* %base, i64 %offset, <vscale x 4 x i1> %pg) #0 {
49+
; CHECK-LABEL: gather_i8_index_offset_8:
50+
; CHECK: // %bb.0:
51+
; CHECK-NEXT: add x8, x0, x1
52+
; CHECK-NEXT: index z0.s, #0, #1
53+
; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw]
54+
; CHECK-NEXT: ret
55+
%splat.insert0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
56+
%splat0 = shufflevector <vscale x 4 x i64> %splat.insert0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
57+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
58+
%splat.insert1 = insertelement <vscale x 4 x i64> undef, i64 1, i32 0
59+
%splat1 = shufflevector <vscale x 4 x i64> %splat.insert1, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
60+
%t1 = mul <vscale x 4 x i64> %splat1, %step
61+
%t2 = add <vscale x 4 x i64> %splat0, %t1
62+
%t3 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t2
63+
%load = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*> %t3, i32 4, <vscale x 4 x i1> %pg, <vscale x 4 x i8> undef)
64+
ret <vscale x 4 x i8> %load
65+
}
66+
67+
;; Negative tests
68+
69+
; Ensure we don't use a "vscale x 4" scatter. Cannot prove that variable stride
70+
; will not wrap when shrunk to be i32 based.
71+
define void @scatter_f16_index_offset_var(half* %base, i64 %offset, i64 %scale, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
72+
; CHECK-LABEL: scatter_f16_index_offset_var:
73+
; CHECK: // %bb.0:
74+
; CHECK-NEXT: index z1.d, #0, #1
75+
; CHECK-NEXT: mov z3.d, x1
76+
; CHECK-NEXT: mov z2.d, z1.d
77+
; CHECK-NEXT: mov z4.d, z3.d
78+
; CHECK-NEXT: ptrue p1.d
79+
; CHECK-NEXT: incd z2.d
80+
; CHECK-NEXT: mla z3.d, p1/m, z1.d, z3.d
81+
; CHECK-NEXT: mla z4.d, p1/m, z2.d, z4.d
82+
; CHECK-NEXT: punpklo p1.h, p0.b
83+
; CHECK-NEXT: uunpklo z1.d, z0.s
84+
; CHECK-NEXT: punpkhi p0.h, p0.b
85+
; CHECK-NEXT: uunpkhi z0.d, z0.s
86+
; CHECK-NEXT: st1h { z1.d }, p1, [x0, z3.d, lsl #1]
87+
; CHECK-NEXT: st1h { z0.d }, p0, [x0, z4.d, lsl #1]
88+
; CHECK-NEXT: ret
89+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
90+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
91+
%t2 = insertelement <vscale x 4 x i64> undef, i64 %scale, i32 0
92+
%t3 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
93+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
94+
%t4 = mul <vscale x 4 x i64> %t3, %step
95+
%t5 = add <vscale x 4 x i64> %t1, %t4
96+
%t6 = getelementptr half, half* %base, <vscale x 4 x i64> %t5
97+
call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %t6, i32 2, <vscale x 4 x i1> %pg)
98+
ret void
99+
}
100+
101+
; Ensure we don't use a "vscale x 4" wide scatter when the offset is too big.
102+
define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
103+
; CHECK-LABEL: scatter_i8_index_offset_maximum_plus_one:
104+
; CHECK: // %bb.0:
105+
; CHECK-NEXT: rdvl x8, #1
106+
; CHECK-NEXT: mov w9, #67108864
107+
; CHECK-NEXT: lsr x8, x8, #4
108+
; CHECK-NEXT: mov z1.d, x1
109+
; CHECK-NEXT: punpklo p1.h, p0.b
110+
; CHECK-NEXT: punpkhi p0.h, p0.b
111+
; CHECK-NEXT: mul x8, x8, x9
112+
; CHECK-NEXT: mov w9, #33554432
113+
; CHECK-NEXT: index z2.d, #0, x9
114+
; CHECK-NEXT: mov z3.d, x8
115+
; CHECK-NEXT: add z3.d, z2.d, z3.d
116+
; CHECK-NEXT: add z2.d, z2.d, z1.d
117+
; CHECK-NEXT: add z1.d, z3.d, z1.d
118+
; CHECK-NEXT: uunpklo z3.d, z0.s
119+
; CHECK-NEXT: uunpkhi z0.d, z0.s
120+
; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d]
121+
; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d]
122+
; CHECK-NEXT: ret
123+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
124+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
125+
%t2 = insertelement <vscale x 4 x i64> undef, i64 33554432, i32 0
126+
%t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
127+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
128+
%t4 = mul <vscale x 4 x i64> %t3, %step
129+
%t5 = add <vscale x 4 x i64> %t1, %t4
130+
%t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
131+
call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
132+
ret void
133+
}
134+
135+
; Ensure we don't use a "vscale x 4" wide scatter when the offset is too small.
136+
define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
137+
; CHECK-LABEL: scatter_i8_index_offset_minimum_minus_one:
138+
; CHECK: // %bb.0:
139+
; CHECK-NEXT: rdvl x8, #1
140+
; CHECK-NEXT: mov x9, #-2
141+
; CHECK-NEXT: lsr x8, x8, #4
142+
; CHECK-NEXT: movk x9, #64511, lsl #16
143+
; CHECK-NEXT: mov z1.d, x1
144+
; CHECK-NEXT: punpklo p1.h, p0.b
145+
; CHECK-NEXT: mul x8, x8, x9
146+
; CHECK-NEXT: mov x9, #-33554433
147+
; CHECK-NEXT: punpkhi p0.h, p0.b
148+
; CHECK-NEXT: index z2.d, #0, x9
149+
; CHECK-NEXT: mov z3.d, x8
150+
; CHECK-NEXT: add z3.d, z2.d, z3.d
151+
; CHECK-NEXT: add z2.d, z2.d, z1.d
152+
; CHECK-NEXT: add z1.d, z3.d, z1.d
153+
; CHECK-NEXT: uunpklo z3.d, z0.s
154+
; CHECK-NEXT: uunpkhi z0.d, z0.s
155+
; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d]
156+
; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d]
157+
; CHECK-NEXT: ret
158+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
159+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
160+
%t2 = insertelement <vscale x 4 x i64> undef, i64 -33554433, i32 0
161+
%t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
162+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
163+
%t4 = mul <vscale x 4 x i64> %t3, %step
164+
%t5 = add <vscale x 4 x i64> %t1, %t4
165+
%t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
166+
call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
167+
ret void
168+
}
169+
170+
; Ensure we don't use a "vscale x 4" wide scatter when the stride is too big .
171+
define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4 x i1> %pg, <vscale x 4 x i8> %data) #0 {
172+
; CHECK-LABEL: scatter_i8_index_stride_too_big:
173+
; CHECK: // %bb.0:
174+
; CHECK-NEXT: rdvl x8, #1
175+
; CHECK-NEXT: mov x9, #-9223372036854775808
176+
; CHECK-NEXT: lsr x8, x8, #4
177+
; CHECK-NEXT: mov z1.d, x1
178+
; CHECK-NEXT: punpklo p1.h, p0.b
179+
; CHECK-NEXT: punpkhi p0.h, p0.b
180+
; CHECK-NEXT: mul x8, x8, x9
181+
; CHECK-NEXT: mov x9, #4611686018427387904
182+
; CHECK-NEXT: index z2.d, #0, x9
183+
; CHECK-NEXT: mov z3.d, x8
184+
; CHECK-NEXT: add z3.d, z2.d, z3.d
185+
; CHECK-NEXT: add z2.d, z2.d, z1.d
186+
; CHECK-NEXT: add z1.d, z3.d, z1.d
187+
; CHECK-NEXT: uunpklo z3.d, z0.s
188+
; CHECK-NEXT: uunpkhi z0.d, z0.s
189+
; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d]
190+
; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d]
191+
; CHECK-NEXT: ret
192+
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
193+
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
194+
%t2 = insertelement <vscale x 4 x i64> undef, i64 4611686018427387904, i32 0
195+
%t3 = shufflevector <vscale x 4 x i64> %t2, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
196+
%step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
197+
%t4 = mul <vscale x 4 x i64> %t3, %step
198+
%t5 = add <vscale x 4 x i64> %t1, %t4
199+
%t6 = getelementptr i8, i8* %base, <vscale x 4 x i64> %t5
200+
call void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8> %data, <vscale x 4 x i8*> %t6, i32 2, <vscale x 4 x i1> %pg)
201+
ret void
202+
}
203+
204+
205+
attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
206+
207+
208+
declare <vscale x 4 x i8> @llvm.masked.gather.nxv4i8(<vscale x 4 x i8*>, i32, <vscale x 4 x i1>, <vscale x 4 x i8>)
209+
declare void @llvm.masked.scatter.nxv4i8(<vscale x 4 x i8>, <vscale x 4 x i8*>, i32, <vscale x 4 x i1>)
210+
declare void @llvm.masked.scatter.nxv4i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32, <vscale x 4 x i1>)
211+
declare void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half*>, i32, <vscale x 4 x i1>)
212+
declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()

0 commit comments

Comments
 (0)