Skip to content

Commit af024dd

Browse files
[AArch64] Combine getActiveLaneMask with vector_extract
... into a `whilelo` instruction with a pair of predicate registers.
1 parent 3ad6359 commit af024dd

File tree

3 files changed

+177
-35
lines changed

3 files changed

+177
-35
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,8 +1813,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
18131813

18141814
bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
18151815
EVT OpVT) const {
1816-
// Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
1817-
if (!Subtarget->hasSVE())
1816+
// Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
1817+
if (!Subtarget->hasSVEorSME())
18181818
return true;
18191819

18201820
// We can only support legal predicate result types. We can use the SVE
@@ -20004,47 +20004,98 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
2000420004
return SDValue();
2000520005
}
2000620006

20007-
static SDValue performIntrinsicCombine(SDNode *N,
20008-
TargetLowering::DAGCombinerInfo &DCI,
20009-
const AArch64Subtarget *Subtarget) {
20007+
static SDValue tryCombineGetActiveLaneMask(SDNode *N,
20008+
TargetLowering::DAGCombinerInfo &DCI,
20009+
const AArch64Subtarget *Subtarget) {
2001020010
SelectionDAG &DAG = DCI.DAG;
20011-
unsigned IID = getIntrinsicID(N);
20012-
switch (IID) {
20013-
default:
20014-
break;
20015-
case Intrinsic::get_active_lane_mask: {
20016-
SDValue Res = SDValue();
20017-
EVT VT = N->getValueType(0);
20018-
if (VT.isFixedLengthVector()) {
20019-
// We can use the SVE whilelo instruction to lower this intrinsic by
20020-
// creating the appropriate sequence of scalable vector operations and
20021-
// then extracting a fixed-width subvector from the scalable vector.
20011+
EVT VT = N->getValueType(0);
20012+
if (VT.isFixedLengthVector()) {
20013+
// We can use the SVE whilelo instruction to lower this intrinsic by
20014+
// creating the appropriate sequence of scalable vector operations and
20015+
// then extracting a fixed-width subvector from the scalable vector.
20016+
SDLoc DL(N);
20017+
SDValue ID =
20018+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
2002220019

20023-
SDLoc DL(N);
20024-
SDValue ID =
20025-
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
20020+
EVT WhileVT =
20021+
EVT::getVectorVT(*DAG.getContext(), MVT::i1,
20022+
ElementCount::getScalable(VT.getVectorNumElements()));
2002620023

20027-
EVT WhileVT = EVT::getVectorVT(
20028-
*DAG.getContext(), MVT::i1,
20029-
ElementCount::getScalable(VT.getVectorNumElements()));
20024+
// Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
20025+
EVT PromVT = getPromotedVTForPredicate(WhileVT);
2003020026

20031-
// Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
20032-
EVT PromVT = getPromotedVTForPredicate(WhileVT);
20027+
// Get the fixed-width equivalent of PromVT for extraction.
20028+
EVT ExtVT =
20029+
EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
20030+
VT.getVectorElementCount());
2003320031

20034-
// Get the fixed-width equivalent of PromVT for extraction.
20035-
EVT ExtVT =
20036-
EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
20037-
VT.getVectorElementCount());
20032+
SDValue Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
20033+
N->getOperand(1), N->getOperand(2));
20034+
Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
20035+
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
20036+
DAG.getConstant(0, DL, MVT::i64));
20037+
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
2003820038

20039-
Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
20040-
N->getOperand(1), N->getOperand(2));
20041-
Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
20042-
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
20043-
DAG.getConstant(0, DL, MVT::i64));
20044-
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
20045-
}
2004620039
return Res;
2004720040
}
20041+
20042+
if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
20043+
return SDValue();
20044+
20045+
if (!N->hasNUsesOfValue(2, 0))
20046+
return SDValue();
20047+
20048+
auto It = N->use_begin();
20049+
SDNode *Lo = *It++;
20050+
SDNode *Hi = *It;
20051+
20052+
const uint64_t HalfSize = VT.getVectorMinNumElements() / 2;
20053+
uint64_t OffLo, OffHi;
20054+
if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20055+
!isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
20056+
(OffLo != 0 && OffLo != HalfSize) ||
20057+
Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
20058+
!isIntImmediate(Hi->getOperand(1).getNode(), OffHi) ||
20059+
(OffHi != 0 && OffHi != HalfSize))
20060+
return SDValue();
20061+
20062+
if (OffLo > OffHi) {
20063+
std::swap(Lo, Hi);
20064+
std::swap(OffLo, OffHi);
20065+
}
20066+
20067+
if (OffLo != 0 || OffHi != HalfSize)
20068+
return SDValue();
20069+
20070+
SDLoc DL(N);
20071+
SDValue ID =
20072+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
20073+
SDValue Idx = N->getOperand(1);
20074+
SDValue TC = N->getOperand(2);
20075+
if (Idx.getValueType() != MVT::i64) {
20076+
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
20077+
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
20078+
}
20079+
auto R =
20080+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
20081+
{Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
20082+
20083+
DCI.CombineTo(Lo, R.getValue(0));
20084+
DCI.CombineTo(Hi, R.getValue(1));
20085+
20086+
return SDValue(N, 0);
20087+
}
20088+
20089+
static SDValue performIntrinsicCombine(SDNode *N,
20090+
TargetLowering::DAGCombinerInfo &DCI,
20091+
const AArch64Subtarget *Subtarget) {
20092+
SelectionDAG &DAG = DCI.DAG;
20093+
unsigned IID = getIntrinsicID(N);
20094+
switch (IID) {
20095+
default:
20096+
break;
20097+
case Intrinsic::get_active_lane_mask:
20098+
return tryCombineGetActiveLaneMask(N, DCI, Subtarget);
2004820099
case Intrinsic::aarch64_neon_vcvtfxs2fp:
2004920100
case Intrinsic::aarch64_neon_vcvtfxu2fp:
2005020101
return tryCombineFixedPointConvert(N, DCI, DAG);

llvm/test/CodeGen/AArch64/active_lane_mask.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
3+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme < %s | FileCheck %s
34

45
; == Scalable ==
56

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE
3+
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
4+
; RUN: llc -mattr=+sme2 < %s | FileCheck %s -check-prefix CHECK-SME2
5+
target triple = "aarch64-linux"
6+
7+
; Test combining of getActiveLaneMask with a pair of extract_vector operations.
8+
9+
define void @f0(i32 %i, i32 %n, ptr %p0, ptr %p1) #0 {
10+
; CHECK-SVE-LABEL: f0:
11+
; CHECK-SVE: // %bb.0:
12+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
13+
; CHECK-SVE-NEXT: ptrue p1.h
14+
; CHECK-SVE-NEXT: punpklo p2.h, p0.b
15+
; CHECK-SVE-NEXT: punpkhi p0.h, p0.b
16+
; CHECK-SVE-NEXT: and p2.b, p2/z, p2.b, p1.b
17+
; CHECK-SVE-NEXT: and p0.b, p0/z, p0.b, p1.b
18+
; CHECK-SVE-NEXT: str p2, [x2]
19+
; CHECK-SVE-NEXT: str p0, [x3]
20+
; CHECK-SVE-NEXT: ret
21+
;
22+
; CHECK-SVE2p1-LABEL: f0:
23+
; CHECK-SVE2p1: // %bb.0:
24+
; CHECK-SVE2p1-NEXT: mov w8, w1
25+
; CHECK-SVE2p1-NEXT: mov w9, w0
26+
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
27+
; CHECK-SVE2p1-NEXT: str p0, [x2]
28+
; CHECK-SVE2p1-NEXT: str p1, [x3]
29+
; CHECK-SVE2p1-NEXT: ret
30+
;
31+
; CHECK-SME2-LABEL: f0:
32+
; CHECK-SME2: // %bb.0:
33+
; CHECK-SME2-NEXT: mov w8, w1
34+
; CHECK-SME2-NEXT: mov w9, w0
35+
; CHECK-SME2-NEXT: whilelo { p0.h, p1.h }, x9, x8
36+
; CHECK-SME2-NEXT: str p0, [x2]
37+
; CHECK-SME2-NEXT: str p1, [x3]
38+
; CHECK-SME2-NEXT: ret
39+
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
40+
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
41+
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
42+
%pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
43+
%pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
44+
store <vscale x 16 x i1> %pg0, ptr %p0
45+
store <vscale x 16 x i1> %pg1, ptr %p1
46+
ret void
47+
}
48+
49+
define void @f1(i64 %i, i64 %n, ptr %p0, ptr %p1) #0 {
50+
; CHECK-SVE-LABEL: f1:
51+
; CHECK-SVE: // %bb.0:
52+
; CHECK-SVE-NEXT: whilelo p0.b, x0, x1
53+
; CHECK-SVE-NEXT: ptrue p1.h
54+
; CHECK-SVE-NEXT: punpklo p2.h, p0.b
55+
; CHECK-SVE-NEXT: punpkhi p0.h, p0.b
56+
; CHECK-SVE-NEXT: and p2.b, p2/z, p2.b, p1.b
57+
; CHECK-SVE-NEXT: and p0.b, p0/z, p0.b, p1.b
58+
; CHECK-SVE-NEXT: str p2, [x2]
59+
; CHECK-SVE-NEXT: str p0, [x3]
60+
; CHECK-SVE-NEXT: ret
61+
;
62+
; CHECK-SVE2p1-LABEL: f1:
63+
; CHECK-SVE2p1: // %bb.0:
64+
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x0, x1
65+
; CHECK-SVE2p1-NEXT: str p0, [x2]
66+
; CHECK-SVE2p1-NEXT: str p1, [x3]
67+
; CHECK-SVE2p1-NEXT: ret
68+
;
69+
; CHECK-SME2-LABEL: f1:
70+
; CHECK-SME2: // %bb.0:
71+
; CHECK-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
72+
; CHECK-SME2-NEXT: str p0, [x2]
73+
; CHECK-SME2-NEXT: str p1, [x3]
74+
; CHECK-SME2-NEXT: ret
75+
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
76+
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
77+
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
78+
%pg0 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v0)
79+
%pg1 = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %v1)
80+
store <vscale x 16 x i1> %pg0, ptr %p0
81+
store <vscale x 16 x i1> %pg1, ptr %p1
82+
ret void
83+
}
84+
85+
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
86+
declare <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1>, i64)
87+
declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32, i32)
88+
declare <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64, i64)
89+
90+
attributes #0 = { nounwind }

0 commit comments

Comments
 (0)