Skip to content

Commit 267d6b5

Browse files
authored
[AArch64][SVE] Instcombine uzp1/reinterpret svbool to use vector.insert (#81069)
Concatenating two predictes using uzp1 after converting to double length using sve.convert.to/from.svbool is optimized poorly in the backend, resulting in additional `and` instructions to zero the lanes. See #78623 Combine this pattern to use `llvm.vector.insert` to concatenate and get rid of convert to/from svbools.
1 parent e82659f commit 267d6b5

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,36 @@ static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
16301630
return IC.replaceInstUsesWith(II, VectorSplat);
16311631
}
16321632

1633+
static std::optional<Instruction *> instCombineSVEUzp1(InstCombiner &IC,
1634+
IntrinsicInst &II) {
1635+
Value *A, *B;
1636+
Type *RetTy = II.getType();
1637+
constexpr Intrinsic::ID FromSVB = Intrinsic::aarch64_sve_convert_from_svbool;
1638+
constexpr Intrinsic::ID ToSVB = Intrinsic::aarch64_sve_convert_to_svbool;
1639+
1640+
// uzp1(to_svbool(A), to_svbool(B)) --> <A, B>
1641+
// uzp1(from_svbool(to_svbool(A)), from_svbool(to_svbool(B))) --> <A, B>
1642+
if ((match(II.getArgOperand(0),
1643+
m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(A)))) &&
1644+
match(II.getArgOperand(1),
1645+
m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(B))))) ||
1646+
(match(II.getArgOperand(0), m_Intrinsic<ToSVB>(m_Value(A))) &&
1647+
match(II.getArgOperand(1), m_Intrinsic<ToSVB>(m_Value(B))))) {
1648+
auto *TyA = cast<ScalableVectorType>(A->getType());
1649+
if (TyA == B->getType() &&
1650+
RetTy == ScalableVectorType::getDoubleElementsVectorType(TyA)) {
1651+
auto *SubVec = IC.Builder.CreateInsertVector(
1652+
RetTy, PoisonValue::get(RetTy), A, IC.Builder.getInt64(0));
1653+
auto *ConcatVec = IC.Builder.CreateInsertVector(
1654+
RetTy, SubVec, B, IC.Builder.getInt64(TyA->getMinNumElements()));
1655+
ConcatVec->takeName(&II);
1656+
return IC.replaceInstUsesWith(II, ConcatVec);
1657+
}
1658+
}
1659+
1660+
return std::nullopt;
1661+
}
1662+
16331663
static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
16341664
IntrinsicInst &II) {
16351665
// zip1(uzp1(A, B), uzp2(A, B)) --> A
@@ -2012,6 +2042,8 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
20122042
case Intrinsic::aarch64_sve_sunpkhi:
20132043
case Intrinsic::aarch64_sve_sunpklo:
20142044
return instCombineSVEUnpack(IC, II);
2045+
case Intrinsic::aarch64_sve_uzp1:
2046+
return instCombineSVEUzp1(IC, II);
20152047
case Intrinsic::aarch64_sve_zip1:
20162048
case Intrinsic::aarch64_sve_zip2:
20172049
return instCombineSVEZip(IC, II);
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt -S -passes=instcombine -mtriple=aarch64 < %s | FileCheck %s
3+
4+
; Transform the SVE idiom used to concatenate two vectors into target agnostic IR.
5+
6+
define <vscale x 8 x i1> @reinterpt_uzp1_1(<vscale x 4 x i1> %cmp0, <vscale x 4 x i1> %cmp1) {
7+
; CHECK-LABEL: define <vscale x 8 x i1> @reinterpt_uzp1_1(
8+
; CHECK-SAME: <vscale x 4 x i1> [[CMP0:%.*]], <vscale x 4 x i1> [[CMP1:%.*]]) {
9+
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> poison, <vscale x 4 x i1> [[CMP0]], i64 0)
10+
; CHECK-NEXT: [[UZ1:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> [[TMP1]], <vscale x 4 x i1> [[CMP1]], i64 4)
11+
; CHECK-NEXT: ret <vscale x 8 x i1> [[UZ1]]
12+
;
13+
%1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp0)
14+
%2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %1)
15+
%3 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp1)
16+
%4 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %3)
17+
%uz1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %2, <vscale x 8 x i1> %4)
18+
ret <vscale x 8 x i1> %uz1
19+
}
20+
21+
define <vscale x 8 x i1> @reinterpt_uzp1_2(<vscale x 2 x i1> %cmp0, <vscale x 2 x i1> %cmp1) {
22+
; CHECK-LABEL: define <vscale x 8 x i1> @reinterpt_uzp1_2(
23+
; CHECK-SAME: <vscale x 2 x i1> [[CMP0:%.*]], <vscale x 2 x i1> [[CMP1:%.*]]) {
24+
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> [[CMP0]])
25+
; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[TMP1]])
26+
; CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> [[CMP1]])
27+
; CHECK-NEXT: [[TMP4:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> [[TMP3]])
28+
; CHECK-NEXT: [[TMP5:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> poison, <vscale x 4 x i1> [[TMP2]], i64 0)
29+
; CHECK-NEXT: [[UZ1:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> [[TMP5]], <vscale x 4 x i1> [[TMP4]], i64 4)
30+
; CHECK-NEXT: ret <vscale x 8 x i1> [[UZ1]]
31+
;
32+
%1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %cmp0)
33+
%2 = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %1)
34+
%3 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %cmp1)
35+
%4 = tail call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %3)
36+
%5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %2)
37+
%6 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %5)
38+
%7 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %4)
39+
%8 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %7)
40+
%uz1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %6, <vscale x 8 x i1> %8)
41+
ret <vscale x 8 x i1> %uz1
42+
}
43+
44+
define <vscale x 16 x i1> @reinterpt_uzp1_3(<vscale x 4 x i1> %cmp0, <vscale x 4 x i1> %cmp1, <vscale x 4 x i1> %cmp2, <vscale x 4 x i1> %cmp3) {
45+
; CHECK-LABEL: define <vscale x 16 x i1> @reinterpt_uzp1_3(
46+
; CHECK-SAME: <vscale x 4 x i1> [[CMP0:%.*]], <vscale x 4 x i1> [[CMP1:%.*]], <vscale x 4 x i1> [[CMP2:%.*]], <vscale x 4 x i1> [[CMP3:%.*]]) {
47+
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> poison, <vscale x 4 x i1> [[CMP0]], i64 0)
48+
; CHECK-NEXT: [[UZ1_1:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> [[TMP1]], <vscale x 4 x i1> [[CMP1]], i64 4)
49+
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> poison, <vscale x 4 x i1> [[CMP2]], i64 0)
50+
; CHECK-NEXT: [[UZ1_2:%.*]] = call <vscale x 8 x i1> @llvm.vector.insert.nxv8i1.nxv4i1(<vscale x 8 x i1> [[TMP2]], <vscale x 4 x i1> [[CMP3]], i64 4)
51+
; CHECK-NEXT: [[TMP3:%.*]] = call <vscale x 16 x i1> @llvm.vector.insert.nxv16i1.nxv8i1(<vscale x 16 x i1> poison, <vscale x 8 x i1> [[UZ1_1]], i64 0)
52+
; CHECK-NEXT: [[UZ3:%.*]] = call <vscale x 16 x i1> @llvm.vector.insert.nxv16i1.nxv8i1(<vscale x 16 x i1> [[TMP3]], <vscale x 8 x i1> [[UZ1_2]], i64 8)
53+
; CHECK-NEXT: ret <vscale x 16 x i1> [[UZ3]]
54+
;
55+
%1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp0)
56+
%2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %1)
57+
%3 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp1)
58+
%4 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %3)
59+
%uz1_1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %2, <vscale x 8 x i1> %4)
60+
%5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp2)
61+
%6 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %5)
62+
%7 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp3)
63+
%8 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %7)
64+
%uz1_2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %6, <vscale x 8 x i1> %8)
65+
%9 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %uz1_1)
66+
%10 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %uz1_2)
67+
%uz3 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.uzp1.nxv16i1(<vscale x 16 x i1> %9, <vscale x 16 x i1> %10)
68+
ret <vscale x 16 x i1> %uz3
69+
}
70+
71+
define <vscale x 4 x i1> @neg1(<vscale x 4 x i1> %cmp0, <vscale x 4 x i1> %cmp1) {
72+
; CHECK-LABEL: define <vscale x 4 x i1> @neg1(
73+
; CHECK-SAME: <vscale x 4 x i1> [[CMP0:%.*]], <vscale x 4 x i1> [[CMP1:%.*]]) {
74+
; CHECK-NEXT: [[UZ1:%.*]] = tail call <vscale x 4 x i1> @llvm.aarch64.sve.uzp1.nxv4i1(<vscale x 4 x i1> [[CMP0]], <vscale x 4 x i1> [[CMP1]])
75+
; CHECK-NEXT: ret <vscale x 4 x i1> [[UZ1]]
76+
;
77+
%uz1 = tail call <vscale x 4 x i1> @llvm.aarch64.sve.uzp1.nxv4i1(<vscale x 4 x i1> %cmp0, <vscale x 4 x i1> %cmp1)
78+
ret <vscale x 4 x i1> %uz1
79+
}
80+
81+
define <vscale x 8 x i1> @neg2(<vscale x 2 x i1> %cmp0, <vscale x 4 x i1> %cmp1) {
82+
; CHECK-LABEL: define <vscale x 8 x i1> @neg2(
83+
; CHECK-SAME: <vscale x 2 x i1> [[CMP0:%.*]], <vscale x 4 x i1> [[CMP1:%.*]]) {
84+
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> [[CMP0]])
85+
; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP1]])
86+
; CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> [[CMP1]])
87+
; CHECK-NEXT: [[TMP4:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP3]])
88+
; CHECK-NEXT: [[UZ1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> [[TMP2]], <vscale x 8 x i1> [[TMP4]])
89+
; CHECK-NEXT: ret <vscale x 8 x i1> [[UZ1]]
90+
;
91+
%1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %cmp0)
92+
%2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %1)
93+
%3 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp1)
94+
%4 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %3)
95+
%uz1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %2, <vscale x 8 x i1> %4)
96+
ret <vscale x 8 x i1> %uz1
97+
}
98+
99+
define <vscale x 8 x i1> @neg3(<vscale x 8 x i1> %cmp0, <vscale x 4 x i1> %cmp1) {
100+
; CHECK-LABEL: define <vscale x 8 x i1> @neg3(
101+
; CHECK-SAME: <vscale x 8 x i1> [[CMP0:%.*]], <vscale x 4 x i1> [[CMP1:%.*]]) {
102+
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> [[CMP1]])
103+
; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP1]])
104+
; CHECK-NEXT: [[UZ1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> [[CMP0]], <vscale x 8 x i1> [[TMP2]])
105+
; CHECK-NEXT: ret <vscale x 8 x i1> [[UZ1]]
106+
;
107+
%1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %cmp1)
108+
%2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %1)
109+
%uz1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %cmp0, <vscale x 8 x i1> %2)
110+
ret <vscale x 8 x i1> %uz1
111+
}
112+
113+
define <vscale x 8 x i1> @neg4(<vscale x 2 x i1> %cmp0, <vscale x 2 x i1> %cmp1) {
114+
; CHECK-LABEL: define <vscale x 8 x i1> @neg4(
115+
; CHECK-SAME: <vscale x 2 x i1> [[CMP0:%.*]], <vscale x 2 x i1> [[CMP1:%.*]]) {
116+
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> [[CMP0]])
117+
; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP1]])
118+
; CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> [[CMP1]])
119+
; CHECK-NEXT: [[TMP4:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[TMP3]])
120+
; CHECK-NEXT: [[UZ1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> [[TMP2]], <vscale x 8 x i1> [[TMP4]])
121+
; CHECK-NEXT: ret <vscale x 8 x i1> [[UZ1]]
122+
;
123+
%1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %cmp0)
124+
%2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %1)
125+
%3 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %cmp1)
126+
%4 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %3)
127+
%uz1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1> %2, <vscale x 8 x i1> %4)
128+
ret <vscale x 8 x i1> %uz1
129+
}
130+
131+
declare <vscale x 4 x i1> @llvm.aarch64.sve.uzp1.nxv4i1(<vscale x 4 x i1>, <vscale x 4 x i1>)
132+
declare <vscale x 8 x i1> @llvm.aarch64.sve.uzp1.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>)
133+
declare <vscale x 16 x i1> @llvm.aarch64.sve.uzp1.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>)
134+
135+
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
136+
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1>)
137+
declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1>)
138+
declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
139+
declare <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1>)
140+

0 commit comments

Comments
 (0)