Skip to content

Commit 9b0c52a

Browse files
committed
[AArch64][SVE2] Lower read-after-write mask to whilerw
This patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw.
1 parent d4630ae commit 9b0c52a

File tree

2 files changed

+159
-8
lines changed

2 files changed

+159
-8
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14189,7 +14189,16 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
1418914189
return SDValue();
1419014190

1419114191
SDValue Diff = Cmp.getOperand(0);
14192-
if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
14192+
SDValue NonAbsDiff = Diff;
14193+
bool WriteAfterRead = true;
14194+
// A read-after-write will have an abs call on the diff
14195+
if (Diff.getOpcode() == ISD::ABS) {
14196+
NonAbsDiff = Diff.getOperand(0);
14197+
WriteAfterRead = false;
14198+
}
14199+
14200+
if (NonAbsDiff.getOpcode() != ISD::SUB ||
14201+
NonAbsDiff.getValueType() != MVT::i64)
1419314202
return SDValue();
1419414203

1419514204
if (!isNullConstant(LaneMask.getOperand(1)) ||
@@ -14210,8 +14219,13 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
1421014219
// it's positive, otherwise the difference plus the element size if it's
1421114220
// negative: pos_diff = diff < 0 ? (diff + 7) : diff
1421214221
SDValue Select = DiffDiv.getOperand(0);
14222+
SDValue SelectOp3 = Select.getOperand(3);
14223+
// Check for an abs in the case of a read-after-write
14224+
if (!WriteAfterRead && SelectOp3.getOpcode() == ISD::ABS)
14225+
SelectOp3 = SelectOp3.getOperand(0);
14226+
1421314227
// Make sure the difference is being compared by the select
14214-
if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
14228+
if (Select.getOpcode() != ISD::SELECT_CC || SelectOp3 != NonAbsDiff)
1421514229
return SDValue();
1421614230
// Make sure it's checking if the difference is less than 0
1421714231
if (!isNullConstant(Select.getOperand(1)) ||
@@ -14243,22 +14257,26 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
1424314257
} else if (LaneMask.getOperand(2) != Diff)
1424414258
return SDValue();
1424514259

14246-
SDValue StorePtr = Diff.getOperand(0);
14247-
SDValue ReadPtr = Diff.getOperand(1);
14260+
SDValue StorePtr = NonAbsDiff.getOperand(0);
14261+
SDValue ReadPtr = NonAbsDiff.getOperand(1);
1424814262

1424914263
unsigned IntrinsicID = 0;
1425014264
switch (EltSize) {
1425114265
case 1:
14252-
IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
14266+
IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_b
14267+
: Intrinsic::aarch64_sve_whilerw_b;
1425314268
break;
1425414269
case 2:
14255-
IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
14270+
IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_h
14271+
: Intrinsic::aarch64_sve_whilerw_h;
1425614272
break;
1425714273
case 4:
14258-
IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
14274+
IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_s
14275+
: Intrinsic::aarch64_sve_whilerw_s;
1425914276
break;
1426014277
case 8:
14261-
IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
14278+
IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_d
14279+
: Intrinsic::aarch64_sve_whilerw_d;
1426214280
break;
1426314281
default:
1426414282
return SDValue();

llvm/test/CodeGen/AArch64/whilewr.ll

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,36 @@ entry:
3030
ret <vscale x 16 x i1> %active.lane.mask.alias
3131
}
3232

33+
define <vscale x 16 x i1> @whilerw_8(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
34+
; CHECK-SVE2-LABEL: whilerw_8:
35+
; CHECK-SVE2: // %bb.0: // %entry
36+
; CHECK-SVE2-NEXT: whilerw p0.b, x2, x1
37+
; CHECK-SVE2-NEXT: ret
38+
;
39+
; CHECK-NOSVE2-LABEL: whilerw_8:
40+
; CHECK-NOSVE2: // %bb.0: // %entry
41+
; CHECK-NOSVE2-NEXT: subs x8, x2, x1
42+
; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
43+
; CHECK-NOSVE2-NEXT: cmp x8, #1
44+
; CHECK-NOSVE2-NEXT: cset w9, lt
45+
; CHECK-NOSVE2-NEXT: whilelo p0.b, xzr, x8
46+
; CHECK-NOSVE2-NEXT: sbfx x8, x9, #0, #1
47+
; CHECK-NOSVE2-NEXT: whilelo p1.b, xzr, x8
48+
; CHECK-NOSVE2-NEXT: sel p0.b, p0, p0.b, p1.b
49+
; CHECK-NOSVE2-NEXT: ret
50+
entry:
51+
%b24 = ptrtoint ptr %b to i64
52+
%c25 = ptrtoint ptr %c to i64
53+
%sub.diff = sub i64 %c25, %b24
54+
%0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
55+
%neg.compare = icmp slt i64 %0, 1
56+
%.splatinsert = insertelement <vscale x 16 x i1> poison, i1 %neg.compare, i64 0
57+
%.splat = shufflevector <vscale x 16 x i1> %.splatinsert, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
58+
%ptr.diff.lane.mask = tail call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %0)
59+
%active.lane.mask.alias = or <vscale x 16 x i1> %ptr.diff.lane.mask, %.splat
60+
ret <vscale x 16 x i1> %active.lane.mask.alias
61+
}
62+
3363
define <vscale x 16 x i1> @whilewr_commutative(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
3464
; CHECK-LABEL: whilewr_commutative:
3565
; CHECK: // %bb.0: // %entry
@@ -89,6 +119,39 @@ entry:
89119
ret <vscale x 8 x i1> %active.lane.mask.alias
90120
}
91121

122+
define <vscale x 8 x i1> @whilerw_16(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
123+
; CHECK-SVE2-LABEL: whilerw_16:
124+
; CHECK-SVE2: // %bb.0: // %entry
125+
; CHECK-SVE2-NEXT: whilerw p0.h, x2, x1
126+
; CHECK-SVE2-NEXT: ret
127+
;
128+
; CHECK-NOSVE2-LABEL: whilerw_16:
129+
; CHECK-NOSVE2: // %bb.0: // %entry
130+
; CHECK-NOSVE2-NEXT: subs x8, x2, x1
131+
; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
132+
; CHECK-NOSVE2-NEXT: cmp x8, #2
133+
; CHECK-NOSVE2-NEXT: add x8, x8, x8, lsr #63
134+
; CHECK-NOSVE2-NEXT: cset w9, lt
135+
; CHECK-NOSVE2-NEXT: sbfx x9, x9, #0, #1
136+
; CHECK-NOSVE2-NEXT: asr x8, x8, #1
137+
; CHECK-NOSVE2-NEXT: whilelo p0.h, xzr, x9
138+
; CHECK-NOSVE2-NEXT: whilelo p1.h, xzr, x8
139+
; CHECK-NOSVE2-NEXT: mov p0.b, p1/m, p1.b
140+
; CHECK-NOSVE2-NEXT: ret
141+
entry:
142+
%b24 = ptrtoint ptr %b to i64
143+
%c25 = ptrtoint ptr %c to i64
144+
%sub.diff = sub i64 %c25, %b24
145+
%0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
146+
%diff = sdiv i64 %0, 2
147+
%neg.compare = icmp slt i64 %0, 2
148+
%.splatinsert = insertelement <vscale x 8 x i1> poison, i1 %neg.compare, i64 0
149+
%.splat = shufflevector <vscale x 8 x i1> %.splatinsert, <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer
150+
%ptr.diff.lane.mask = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 %diff)
151+
%active.lane.mask.alias = or <vscale x 8 x i1> %ptr.diff.lane.mask, %.splat
152+
ret <vscale x 8 x i1> %active.lane.mask.alias
153+
}
154+
92155
define <vscale x 4 x i1> @whilewr_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
93156
; CHECK-LABEL: whilewr_32:
94157
; CHECK: // %bb.0: // %entry
@@ -122,6 +185,41 @@ entry:
122185
ret <vscale x 4 x i1> %active.lane.mask.alias
123186
}
124187

188+
define <vscale x 4 x i1> @whilerw_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
189+
; CHECK-SVE2-LABEL: whilerw_32:
190+
; CHECK-SVE2: // %bb.0: // %entry
191+
; CHECK-SVE2-NEXT: whilerw p0.s, x2, x1
192+
; CHECK-SVE2-NEXT: ret
193+
;
194+
; CHECK-NOSVE2-LABEL: whilerw_32:
195+
; CHECK-NOSVE2: // %bb.0: // %entry
196+
; CHECK-NOSVE2-NEXT: subs x8, x2, x1
197+
; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
198+
; CHECK-NOSVE2-NEXT: add x9, x8, #3
199+
; CHECK-NOSVE2-NEXT: cmp x8, #0
200+
; CHECK-NOSVE2-NEXT: csel x9, x9, x8, lt
201+
; CHECK-NOSVE2-NEXT: cmp x8, #4
202+
; CHECK-NOSVE2-NEXT: cset w8, lt
203+
; CHECK-NOSVE2-NEXT: asr x9, x9, #2
204+
; CHECK-NOSVE2-NEXT: sbfx x8, x8, #0, #1
205+
; CHECK-NOSVE2-NEXT: whilelo p1.s, xzr, x9
206+
; CHECK-NOSVE2-NEXT: whilelo p0.s, xzr, x8
207+
; CHECK-NOSVE2-NEXT: mov p0.b, p1/m, p1.b
208+
; CHECK-NOSVE2-NEXT: ret
209+
entry:
210+
%b24 = ptrtoint ptr %b to i64
211+
%c25 = ptrtoint ptr %c to i64
212+
%sub.diff = sub i64 %c25, %b24
213+
%0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
214+
%diff = sdiv i64 %0, 4
215+
%neg.compare = icmp slt i64 %0, 4
216+
%.splatinsert = insertelement <vscale x 4 x i1> poison, i1 %neg.compare, i64 0
217+
%.splat = shufflevector <vscale x 4 x i1> %.splatinsert, <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer
218+
%ptr.diff.lane.mask = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %diff)
219+
%active.lane.mask.alias = or <vscale x 4 x i1> %ptr.diff.lane.mask, %.splat
220+
ret <vscale x 4 x i1> %active.lane.mask.alias
221+
}
222+
125223
define <vscale x 2 x i1> @whilewr_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
126224
; CHECK-LABEL: whilewr_64:
127225
; CHECK: // %bb.0: // %entry
@@ -155,6 +253,41 @@ entry:
155253
ret <vscale x 2 x i1> %active.lane.mask.alias
156254
}
157255

256+
define <vscale x 2 x i1> @whilerw_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
257+
; CHECK-SVE2-LABEL: whilerw_64:
258+
; CHECK-SVE2: // %bb.0: // %entry
259+
; CHECK-SVE2-NEXT: whilerw p0.d, x2, x1
260+
; CHECK-SVE2-NEXT: ret
261+
;
262+
; CHECK-NOSVE2-LABEL: whilerw_64:
263+
; CHECK-NOSVE2: // %bb.0: // %entry
264+
; CHECK-NOSVE2-NEXT: subs x8, x2, x1
265+
; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
266+
; CHECK-NOSVE2-NEXT: add x9, x8, #7
267+
; CHECK-NOSVE2-NEXT: cmp x8, #0
268+
; CHECK-NOSVE2-NEXT: csel x9, x9, x8, lt
269+
; CHECK-NOSVE2-NEXT: cmp x8, #8
270+
; CHECK-NOSVE2-NEXT: cset w8, lt
271+
; CHECK-NOSVE2-NEXT: asr x9, x9, #3
272+
; CHECK-NOSVE2-NEXT: sbfx x8, x8, #0, #1
273+
; CHECK-NOSVE2-NEXT: whilelo p1.d, xzr, x9
274+
; CHECK-NOSVE2-NEXT: whilelo p0.d, xzr, x8
275+
; CHECK-NOSVE2-NEXT: mov p0.b, p1/m, p1.b
276+
; CHECK-NOSVE2-NEXT: ret
277+
entry:
278+
%b24 = ptrtoint ptr %b to i64
279+
%c25 = ptrtoint ptr %c to i64
280+
%sub.diff = sub i64 %c25, %b24
281+
%0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
282+
%diff = sdiv i64 %0, 8
283+
%neg.compare = icmp slt i64 %0, 8
284+
%.splatinsert = insertelement <vscale x 2 x i1> poison, i1 %neg.compare, i64 0
285+
%.splat = shufflevector <vscale x 2 x i1> %.splatinsert, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
286+
%ptr.diff.lane.mask = tail call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 %diff)
287+
%active.lane.mask.alias = or <vscale x 2 x i1> %ptr.diff.lane.mask, %.splat
288+
ret <vscale x 2 x i1> %active.lane.mask.alias
289+
}
290+
158291
define <vscale x 1 x i1> @no_whilewr_128(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
159292
; CHECK-LABEL: no_whilewr_128:
160293
; CHECK: // %bb.0: // %entry

0 commit comments

Comments
 (0)