Skip to content

Commit b3b2429

Browse files
DevM-ukyuxuanchen1997
authored andcommitted
[AArch64] Lower scalable i1 vector add reduction to cntp (#100118)
Summary: Doing an add reduction on a vector of i1 elements is the same as counting the number of set elements so such a reduction can be lowered to a cntp instruction. This saves a number of instructions over performing a UADDV. This patch only handles straightforward cases (i.e. when vectors are not split). Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251189
1 parent 7de6b2c commit b3b2429

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27659,6 +27659,20 @@ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
2765927659
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
2766027660
}
2766127661

27662+
// Lower VECREDUCE_ADD of nxv2i1-nxv16i1 to CNTP rather than UADDV.
27663+
if (ScalarOp.getOpcode() == ISD::VECREDUCE_ADD &&
27664+
VecOp.getOpcode() == ISD::ZERO_EXTEND) {
27665+
SDValue BoolVec = VecOp.getOperand(0);
27666+
if (BoolVec.getValueType().getVectorElementType() == MVT::i1) {
27667+
// CNTP(BoolVec & BoolVec) <=> CNTP(BoolVec & PTRUE)
27668+
SDValue CntpOp = DAG.getNode(
27669+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
27670+
DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
27671+
BoolVec, BoolVec);
27672+
return DAG.getAnyExtOrTrunc(CntpOp, DL, ScalarOp.getValueType());
27673+
}
27674+
}
27675+
2766227676
// UADDV always returns an i64 result.
2766327677
EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
2766427678
SrcVT.getVectorElementType();
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
3+
4+
define i8 @uaddv_zexti8_nxv16i1(<vscale x 16 x i1> %v) {
5+
; CHECK-LABEL: uaddv_zexti8_nxv16i1:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: cntp x0, p0, p0.b
8+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
9+
; CHECK-NEXT: ret
10+
entry:
11+
%3 = zext <vscale x 16 x i1> %v to <vscale x 16 x i8>
12+
%4 = tail call i8 @llvm.vector.reduce.add.nxv16i8(<vscale x 16 x i8> %3)
13+
ret i8 %4
14+
}
15+
16+
define i8 @uaddv_zexti8_nxv8i1(<vscale x 8 x i1> %v) {
17+
; CHECK-LABEL: uaddv_zexti8_nxv8i1:
18+
; CHECK: // %bb.0: // %entry
19+
; CHECK-NEXT: cntp x0, p0, p0.h
20+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
21+
; CHECK-NEXT: ret
22+
entry:
23+
%3 = zext <vscale x 8 x i1> %v to <vscale x 8 x i8>
24+
%4 = tail call i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8> %3)
25+
ret i8 %4
26+
}
27+
28+
define i16 @uaddv_zexti16_nxv8i1(<vscale x 8 x i1> %v) {
29+
; CHECK-LABEL: uaddv_zexti16_nxv8i1:
30+
; CHECK: // %bb.0: // %entry
31+
; CHECK-NEXT: cntp x0, p0, p0.h
32+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
33+
; CHECK-NEXT: ret
34+
entry:
35+
%3 = zext <vscale x 8 x i1> %v to <vscale x 8 x i16>
36+
%4 = tail call i16 @llvm.vector.reduce.add.nxv8i16(<vscale x 8 x i16> %3)
37+
ret i16 %4
38+
}
39+
40+
define i8 @uaddv_zexti8_nxv4i1(<vscale x 4 x i1> %v) {
41+
; CHECK-LABEL: uaddv_zexti8_nxv4i1:
42+
; CHECK: // %bb.0: // %entry
43+
; CHECK-NEXT: cntp x0, p0, p0.s
44+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
45+
; CHECK-NEXT: ret
46+
entry:
47+
%3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i8>
48+
%4 = tail call i8 @llvm.vector.reduce.add.nxv4i8(<vscale x 4 x i8> %3)
49+
ret i8 %4
50+
}
51+
52+
define i16 @uaddv_zexti16_nxv4i1(<vscale x 4 x i1> %v) {
53+
; CHECK-LABEL: uaddv_zexti16_nxv4i1:
54+
; CHECK: // %bb.0: // %entry
55+
; CHECK-NEXT: cntp x0, p0, p0.s
56+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
57+
; CHECK-NEXT: ret
58+
entry:
59+
%3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i16>
60+
%4 = tail call i16 @llvm.vector.reduce.add.nxv4i16(<vscale x 4 x i16> %3)
61+
ret i16 %4
62+
}
63+
64+
define i32 @uaddv_zexti32_nxv4i1(<vscale x 4 x i1> %v) {
65+
; CHECK-LABEL: uaddv_zexti32_nxv4i1:
66+
; CHECK: // %bb.0: // %entry
67+
; CHECK-NEXT: cntp x0, p0, p0.s
68+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
69+
; CHECK-NEXT: ret
70+
entry:
71+
%3 = zext <vscale x 4 x i1> %v to <vscale x 4 x i32>
72+
%4 = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %3)
73+
ret i32 %4
74+
}
75+
76+
define i8 @uaddv_zexti8_nxv2i1(<vscale x 2 x i1> %v) {
77+
; CHECK-LABEL: uaddv_zexti8_nxv2i1:
78+
; CHECK: // %bb.0: // %entry
79+
; CHECK-NEXT: cntp x0, p0, p0.d
80+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
81+
; CHECK-NEXT: ret
82+
entry:
83+
%3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i8>
84+
%4 = tail call i8 @llvm.vector.reduce.add.nxv2i8(<vscale x 2 x i8> %3)
85+
ret i8 %4
86+
}
87+
88+
define i16 @uaddv_zexti16_nxv2i1(<vscale x 2 x i1> %v) {
89+
; CHECK-LABEL: uaddv_zexti16_nxv2i1:
90+
; CHECK: // %bb.0: // %entry
91+
; CHECK-NEXT: cntp x0, p0, p0.d
92+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
93+
; CHECK-NEXT: ret
94+
entry:
95+
%3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i16>
96+
%4 = tail call i16 @llvm.vector.reduce.add.nxv2i16(<vscale x 2 x i16> %3)
97+
ret i16 %4
98+
}
99+
100+
define i32 @uaddv_zexti32_nxv2i1(<vscale x 2 x i1> %v) {
101+
; CHECK-LABEL: uaddv_zexti32_nxv2i1:
102+
; CHECK: // %bb.0: // %entry
103+
; CHECK-NEXT: cntp x0, p0, p0.d
104+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
105+
; CHECK-NEXT: ret
106+
entry:
107+
%3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i32>
108+
%4 = tail call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> %3)
109+
ret i32 %4
110+
}
111+
112+
define i64 @uaddv_zexti64_nxv2i1(<vscale x 2 x i1> %v) {
113+
; CHECK-LABEL: uaddv_zexti64_nxv2i1:
114+
; CHECK: // %bb.0: // %entry
115+
; CHECK-NEXT: cntp x0, p0, p0.d
116+
; CHECK-NEXT: ret
117+
entry:
118+
%3 = zext <vscale x 2 x i1> %v to <vscale x 2 x i64>
119+
%4 = tail call i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64> %3)
120+
ret i64 %4
121+
}
122+
123+
declare i8 @llvm.vector.reduce.add.nxv16i8(<vscale x 16 x i8>)
124+
declare i8 @llvm.vector.reduce.add.nxv8i8(<vscale x 8 x i8>)
125+
declare i16 @llvm.vector.reduce.add.nxv8i16(<vscale x 8 x i16>)
126+
declare i8 @llvm.vector.reduce.add.nxv4i8(<vscale x 4 x i8>)
127+
declare i16 @llvm.vector.reduce.add.nxv4i16(<vscale x 4 x i16>)
128+
declare i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32>)
129+
declare i8 @llvm.vector.reduce.add.nxv2i8(<vscale x 2 x i8>)
130+
declare i16 @llvm.vector.reduce.add.nxv2i16(<vscale x 2 x i16>)
131+
declare i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32>)
132+
declare i64 @llvm.vector.reduce.add.nxv2i64(<vscale x 2 x i64>)

0 commit comments

Comments
 (0)