Skip to content

Commit d30236d

Browse files
committed
[VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x)) iff cost effective
Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free. If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely. Fixes #81469
1 parent 7de4f82 commit d30236d

File tree

2 files changed

+76
-17
lines changed

2 files changed

+76
-17
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/IR/PatternMatch.h"
3030
#include "llvm/Support/CommandLine.h"
3131
#include "llvm/Transforms/Utils/Local.h"
32+
#include "llvm/Transforms/Utils/LoopUtils.h"
3233
#include <numeric>
3334
#include <queue>
3435

@@ -111,6 +112,7 @@ class VectorCombine {
111112
bool scalarizeLoadExtract(Instruction &I);
112113
bool foldShuffleOfBinops(Instruction &I);
113114
bool foldShuffleFromReductions(Instruction &I);
115+
bool foldTruncFromReductions(Instruction &I);
114116
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
115117

116118
void replaceValue(Value &Old, Value &New) {
@@ -1526,6 +1528,60 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
15261528
return foldSelectShuffle(*Shuffle, true);
15271529
}
15281530

1531+
/// Determine if its more efficient to fold:
1532+
/// reduce(trunc(x)) -> trunc(reduce(x)).
1533+
bool VectorCombine::foldTruncFromReductions(Instruction &I) {
1534+
auto *II = dyn_cast<IntrinsicInst>(&I);
1535+
if (!II)
1536+
return false;
1537+
1538+
Intrinsic::ID IID = II->getIntrinsicID();
1539+
switch (IID) {
1540+
case Intrinsic::vector_reduce_add:
1541+
case Intrinsic::vector_reduce_mul:
1542+
case Intrinsic::vector_reduce_and:
1543+
case Intrinsic::vector_reduce_or:
1544+
case Intrinsic::vector_reduce_xor:
1545+
break;
1546+
default:
1547+
return false;
1548+
}
1549+
1550+
unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
1551+
Value *ReductionSrc = I.getOperand(0);
1552+
1553+
Value *TruncSrc;
1554+
if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
1555+
return false;
1556+
1557+
auto *Trunc = cast<CastInst>(ReductionSrc);
1558+
auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
1559+
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
1560+
Type *ResultTy = I.getType();
1561+
1562+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1563+
InstructionCost OldCost =
1564+
TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
1565+
TTI::CastContextHint::None, CostKind, Trunc) +
1566+
TTI.getArithmeticReductionCost(ReductionOpc, ReductionSrcTy, std::nullopt,
1567+
CostKind);
1568+
InstructionCost NewCost =
1569+
TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
1570+
CostKind) +
1571+
TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
1572+
ReductionSrcTy->getScalarType(),
1573+
TTI::CastContextHint::None, CostKind);
1574+
1575+
if (OldCost <= NewCost || !NewCost.isValid())
1576+
return false;
1577+
1578+
Value *NewReduction = Builder.CreateIntrinsic(
1579+
TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
1580+
Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
1581+
replaceValue(I, *NewTruncation);
1582+
return true;
1583+
}
1584+
15291585
/// This method looks for groups of shuffles acting on binops, of the form:
15301586
/// %x = shuffle ...
15311587
/// %y = shuffle ...
@@ -1917,6 +1973,7 @@ bool VectorCombine::run() {
19171973
switch (Opcode) {
19181974
case Instruction::Call:
19191975
MadeChange |= foldShuffleFromReductions(I);
1976+
MadeChange |= foldTruncFromReductions(I);
19201977
break;
19211978
case Instruction::ICmp:
19221979
case Instruction::FCmp:

llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,29 @@
88
; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
99
;
1010

11-
; TODO: Cheap AVX512 v8i64 -> v8i32 truncation
11+
; Cheap AVX512 v8i64 -> v8i32 truncation
1212
define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0) {
13-
; CHECK-LABEL: @reduce_add_trunc_v8i64_i32(
14-
; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
15-
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
16-
; CHECK-NEXT: ret i32 [[RED]]
13+
; X64-LABEL: @reduce_add_trunc_v8i64_i32(
14+
; X64-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[A0:%.*]])
15+
; X64-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i32
16+
; X64-NEXT: ret i32 [[RED]]
17+
;
18+
; AVX512-LABEL: @reduce_add_trunc_v8i64_i32(
19+
; AVX512-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
20+
; AVX512-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
21+
; AVX512-NEXT: ret i32 [[RED]]
1722
;
1823
%tr = trunc <8 x i64> %a0 to <8 x i32>
1924
%red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
2025
ret i32 %red
2126
}
2227
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
2328

24-
; TODO: No legal vXi8 multiplication so vXi16 is always cheaper
29+
; No legal vXi8 multiplication so vXi16 is always cheaper
2530
define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0) {
2631
; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
27-
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
28-
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> [[TR]])
32+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> [[A0:%.*]])
33+
; CHECK-NEXT: [[RED:%.*]] = trunc i16 [[TMP1]] to i8
2934
; CHECK-NEXT: ret i8 [[RED]]
3035
;
3136
%tr = trunc <16 x i16> %a0 to <16 x i8>
@@ -36,8 +41,8 @@ declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
3641

3742
define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0) {
3843
; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
39-
; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i32> [[A0:%.*]] to <8 x i8>
40-
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[TR]])
44+
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[A0:%.*]])
45+
; CHECK-NEXT: [[RED:%.*]] = trunc i32 [[TMP1]] to i8
4146
; CHECK-NEXT: ret i8 [[RED]]
4247
;
4348
%tr = trunc <8 x i32> %a0 to <8 x i8>
@@ -48,8 +53,8 @@ declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
4853

4954
define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0) {
5055
; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
51-
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i8>
52-
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> [[TR]])
56+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.xor.v16i64(<16 x i64> [[A0:%.*]])
57+
; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i8
5358
; CHECK-NEXT: ret i8 [[RED]]
5459
;
5560
%tr = trunc <16 x i64> %a0 to <16 x i8>
@@ -61,8 +66,8 @@ declare i8 @llvm.vector.reduce.xor.v16i8(<16 x i8>)
6166
; Truncation source has other uses - OK to truncate reduction
6267
define i16 @reduce_and_trunc_v16i64_i16(<16 x i64> %a0) {
6368
; CHECK-LABEL: @reduce_and_trunc_v16i64_i16(
64-
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i16>
65-
; CHECK-NEXT: [[RED:%.*]] = tail call i16 @llvm.vector.reduce.and.v16i16(<16 x i16> [[TR]])
69+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.and.v16i64(<16 x i64> [[A0:%.*]])
70+
; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i16
6671
; CHECK-NEXT: call void @use_v16i64(<16 x i64> [[A0]])
6772
; CHECK-NEXT: ret i16 [[RED]]
6873
;
@@ -116,6 +121,3 @@ define i16 @reduce_and_trunc_v16i64_i16_multiuse(<16 x i64> %a0) {
116121
declare void @use_v16i64(<16 x i64>)
117122
declare void @use_v16i16(<16 x i16>)
118123

119-
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
120-
; AVX512: {{.*}}
121-
; X64: {{.*}}

0 commit comments

Comments
 (0)