Skip to content

Commit 88da875

Browse files
[AArch64] Combine getActiveLaneMask with vector_extract (#81139)
... into a `whilelo` instruction with a pair of predicate registers.
1 parent d655054 commit 88da875

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20535,6 +20535,66 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
2053520535
return SDValue();
2053620536
}
2053720537

20538+
static SDValue tryCombineWhileLo(SDNode *N,
20539+
TargetLowering::DAGCombinerInfo &DCI,
20540+
const AArch64Subtarget *Subtarget) {
20541+
if (DCI.isBeforeLegalize())
20542+
return SDValue();
20543+
20544+
if (!Subtarget->hasSVE2p1())
20545+
return SDValue();
20546+
20547+
if (!N->hasNUsesOfValue(2, 0))
20548+
return SDValue();
20549+
20550+
const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
20551+
if (HalfSize < 2)
20552+
return SDValue();
20553+
20554+
auto It = N->use_begin();
20555+
SDNode *Lo = *It++;
20556+
SDNode *Hi = *It;
20557+
20558+
if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20559+
Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
20560+
return SDValue();
20561+
20562+
uint64_t OffLo = Lo->getConstantOperandVal(1);
20563+
uint64_t OffHi = Hi->getConstantOperandVal(1);
20564+
20565+
if (OffLo > OffHi) {
20566+
std::swap(Lo, Hi);
20567+
std::swap(OffLo, OffHi);
20568+
}
20569+
20570+
if (OffLo != 0 || OffHi != HalfSize)
20571+
return SDValue();
20572+
20573+
EVT HalfVec = Lo->getValueType(0);
20574+
if (HalfVec != Hi->getValueType(0) ||
20575+
HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
20576+
return SDValue();
20577+
20578+
SelectionDAG &DAG = DCI.DAG;
20579+
SDLoc DL(N);
20580+
SDValue ID =
20581+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
20582+
SDValue Idx = N->getOperand(1);
20583+
SDValue TC = N->getOperand(2);
20584+
if (Idx.getValueType() != MVT::i64) {
20585+
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
20586+
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
20587+
}
20588+
auto R =
20589+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
20590+
{Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
20591+
20592+
DCI.CombineTo(Lo, R.getValue(0));
20593+
DCI.CombineTo(Hi, R.getValue(1));
20594+
20595+
return SDValue(N, 0);
20596+
}
20597+
2053820598
static SDValue performIntrinsicCombine(SDNode *N,
2053920599
TargetLowering::DAGCombinerInfo &DCI,
2054020600
const AArch64Subtarget *Subtarget) {
@@ -20832,6 +20892,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
2083220892
case Intrinsic::aarch64_sve_ptest_last:
2083320893
return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
2083420894
AArch64CC::LAST_ACTIVE);
20895+
case Intrinsic::aarch64_sve_whilelo:
20896+
return tryCombineWhileLo(N, DCI, Subtarget);
2083520897
}
2083620898
return SDValue();
2083720899
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE
3+
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
4+
target triple = "aarch64-linux"
5+
6+
; Test combining of getActiveLaneMask with a pair of extract_vector operations.
7+
8+
define void @test_2x8bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
9+
; CHECK-SVE-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
10+
; CHECK-SVE: // %bb.0:
11+
; CHECK-SVE-NEXT: whilelo p1.b, w0, w1
12+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
13+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
14+
; CHECK-SVE-NEXT: b use
15+
;
16+
; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
17+
; CHECK-SVE2p1: // %bb.0:
18+
; CHECK-SVE2p1-NEXT: mov w8, w1
19+
; CHECK-SVE2p1-NEXT: mov w9, w0
20+
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
21+
; CHECK-SVE2p1-NEXT: b use
22+
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
23+
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
24+
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
25+
tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
26+
ret void
27+
}
28+
29+
define void @test_2x8bit_mask_with_64bit_index_and_trip_count(i64 %i, i64 %n) #0 {
30+
; CHECK-SVE-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
31+
; CHECK-SVE: // %bb.0:
32+
; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
33+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
34+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
35+
; CHECK-SVE-NEXT: b use
36+
;
37+
; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
38+
; CHECK-SVE2p1: // %bb.0:
39+
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x0, x1
40+
; CHECK-SVE2p1-NEXT: b use
41+
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
42+
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
43+
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
44+
tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
45+
ret void
46+
}
47+
48+
define void @test_edge_case_2x1bit_mask(i64 %i, i64 %n) #0 {
49+
; CHECK-SVE-LABEL: test_edge_case_2x1bit_mask:
50+
; CHECK-SVE: // %bb.0:
51+
; CHECK-SVE-NEXT: whilelo p1.d, x0, x1
52+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
53+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
54+
; CHECK-SVE-NEXT: b use
55+
;
56+
; CHECK-SVE2p1-LABEL: test_edge_case_2x1bit_mask:
57+
; CHECK-SVE2p1: // %bb.0:
58+
; CHECK-SVE2p1-NEXT: whilelo p1.d, x0, x1
59+
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
60+
; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
61+
; CHECK-SVE2p1-NEXT: b use
62+
%r = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 %i, i64 %n)
63+
%v0 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 0)
64+
%v1 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 1)
65+
tail call void @use(<vscale x 1 x i1> %v0, <vscale x 1 x i1> %v1)
66+
ret void
67+
}
68+
69+
define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
70+
; CHECK-SVE-LABEL: test_boring_case_2x2bit_mask:
71+
; CHECK-SVE: // %bb.0:
72+
; CHECK-SVE-NEXT: whilelo p1.s, x0, x1
73+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
74+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
75+
; CHECK-SVE-NEXT: b use
76+
;
77+
; CHECK-SVE2p1-LABEL: test_boring_case_2x2bit_mask:
78+
; CHECK-SVE2p1: // %bb.0:
79+
; CHECK-SVE2p1-NEXT: whilelo { p0.d, p1.d }, x0, x1
80+
; CHECK-SVE2p1-NEXT: b use
81+
%r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
82+
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
83+
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
84+
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1)
85+
ret void
86+
}
87+
88+
; Negative test for when not extracting exactly two halves of the source vector
89+
define void @test_partial_extract(i64 %i, i64 %n) #0 {
90+
; CHECK-SVE-LABEL: test_partial_extract:
91+
; CHECK-SVE: // %bb.0:
92+
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
93+
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
94+
; CHECK-SVE-NEXT: punpkhi p2.h, p0.b
95+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
96+
; CHECK-SVE-NEXT: punpklo p1.h, p2.b
97+
; CHECK-SVE-NEXT: b use
98+
;
99+
; CHECK-SVE2p1-LABEL: test_partial_extract:
100+
; CHECK-SVE2p1: // %bb.0:
101+
; CHECK-SVE2p1-NEXT: whilelo p0.h, x0, x1
102+
; CHECK-SVE2p1-NEXT: punpklo p1.h, p0.b
103+
; CHECK-SVE2p1-NEXT: punpkhi p2.h, p0.b
104+
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
105+
; CHECK-SVE2p1-NEXT: punpklo p1.h, p2.b
106+
; CHECK-SVE2p1-NEXT: b use
107+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
108+
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
109+
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
110+
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1)
111+
ret void
112+
}
113+
114+
;; Negative test for when extracting a fixed-length vector.
115+
define void @test_fixed_extract(i64 %i, i64 %n) #0 {
116+
; CHECK-SVE-LABEL: test_fixed_extract:
117+
; CHECK-SVE: // %bb.0:
118+
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
119+
; CHECK-SVE-NEXT: cset w8, mi
120+
; CHECK-SVE-NEXT: mov z0.h, p0/z, #1 // =0x1
121+
; CHECK-SVE-NEXT: umov w9, v0.h[4]
122+
; CHECK-SVE-NEXT: umov w10, v0.h[1]
123+
; CHECK-SVE-NEXT: umov w11, v0.h[5]
124+
; CHECK-SVE-NEXT: fmov s0, w8
125+
; CHECK-SVE-NEXT: fmov s1, w9
126+
; CHECK-SVE-NEXT: mov v0.s[1], w10
127+
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $q0
128+
; CHECK-SVE-NEXT: mov v1.s[1], w11
129+
; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 killed $q1
130+
; CHECK-SVE-NEXT: b use
131+
;
132+
; CHECK-SVE2p1-LABEL: test_fixed_extract:
133+
; CHECK-SVE2p1: // %bb.0:
134+
; CHECK-SVE2p1-NEXT: whilelo p0.h, x0, x1
135+
; CHECK-SVE2p1-NEXT: cset w8, mi
136+
; CHECK-SVE2p1-NEXT: mov z0.h, p0/z, #1 // =0x1
137+
; CHECK-SVE2p1-NEXT: umov w9, v0.h[4]
138+
; CHECK-SVE2p1-NEXT: umov w10, v0.h[1]
139+
; CHECK-SVE2p1-NEXT: umov w11, v0.h[5]
140+
; CHECK-SVE2p1-NEXT: fmov s0, w8
141+
; CHECK-SVE2p1-NEXT: fmov s1, w9
142+
; CHECK-SVE2p1-NEXT: mov v0.s[1], w10
143+
; CHECK-SVE2p1-NEXT: // kill: def $d0 killed $d0 killed $q0
144+
; CHECK-SVE2p1-NEXT: mov v1.s[1], w11
145+
; CHECK-SVE2p1-NEXT: // kill: def $d1 killed $d1 killed $q1
146+
; CHECK-SVE2p1-NEXT: b use
147+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
148+
%v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
149+
%v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
150+
tail call void @use(<2 x i1> %v0, <2 x i1> %v1)
151+
ret void
152+
}
153+
154+
declare void @use(...)
155+
156+
attributes #0 = { nounwind }

0 commit comments

Comments
 (0)