Skip to content

Commit 6383785

Browse files
[SVE][CodeGenPrepare] Sink address calculations that match SVE gather/scatter addressing modes. (#66996)
SVE supports scalar+vector and scalar+extw(vector) addressing modes. However, the masked gather/scatter intrinsics take a vector of addresses, which means address computations can be hoisted out of loops. The is especially true for things like offsets where the true size of offsets is lost by the time you get to code generation. This is problematic because it forces the code generator to legalise towards `<vscale x 2 x ty>` vectors that will not maximise bandwidth if the main block datatypes is in fact i32 or smaller. This patch sinks GEPs and extends for cases where one of the above addressing modes can be used. NOTE: There are cases where it would be better to split the extend in two with one half hoisted out of a loop and the other within the loop. Whilst true I think this switch of default is still better than before because the extra extends are an improvement over being forced to split a gather/scatter.
1 parent a633a37 commit 6383785

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14431,6 +14431,31 @@ static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
1443114431
return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
1443214432
}
1443314433

14434+
static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
14435+
// Restrict ourselves to the form CodeGenPrepare typically constructs.
14436+
auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
14437+
if (!GEP || GEP->getNumOperands() != 2)
14438+
return false;
14439+
14440+
Value *Base = GEP->getOperand(0);
14441+
Value *Offsets = GEP->getOperand(1);
14442+
14443+
// We only care about scalar_base+vector_offsets.
14444+
if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
14445+
return false;
14446+
14447+
// Sink extends that would allow us to use 32-bit offset vectors.
14448+
if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
14449+
auto *OffsetsInst = cast<Instruction>(Offsets);
14450+
if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
14451+
OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
14452+
Ops.push_back(&GEP->getOperandUse(1));
14453+
}
14454+
14455+
// Sink the GEP.
14456+
return true;
14457+
}
14458+
1443414459
/// Check if sinking \p I's operands to I's basic block is profitable, because
1443514460
/// the operands can be folded into a target instruction, e.g.
1443614461
/// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
@@ -14532,6 +14557,16 @@ bool AArch64TargetLowering::shouldSinkOperands(
1453214557
Ops.push_back(&II->getArgOperandUse(0));
1453314558
Ops.push_back(&II->getArgOperandUse(1));
1453414559
return true;
14560+
case Intrinsic::masked_gather:
14561+
if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
14562+
return false;
14563+
Ops.push_back(&II->getArgOperandUse(0));
14564+
return true;
14565+
case Intrinsic::masked_scatter:
14566+
if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
14567+
return false;
14568+
Ops.push_back(&II->getArgOperandUse(1));
14569+
return true;
1453514570
default:
1453614571
return false;
1453714572
}
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
2+
; RUN: opt -S --codegenprepare < %s | FileCheck %s
3+
4+
target triple = "aarch64-unknown-linux-gnu"
5+
6+
; Sink the GEP to make use of scalar+vector addressing modes.
7+
define <vscale x 4 x float> @gather_offsets_sink_gep(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
8+
; CHECK-LABEL: define <vscale x 4 x float> @gather_offsets_sink_gep(
9+
; CHECK-SAME: ptr [[BASE:%.*]], <vscale x 4 x i32> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
10+
; CHECK-NEXT: entry:
11+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
12+
; CHECK: cond.block:
13+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i32> [[INDICES]]
14+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[TMP0]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
15+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
16+
; CHECK: exit:
17+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
18+
;
19+
entry:
20+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i32> %indices
21+
br i1 %cond, label %cond.block, label %exit
22+
23+
cond.block:
24+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
25+
br label %exit
26+
27+
exit:
28+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
29+
ret <vscale x 4 x float> %ret
30+
}
31+
32+
; Sink sext to make use of scalar+sxtw(vector) addressing modes.
33+
define <vscale x 4 x float> @gather_offsets_sink_sext(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
34+
; CHECK-LABEL: define <vscale x 4 x float> @gather_offsets_sink_sext(
35+
; CHECK-SAME: ptr [[BASE:%.*]], <vscale x 4 x i32> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
36+
; CHECK-NEXT: entry:
37+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
38+
; CHECK: cond.block:
39+
; CHECK-NEXT: [[TMP0:%.*]] = sext <vscale x 4 x i32> [[INDICES]] to <vscale x 4 x i64>
40+
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i64> [[TMP0]]
41+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[PTRS]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
42+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
43+
; CHECK: exit:
44+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
45+
;
46+
entry:
47+
%indices.sext = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
48+
br i1 %cond, label %cond.block, label %exit
49+
50+
cond.block:
51+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i64> %indices.sext
52+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
53+
br label %exit
54+
55+
exit:
56+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
57+
ret <vscale x 4 x float> %ret
58+
}
59+
60+
; As above but ensure both the GEP and sext is sunk.
61+
define <vscale x 4 x float> @gather_offsets_sink_sext_get(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
62+
; CHECK-LABEL: define <vscale x 4 x float> @gather_offsets_sink_sext_get(
63+
; CHECK-SAME: ptr [[BASE:%.*]], <vscale x 4 x i32> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
64+
; CHECK-NEXT: entry:
65+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
66+
; CHECK: cond.block:
67+
; CHECK-NEXT: [[TMP0:%.*]] = sext <vscale x 4 x i32> [[INDICES]] to <vscale x 4 x i64>
68+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i64> [[TMP0]]
69+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
70+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
71+
; CHECK: exit:
72+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
73+
;
74+
entry:
75+
%indices.sext = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
76+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i64> %indices.sext
77+
br i1 %cond, label %cond.block, label %exit
78+
79+
cond.block:
80+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
81+
br label %exit
82+
83+
exit:
84+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
85+
ret <vscale x 4 x float> %ret
86+
}
87+
88+
; Don't sink GEPs that cannot benefit from SVE's scalar+vector addressing modes.
89+
define <vscale x 4 x float> @gather_no_scalar_base(<vscale x 4 x ptr> %bases, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
90+
; CHECK-LABEL: define <vscale x 4 x float> @gather_no_scalar_base(
91+
; CHECK-SAME: <vscale x 4 x ptr> [[BASES:%.*]], <vscale x 4 x i32> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
92+
; CHECK-NEXT: entry:
93+
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr float, <vscale x 4 x ptr> [[BASES]], <vscale x 4 x i32> [[INDICES]]
94+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
95+
; CHECK: cond.block:
96+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[PTRS]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
97+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
98+
; CHECK: exit:
99+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
100+
;
101+
entry:
102+
%ptrs = getelementptr float, <vscale x 4 x ptr> %bases, <vscale x 4 x i32> %indices
103+
br i1 %cond, label %cond.block, label %exit
104+
105+
cond.block:
106+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
107+
br label %exit
108+
109+
exit:
110+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
111+
ret <vscale x 4 x float> %ret
112+
}
113+
114+
; Don't sink extends whose result type is already favourable for SVE's sxtw/uxtw addressing modes.
115+
; NOTE: We still want to sink the GEP.
116+
define <vscale x 4 x float> @gather_offset_type_too_small(ptr %base, <vscale x 4 x i8> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
117+
; CHECK-LABEL: define <vscale x 4 x float> @gather_offset_type_too_small(
118+
; CHECK-SAME: ptr [[BASE:%.*]], <vscale x 4 x i8> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
119+
; CHECK-NEXT: entry:
120+
; CHECK-NEXT: [[INDICES_SEXT:%.*]] = sext <vscale x 4 x i8> [[INDICES]] to <vscale x 4 x i32>
121+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
122+
; CHECK: cond.block:
123+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i32> [[INDICES_SEXT]]
124+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[TMP0]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
125+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
126+
; CHECK: exit:
127+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
128+
;
129+
entry:
130+
%indices.sext = sext <vscale x 4 x i8> %indices to <vscale x 4 x i32>
131+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i32> %indices.sext
132+
br i1 %cond, label %cond.block, label %exit
133+
134+
cond.block:
135+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
136+
br label %exit
137+
138+
exit:
139+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
140+
ret <vscale x 4 x float> %ret
141+
}
142+
143+
; Don't sink extends that cannot benefit from SVE's sxtw/uxtw addressing modes.
144+
; NOTE: We still want to sink the GEP.
145+
define <vscale x 4 x float> @gather_offset_type_too_big(ptr %base, <vscale x 4 x i48> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
146+
; CHECK-LABEL: define <vscale x 4 x float> @gather_offset_type_too_big(
147+
; CHECK-SAME: ptr [[BASE:%.*]], <vscale x 4 x i48> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
148+
; CHECK-NEXT: entry:
149+
; CHECK-NEXT: [[INDICES_SEXT:%.*]] = sext <vscale x 4 x i48> [[INDICES]] to <vscale x 4 x i64>
150+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
151+
; CHECK: cond.block:
152+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i64> [[INDICES_SEXT]]
153+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[TMP0]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
154+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
155+
; CHECK: exit:
156+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
157+
;
158+
entry:
159+
%indices.sext = sext <vscale x 4 x i48> %indices to <vscale x 4 x i64>
160+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i64> %indices.sext
161+
br i1 %cond, label %cond.block, label %exit
162+
163+
cond.block:
164+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
165+
br label %exit
166+
167+
exit:
168+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
169+
ret <vscale x 4 x float> %ret
170+
}
171+
172+
; Sink zext to make use of scalar+uxtw(vector) addressing modes.
173+
; TODO: There's an argument here to split the extend into i8->i32 and i32->i64,
174+
; which would be especially useful if the i8s are the result of a load because
175+
; it would maintain the use of sign-extending loads.
176+
define <vscale x 4 x float> @gather_offset_sink_zext(ptr %base, <vscale x 4 x i8> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
177+
; CHECK-LABEL: define <vscale x 4 x float> @gather_offset_sink_zext(
178+
; CHECK-SAME: ptr [[BASE:%.*]], <vscale x 4 x i8> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
179+
; CHECK-NEXT: entry:
180+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
181+
; CHECK: cond.block:
182+
; CHECK-NEXT: [[TMP0:%.*]] = zext <vscale x 4 x i8> [[INDICES]] to <vscale x 4 x i64>
183+
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i64> [[TMP0]]
184+
; CHECK-NEXT: [[LOAD:%.*]] = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> [[PTRS]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> poison)
185+
; CHECK-NEXT: ret <vscale x 4 x float> [[LOAD]]
186+
; CHECK: exit:
187+
; CHECK-NEXT: ret <vscale x 4 x float> zeroinitializer
188+
;
189+
entry:
190+
%indices.zext = zext <vscale x 4 x i8> %indices to <vscale x 4 x i64>
191+
br i1 %cond, label %cond.block, label %exit
192+
193+
cond.block:
194+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i64> %indices.zext
195+
%load = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> poison)
196+
br label %exit
197+
198+
exit:
199+
%ret = phi <vscale x 4 x float> [ zeroinitializer, %entry ], [ %load, %cond.block ]
200+
ret <vscale x 4 x float> %ret
201+
}
202+
203+
; Ensure we support scatters as well as gathers.
204+
define void @scatter_offsets_sink_sext_get(<vscale x 4 x float> %data, ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i1 %cond) {
205+
; CHECK-LABEL: define void @scatter_offsets_sink_sext_get(
206+
; CHECK-SAME: <vscale x 4 x float> [[DATA:%.*]], ptr [[BASE:%.*]], <vscale x 4 x i32> [[INDICES:%.*]], <vscale x 4 x i1> [[MASK:%.*]], i1 [[COND:%.*]]) {
207+
; CHECK-NEXT: entry:
208+
; CHECK-NEXT: br i1 [[COND]], label [[COND_BLOCK:%.*]], label [[EXIT:%.*]]
209+
; CHECK: cond.block:
210+
; CHECK-NEXT: [[TMP0:%.*]] = sext <vscale x 4 x i32> [[INDICES]] to <vscale x 4 x i64>
211+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[BASE]], <vscale x 4 x i64> [[TMP0]]
212+
; CHECK-NEXT: tail call void @llvm.masked.scatter.nxv4f32.nxv4p0(<vscale x 4 x float> [[DATA]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK]])
213+
; CHECK-NEXT: ret void
214+
; CHECK: exit:
215+
; CHECK-NEXT: ret void
216+
;
217+
entry:
218+
%indices.sext = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
219+
%ptrs = getelementptr float, ptr %base, <vscale x 4 x i64> %indices.sext
220+
br i1 %cond, label %cond.block, label %exit
221+
222+
cond.block:
223+
tail call void @llvm.masked.scatter.nxv4f32(<vscale x 4 x float> %data, <vscale x 4 x ptr> %ptrs, i32 4, <vscale x 4 x i1> %mask)
224+
br label %exit
225+
226+
exit:
227+
ret void
228+
}
229+
230+
declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x ptr>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
231+
declare void @llvm.masked.scatter.nxv4f32(<vscale x 4 x float>, <vscale x 4 x ptr>, i32, <vscale x 4 x i1>)

0 commit comments

Comments
 (0)