Skip to content

Commit 6194e5c

Browse files
committed
[VectorCombine] Fold reduce(trunc(x)) -> trunc(reduce(x)) iff cost effective
Vector truncations can be pretty expensive, especially on X86, whilst scalar truncations are often free. If the cost of performing the add/mul/and/or/xor reduction is cheap enough on the pre-truncated type, then avoid the vector truncation entirely. Fixes #81469
1 parent ee5e122 commit 6194e5c

File tree

2 files changed

+80
-15
lines changed

2 files changed

+80
-15
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class VectorCombine {
111111
bool scalarizeLoadExtract(Instruction &I);
112112
bool foldShuffleOfBinops(Instruction &I);
113113
bool foldShuffleFromReductions(Instruction &I);
114+
bool foldTruncFromReductions(Instruction &I);
114115
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
115116

116117
void replaceValue(Value &Old, Value &New) {
@@ -1526,6 +1527,67 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
15261527
return foldSelectShuffle(*Shuffle, true);
15271528
}
15281529

1530+
/// Determine if its more efficient to fold:
1531+
/// reduce(trunc(x)) -> trunc(reduce(x)).
1532+
bool VectorCombine::foldTruncFromReductions(Instruction &I) {
1533+
auto *II = dyn_cast<IntrinsicInst>(&I);
1534+
if (!II)
1535+
return false;
1536+
1537+
unsigned ReductionOpc = 0;
1538+
switch (II->getIntrinsicID()) {
1539+
case Intrinsic::vector_reduce_add:
1540+
ReductionOpc = Instruction::Add;
1541+
break;
1542+
case Intrinsic::vector_reduce_mul:
1543+
ReductionOpc = Instruction::Mul;
1544+
break;
1545+
case Intrinsic::vector_reduce_and:
1546+
ReductionOpc = Instruction::And;
1547+
break;
1548+
case Intrinsic::vector_reduce_or:
1549+
ReductionOpc = Instruction::Or;
1550+
break;
1551+
case Intrinsic::vector_reduce_xor:
1552+
ReductionOpc = Instruction::Xor;
1553+
break;
1554+
default:
1555+
return false;
1556+
}
1557+
Value *ReductionSrc = I.getOperand(0);
1558+
1559+
Value *TruncSrc;
1560+
if (!match(ReductionSrc, m_Trunc(m_OneUse(m_Value(TruncSrc)))))
1561+
return false;
1562+
1563+
auto *Trunc = cast<CastInst>(ReductionSrc);
1564+
auto *TruncTy = cast<VectorType>(TruncSrc->getType());
1565+
auto *ReductionTy = cast<VectorType>(ReductionSrc->getType());
1566+
Type *ResultTy = I.getType();
1567+
1568+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1569+
InstructionCost OldCost =
1570+
TTI.getCastInstrCost(Instruction::Trunc, ReductionTy, TruncTy,
1571+
TTI::CastContextHint::None, CostKind, Trunc) +
1572+
TTI.getArithmeticReductionCost(ReductionOpc, ReductionTy, std::nullopt,
1573+
CostKind);
1574+
InstructionCost NewCost =
1575+
TTI.getArithmeticReductionCost(ReductionOpc, TruncTy, std::nullopt,
1576+
CostKind) +
1577+
TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
1578+
ReductionTy->getScalarType(),
1579+
TTI::CastContextHint::None, CostKind);
1580+
1581+
if (OldCost < NewCost || !NewCost.isValid())
1582+
return false;
1583+
1584+
Value *NewReduction = Builder.CreateIntrinsic(
1585+
TruncTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
1586+
Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
1587+
replaceValue(I, *NewTruncation);
1588+
return true;
1589+
}
1590+
15291591
/// This method looks for groups of shuffles acting on binops, of the form:
15301592
/// %x = shuffle ...
15311593
/// %y = shuffle ...
@@ -1917,6 +1979,7 @@ bool VectorCombine::run() {
19171979
switch (Opcode) {
19181980
case Instruction::Call:
19191981
MadeChange |= foldShuffleFromReductions(I);
1982+
MadeChange |= foldTruncFromReductions(I);
19201983
break;
19211984
case Instruction::ICmp:
19221985
case Instruction::FCmp:

llvm/test/Transforms/VectorCombine/X86/reduction-of-truncations.ll

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,29 @@
88
; Fold reduce(trunc(X)) -> trunc(reduce(X)) if more cost efficient
99
;
1010

11-
; TODO: Cheap AVX512 v8i64 -> v8i32 truncation
11+
; Cheap AVX512 v8i64 -> v8i32 truncation
1212
define i32 @reduce_add_trunc_v8i64_i32(<8 x i64> %a0) {
13-
; CHECK-LABEL: @reduce_add_trunc_v8i64_i32(
14-
; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
15-
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
16-
; CHECK-NEXT: ret i32 [[RED]]
13+
; X64-LABEL: @reduce_add_trunc_v8i64_i32(
14+
; X64-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[A0:%.*]])
15+
; X64-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i32
16+
; X64-NEXT: ret i32 [[RED]]
17+
;
18+
; AVX512-LABEL: @reduce_add_trunc_v8i64_i32(
19+
; AVX512-NEXT: [[TR:%.*]] = trunc <8 x i64> [[A0:%.*]] to <8 x i32>
20+
; AVX512-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
21+
; AVX512-NEXT: ret i32 [[RED]]
1722
;
1823
%tr = trunc <8 x i64> %a0 to <8 x i32>
1924
%red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
2025
ret i32 %red
2126
}
2227
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
2328

24-
; TODO: No legal vXi8 multiplication so vXi16 is always cheaper
29+
; No legal vXi8 multiplication so vXi16 is always cheaper
2530
define i8 @reduce_mul_trunc_v16i16_i8(<16 x i16> %a0) {
2631
; CHECK-LABEL: @reduce_mul_trunc_v16i16_i8(
27-
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i16> [[A0:%.*]] to <16 x i8>
28-
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.mul.v16i8(<16 x i8> [[TR]])
32+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.mul.v16i16(<16 x i16> [[A0:%.*]])
33+
; CHECK-NEXT: [[RED:%.*]] = trunc i16 [[TMP1]] to i8
2934
; CHECK-NEXT: ret i8 [[RED]]
3035
;
3136
%tr = trunc <16 x i16> %a0 to <16 x i8>
@@ -36,8 +41,8 @@ declare i8 @llvm.vector.reduce.mul.v16i8(<16 x i8>)
3641

3742
define i8 @reduce_or_trunc_v8i32_i8(<8 x i32> %a0) {
3843
; CHECK-LABEL: @reduce_or_trunc_v8i32_i8(
39-
; CHECK-NEXT: [[TR:%.*]] = trunc <8 x i32> [[A0:%.*]] to <8 x i8>
40-
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[TR]])
44+
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[A0:%.*]])
45+
; CHECK-NEXT: [[RED:%.*]] = trunc i32 [[TMP1]] to i8
4146
; CHECK-NEXT: ret i8 [[RED]]
4247
;
4348
%tr = trunc <8 x i32> %a0 to <8 x i8>
@@ -48,8 +53,8 @@ declare i32 @llvm.vector.reduce.or.v8i8(<8 x i8>)
4853

4954
define i8 @reduce_xor_trunc_v16i64_i8(<16 x i64> %a0) {
5055
; CHECK-LABEL: @reduce_xor_trunc_v16i64_i8(
51-
; CHECK-NEXT: [[TR:%.*]] = trunc <16 x i64> [[A0:%.*]] to <16 x i8>
52-
; CHECK-NEXT: [[RED:%.*]] = tail call i8 @llvm.vector.reduce.xor.v16i8(<16 x i8> [[TR]])
56+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vector.reduce.xor.v16i64(<16 x i64> [[A0:%.*]])
57+
; CHECK-NEXT: [[RED:%.*]] = trunc i64 [[TMP1]] to i8
5358
; CHECK-NEXT: ret i8 [[RED]]
5459
;
5560
%tr = trunc <16 x i64> %a0 to <16 x i8>
@@ -84,6 +89,3 @@ define i8 @reduce_smin_trunc_v16i16_i8(<16 x i16> %a0) {
8489
}
8590
declare i8 @llvm.vector.reduce.smin.v16i8(<16 x i8>)
8691

87-
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
88-
; AVX512: {{.*}}
89-
; X64: {{.*}}

0 commit comments

Comments
 (0)