Skip to content

Commit 769c22f

Browse files
authored
[VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x)) iff cost effective (#81852)
Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free. If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely. Fixes #81469
1 parent 21431e0 commit 769c22f

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/IR/PatternMatch.h"
3030
#include "llvm/Support/CommandLine.h"
3131
#include "llvm/Transforms/Utils/Local.h"
32+
#include "llvm/Transforms/Utils/LoopUtils.h"
3233
#include <numeric>
3334
#include <queue>
3435

@@ -111,6 +112,7 @@ class VectorCombine {
111112
bool scalarizeLoadExtract(Instruction &I);
112113
bool foldShuffleOfBinops(Instruction &I);
113114
bool foldShuffleFromReductions(Instruction &I);
115+
bool foldTruncFromReductions(Instruction &I);
114116
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
115117

116118
void replaceValue(Value &Old, Value &New) {
@@ -1526,6 +1528,60 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
15261528
return foldSelectShuffle(*Shuffle, true);
15271529
}
15281530

1531+
/// Determine if its more efficient to fold:
1532+
/// reduce(trunc(x)) -> trunc(reduce(x)).
1533+
bool VectorCombine::foldTruncFromReductions(Instruction &I) {
1534+
auto *II = dyn_cast<IntrinsicInst>(&I);
1535+
if (!II)
1536+
return false;
1537+
1538+
Intrinsic::ID IID = II->getIntrinsicID();
1539+
switch (IID) {
1540+
case Intrinsic::vector_reduce_add:
1541+
case Intrinsic::vector_reduce_mul:
1542+
case Intrinsic::vector_reduce_and:
1543+
case Intrinsic::vector_reduce_or:
1544+
case Intrinsic::vector_reduce_xor:
1545+
break;
1546+
default:
1547+
return false;
1548+
}
1549+
1550+
unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
1551+
Value *ReductionSrc = I.getOperand(0);
1552+
1553+
Value *TruncSrc;
1554+
if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
1555+
return false;
1556+
1557+
auto *Trunc = cast<CastInst>(ReductionSrc);
1558+
auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
1559+
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
1560+
Type *ResultTy = I.getType();
1561+
1562+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1563+
InstructionCost OldCost =
1564+
TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
1565+
TTI::CastContextHint::None, CostKind, Trunc) +
1566+
TTI.getArithmeticReductionCost(ReductionOpc, ReductionSrcTy, std::nullopt,
1567+
CostKind);
1568+
InstructionCost NewCost =
1569+
TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
1570+
CostKind) +
1571+
TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
1572+
ReductionSrcTy->getScalarType(),
1573+
TTI::CastContextHint::None, CostKind);
1574+
1575+
if (OldCost <= NewCost || !NewCost.isValid())
1576+
return false;
1577+
1578+
Value *NewReduction = Builder.CreateIntrinsic(
1579+
TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
1580+
Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
1581+
replaceValue(I, *NewTruncation);
1582+
return true;
1583+
}
1584+
15291585
/// This method looks for groups of shuffles acting on binops, of the form:
15301586
/// %x = shuffle ...
15311587
/// %y = shuffle ...
@@ -1917,6 +1973,7 @@ bool VectorCombine::run() {
19171973
switch (Opcode) {
19181974
case Instruction::Call:
19191975
MadeChange |= foldShuffleFromReductions(I);
1976+
MadeChange |= foldTruncFromReductions(I);
19201977
break;
19211978
case Instruction::ICmp:
19221979
case Instruction::FCmp:
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64 | FileCheck %s --check-prefixes=CHECK,X64
3+
; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v2 | FileCheck %s --check-prefixes=CHECK,X64
4+
; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v3 | FileCheck %s --check-prefixes=CHECK,X64
5+
; RUN: opt < %s -S --passes=vector-combine -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s --check-prefixes=CHECK,AVX512
6+
7+
;
8+
; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
9+
;
10+
11+
; Cheap AVX512 v8i64 -> v8i32 truncation
12+
define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0) {
13+
; X64-LABEL: @reduce_add_trunc_v8i64_i32(
14+
; X64-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[A0:%.*]])
15+
; X64-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i32
16+
; X64-NEXT: ret i32 [[RED]]
17+
;
18+
; AVX512-LABEL: @reduce_add_trunc_v8i64_i32(
19+
; AVX512-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
20+
; AVX512-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
21+
; AVX512-NEXT: ret i32 [[RED]]
22+
;
23+
%tr = trunc <8 x i64> %a0 to <8 x i32>
24+
%red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
25+
ret i32 %red
26+
}
27+
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
28+
29+
; No legal vXi8 multiplication so vXi16 is always cheaper
30+
define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0) {
31+
; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
32+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> [[A0:%.*]])
33+
; CHECK-NEXT: [[RED:%.*]] = trunc i16 [[TMP1]] to i8
34+
; CHECK-NEXT: ret i8 [[RED]]
35+
;
36+
%tr = trunc <16 x i16> %a0 to <16 x i8>
37+
%red = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> %tr)
38+
ret i8 %red
39+
}
40+
declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
41+
42+
define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0) {
43+
; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
44+
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[A0:%.*]])
45+
; CHECK-NEXT: [[RED:%.*]] = trunc i32 [[TMP1]] to i8
46+
; CHECK-NEXT: ret i8 [[RED]]
47+
;
48+
%tr = trunc <8 x i32> %a0 to <8 x i8>
49+
%red = tail call i8 @llvm.vector.reduce.or.v8i32(<8 x i8> %tr)
50+
ret i8 %red
51+
}
52+
declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
53+
54+
define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0) {
55+
; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
56+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.xor.v16i64(<16 x i64> [[A0:%.*]])
57+
; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i8
58+
; CHECK-NEXT: ret i8 [[RED]]
59+
;
60+
%tr = trunc <16 x i64> %a0 to <16 x i8>
61+
%red = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> %tr)
62+
ret i8 %red
63+
}
64+
declare i8 @llvm.vector.reduce.xor.v16i8(<16 x i8>)
65+
66+
; Truncation source has other uses - OK to truncate reduction
67+
define i16 @reduce_and_trunc_v16i64_i16(<16 x i64> %a0) {
68+
; CHECK-LABEL: @reduce_and_trunc_v16i64_i16(
69+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.and.v16i64(<16 x i64> [[A0:%.*]])
70+
; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i16
71+
; CHECK-NEXT: call void @use_v16i64(<16 x i64> [[A0]])
72+
; CHECK-NEXT: ret i16 [[RED]]
73+
;
74+
%tr = trunc <16 x i64> %a0 to <16 x i16>
75+
%red = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> %tr)
76+
call void @use_v16i64(<16 x i64> %a0)
77+
ret i16 %red
78+
}
79+
declare i16 @llvm.vector.reduce.and.v16i16(<16 x i16>)
80+
81+
; Negative Test: vXi16 multiply is much cheaper than vXi64
82+
define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0) {
83+
; CHECK-LABEL: @reduce_mul_trunc_v8i64_i16(
84+
; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i16>
85+
; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> [[TR]])
86+
; CHECK-NEXT: ret i16 [[RED]]
87+
;
88+
%tr = trunc <8 x i64> %a0 to <8 x i16>
89+
%red = tail call i16 @llvm.vector.reduce.mul.v8i16(<8 x i16> %tr)
90+
ret i16 %red
91+
}
92+
declare i16 @llvm.vector.reduce.mul.v8i16(<8 x i16>)
93+
94+
; Negative Test: min/max reductions can't use pre-truncated types.
95+
define i8 @reduce_smin_trunc_v16i16_i8(<16 x i16> %a0) {
96+
; CHECK-LABEL: @reduce_smin_trunc_v16i16_i8(
97+
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
98+
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> [[TR]])
99+
; CHECK-NEXT: ret i8 [[RED]]
100+
;
101+
%tr = trunc <16 x i16> %a0 to <16 x i8>
102+
%red = tail call i8 @llvm.vector.reduce.smin.v16i8(<16 x i8> %tr)
103+
ret i8 %red
104+
}
105+
declare i8 @llvm.vector.reduce.smin.v16i8(<16 x i8>)
106+
107+
; Negative Test: Truncation has other uses.
108+
define i16 @reduce_and_trunc_v16i64_i16_multiuse(<16 x i64> %a0) {
109+
; CHECK-LABEL: @reduce_and_trunc_v16i64_i16_multiuse(
110+
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i16>
111+
; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> [[TR]])
112+
; CHECK-NEXT: call void @use_v16i16(<16 x i16> [[TR]])
113+
; CHECK-NEXT: ret i16 [[RED]]
114+
;
115+
%tr = trunc <16 x i64> %a0 to <16 x i16>
116+
%red = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> %tr)
117+
call void @use_v16i16(<16 x i16> %tr)
118+
ret i16 %red
119+
}
120+
121+
declare void @use_v16i64(<16 x i64>)
122+
declare void @use_v16i16(<16 x i16>)
123+

0 commit comments

Comments
 (0)