Skip to content

Commit 718962f

Browse files
[SystemZ] Provide improved cost estimates (#83873)
This commit provides better cost estimates for the llvm.vector.reduce.add intrinsic on SystemZ. These apply to all vector lengths and integer types up to i128. For integer types larger than i128, we fall back to the default cost estimate. This has the effect of lowering the estimated costs of most common instances of the intrinsic. The expected performance impact of this is minimal with a tendency to slightly improve performance of some benchmarks. This commit also provides a test to check the proper computation of the new estimates, as well as the fallback for types larger than i128.
1 parent 3b30559 commit 718962f

File tree

2 files changed

+158
-3
lines changed

2 files changed

+158
-3
lines changed

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "llvm/CodeGen/TargetLowering.h"
2121
#include "llvm/IR/IntrinsicInst.h"
2222
#include "llvm/Support/Debug.h"
23+
#include "llvm/Support/MathExtras.h"
24+
2325
using namespace llvm;
2426

2527
#define DEBUG_TYPE "systemztti"
@@ -1284,17 +1286,42 @@ InstructionCost SystemZTTIImpl::getInterleavedMemoryOpCost(
12841286
return NumVectorMemOps + NumPermutes;
12851287
}
12861288

1287-
static int getVectorIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy) {
1289+
static int
1290+
getVectorIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
1291+
const SmallVectorImpl<Type *> &ParamTys) {
12881292
if (RetTy->isVectorTy() && ID == Intrinsic::bswap)
12891293
return getNumVectorRegs(RetTy); // VPERM
1294+
1295+
if (ID == Intrinsic::vector_reduce_add) {
1296+
// Retrieve number and size of elements for the vector op.
1297+
auto *VTy = cast<FixedVectorType>(ParamTys.front());
1298+
unsigned NumElements = VTy->getNumElements();
1299+
unsigned ScalarSize = VTy->getScalarSizeInBits();
1300+
// For scalar sizes >128 bits, we fall back to the generic cost estimate.
1301+
if (ScalarSize > SystemZ::VectorBits)
1302+
return -1;
1303+
// A single vector register can hold this many elements.
1304+
unsigned MaxElemsPerVector = SystemZ::VectorBits / ScalarSize;
1305+
// This many vector regs are needed to represent the input elements (V).
1306+
unsigned VectorRegsNeeded = getNumVectorRegs(VTy);
1307+
// This many instructions are needed for the final sum of vector elems (S).
1308+
unsigned LastVectorHandling =
1309+
2 * Log2_32_Ceil(std::min(NumElements, MaxElemsPerVector));
1310+
// We use vector adds to create a sum vector, which takes
1311+
// V/2 + V/4 + ... = V - 1 operations.
1312+
// Then, we need S operations to sum up the elements of that sum vector,
1313+
// for a total of V + S - 1 operations.
1314+
int Cost = VectorRegsNeeded + LastVectorHandling - 1;
1315+
return Cost;
1316+
}
12901317
return -1;
12911318
}
12921319

12931320
InstructionCost
12941321
SystemZTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
12951322
TTI::TargetCostKind CostKind) {
1296-
InstructionCost Cost =
1297-
getVectorIntrinsicInstrCost(ICA.getID(), ICA.getReturnType());
1323+
InstructionCost Cost = getVectorIntrinsicInstrCost(
1324+
ICA.getID(), ICA.getReturnType(), ICA.getArgTypes());
12981325
if (Cost != -1)
12991326
return Cost;
13001327
return BaseT::getIntrinsicInstrCost(ICA, CostKind);
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
; RUN: opt < %s -mtriple=systemz-unknown -mcpu=z13 -passes="print<cost-model>" -cost-kind=throughput 2>&1 -disable-output | FileCheck %s
2+
3+
define void @reduce(ptr %src, ptr %dst) {
4+
; CHECK-LABEL: 'reduce'
5+
; CHECK: Cost Model: Found an estimated cost of 2 for instruction: %R2_64 = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> %V2_64)
6+
; CHECK: Cost Model: Found an estimated cost of 3 for instruction: %R4_64 = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %V4_64)
7+
; CHECK: Cost Model: Found an estimated cost of 5 for instruction: %R8_64 = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %V8_64)
8+
; CHECK: Cost Model: Found an estimated cost of 9 for instruction: %R16_64 = call i64 @llvm.vector.reduce.add.v16i64(<16 x i64> %V16_64)
9+
; CHECK: Cost Model: Found an estimated cost of 2 for instruction: %R2_32 = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %V2_32)
10+
; CHECK: Cost Model: Found an estimated cost of 4 for instruction: %R4_32 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %V4_32)
11+
; CHECK: Cost Model: Found an estimated cost of 5 for instruction: %R8_32 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %V8_32)
12+
; CHECK: Cost Model: Found an estimated cost of 7 for instruction: %R16_32 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %V16_32)
13+
; CHECK: Cost Model: Found an estimated cost of 2 for instruction: %R2_16 = call i16 @llvm.vector.reduce.add.v2i16(<2 x i16> %V2_16)
14+
; CHECK: Cost Model: Found an estimated cost of 4 for instruction: %R4_16 = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> %V4_16)
15+
; CHECK: Cost Model: Found an estimated cost of 6 for instruction: %R8_16 = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> %V8_16)
16+
; CHECK: Cost Model: Found an estimated cost of 7 for instruction: %R16_16 = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %V16_16)
17+
; CHECK: Cost Model: Found an estimated cost of 2 for instruction: %R2_8 = call i8 @llvm.vector.reduce.add.v2i8(<2 x i8> %V2_8)
18+
; CHECK: Cost Model: Found an estimated cost of 4 for instruction: %R4_8 = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> %V4_8)
19+
; CHECK: Cost Model: Found an estimated cost of 6 for instruction: %R8_8 = call i8 @llvm.vector.reduce.add.v8i8(<8 x i8> %V8_8)
20+
; CHECK: Cost Model: Found an estimated cost of 8 for instruction: %R16_8 = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> %V16_8)
21+
;
22+
; CHECK: Cost Model: Found an estimated cost of 15 for instruction: %R128_8 = call i8 @llvm.vector.reduce.add.v128i8(<128 x i8> %V128_8)
23+
; CHECK: Cost Model: Found an estimated cost of 20 for instruction: %R4_256 = call i256 @llvm.vector.reduce.add.v4i256(<4 x i256> %V4_256)
24+
25+
; REDUCEADD64
26+
27+
%V2_64 = load <2 x i64>, ptr %src, align 8
28+
%R2_64 = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> %V2_64)
29+
store volatile i64 %R2_64, ptr %dst, align 4
30+
31+
%V4_64 = load <4 x i64>, ptr %src, align 8
32+
%R4_64 = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %V4_64)
33+
store volatile i64 %R4_64, ptr %dst, align 4
34+
35+
%V8_64 = load <8 x i64>, ptr %src, align 8
36+
%R8_64 = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %V8_64)
37+
store volatile i64 %R8_64, ptr %dst, align 4
38+
39+
%V16_64 = load <16 x i64>, ptr %src, align 8
40+
%R16_64 = call i64 @llvm.vector.reduce.add.v16i64(<16 x i64> %V16_64)
41+
store volatile i64 %R16_64, ptr %dst, align 4
42+
43+
; REDUCEADD32
44+
45+
%V2_32 = load <2 x i32>, ptr %src, align 8
46+
%R2_32 = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %V2_32)
47+
store volatile i32 %R2_32, ptr %dst, align 4
48+
49+
%V4_32 = load <4 x i32>, ptr %src, align 8
50+
%R4_32 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %V4_32)
51+
store volatile i32 %R4_32, ptr %dst, align 4
52+
53+
%V8_32 = load <8 x i32>, ptr %src, align 8
54+
%R8_32 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %V8_32)
55+
store volatile i32 %R8_32, ptr %dst, align 4
56+
57+
%V16_32 = load <16 x i32>, ptr %src, align 8
58+
%R16_32 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %V16_32)
59+
store volatile i32 %R16_32, ptr %dst, align 4
60+
61+
; REDUCEADD16
62+
63+
%V2_16 = load <2 x i16>, ptr %src, align 8
64+
%R2_16 = call i16 @llvm.vector.reduce.add.v2i16(<2 x i16> %V2_16)
65+
store volatile i16 %R2_16, ptr %dst, align 4
66+
67+
%V4_16 = load <4 x i16>, ptr %src, align 8
68+
%R4_16 = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> %V4_16)
69+
store volatile i16 %R4_16, ptr %dst, align 4
70+
71+
%V8_16 = load <8 x i16>, ptr %src, align 8
72+
%R8_16 = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> %V8_16)
73+
store volatile i16 %R8_16, ptr %dst, align 4
74+
75+
%V16_16 = load <16 x i16>, ptr %src, align 8
76+
%R16_16 = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %V16_16)
77+
store volatile i16 %R16_16, ptr %dst, align 4
78+
79+
; REDUCEADD8
80+
81+
%V2_8 = load <2 x i8>, ptr %src, align 8
82+
%R2_8 = call i8 @llvm.vector.reduce.add.v2i8(<2 x i8> %V2_8)
83+
store volatile i8 %R2_8, ptr %dst, align 4
84+
85+
%V4_8 = load <4 x i8>, ptr %src, align 8
86+
%R4_8 = call i8 @llvm.vector.reduce.add.v4i8(<4 x i8> %V4_8)
87+
store volatile i8 %R4_8, ptr %dst, align 4
88+
89+
%V8_8 = load <8 x i8>, ptr %src, align 8
90+
%R8_8 = call i8 @llvm.vector.reduce.add.v8i8(<8 x i8> %V8_8)
91+
store volatile i8 %R8_8, ptr %dst, align 4
92+
93+
%V16_8 = load <16 x i8>, ptr %src, align 8
94+
%R16_8 = call i8 @llvm.vector.reduce.add.v16i8(<16 x i8> %V16_8)
95+
store volatile i8 %R16_8, ptr %dst, align 4
96+
97+
; EXTREME VALUES
98+
99+
%V128_8 = load <128 x i8>, ptr %src, align 8
100+
%R128_8 = call i8 @llvm.vector.reduce.add.v128i8(<128 x i8> %V128_8)
101+
store volatile i8 %R128_8, ptr %dst, align 4
102+
103+
%V4_256 = load <4 x i256>, ptr %src, align 8
104+
%R4_256 = call i256 @llvm.vector.reduce.add.v4i256(<4 x i256> %V4_256)
105+
store volatile i256 %R4_256, ptr %dst, align 8
106+
107+
ret void
108+
}
109+
110+
declare i64 @llvm.vector.reduce.add.v2i64(<2 x i64>)
111+
declare i64 @llvm.vector.reduce.add.v4i64(<4 x i64>)
112+
declare i64 @llvm.vector.reduce.add.v8i64(<8 x i64>)
113+
declare i64 @llvm.vector.reduce.add.v16i64(<16 x i64>)
114+
declare i32 @llvm.vector.reduce.add.v2i32(<2 x i32>)
115+
declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
116+
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
117+
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
118+
declare i16 @llvm.vector.reduce.add.v2i16(<2 x i16>)
119+
declare i16 @llvm.vector.reduce.add.v4i16(<4 x i16>)
120+
declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
121+
declare i16 @llvm.vector.reduce.add.v16i16(<16 x i16>)
122+
declare i8 @llvm.vector.reduce.add.v2i8(<2 x i8>)
123+
declare i8 @llvm.vector.reduce.add.v4i8(<4 x i8>)
124+
declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)
125+
declare i8 @llvm.vector.reduce.add.v16i8(<16 x i8>)
126+
127+
declare i8 @llvm.vector.reduce.add.v128i8(<128 x i8>)
128+
declare i256 @llvm.vector.reduce.add.v4i256(<4 x i256>)

0 commit comments

Comments
 (0)