Skip to content

Commit 24e4aa3

Browse files
preamesyuxuanchen1997
authored andcommitted
[vectorcombine] Pull sext/zext through reduce.or/and/xor (#99548)
Summary: This extends the existing foldTruncFromReductions transform to handle sext and zext as well. This is only legal for the bitwise reductions (and/or/xor) and not the arithmetic ones (add, mul). Use the same costing decision to drive whether we do the transform. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250996
1 parent b6f35fe commit 24e4aa3

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class VectorCombine {
117117
bool foldShuffleOfShuffles(Instruction &I);
118118
bool foldShuffleToIdentity(Instruction &I);
119119
bool foldShuffleFromReductions(Instruction &I);
120-
bool foldTruncFromReductions(Instruction &I);
120+
bool foldCastFromReductions(Instruction &I);
121121
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
122122

123123
void replaceValue(Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
21132113

21142114
/// Determine if its more efficient to fold:
21152115
/// reduce(trunc(x)) -> trunc(reduce(x)).
2116-
bool VectorCombine::foldTruncFromReductions(Instruction &I) {
2116+
/// reduce(sext(x)) -> sext(reduce(x)).
2117+
/// reduce(zext(x)) -> zext(reduce(x)).
2118+
bool VectorCombine::foldCastFromReductions(Instruction &I) {
21172119
auto *II = dyn_cast<IntrinsicInst>(&I);
21182120
if (!II)
21192121
return false;
21202122

2123+
bool TruncOnly = false;
21212124
Intrinsic::ID IID = II->getIntrinsicID();
21222125
switch (IID) {
21232126
case Intrinsic::vector_reduce_add:
21242127
case Intrinsic::vector_reduce_mul:
2128+
TruncOnly = true;
2129+
break;
21252130
case Intrinsic::vector_reduce_and:
21262131
case Intrinsic::vector_reduce_or:
21272132
case Intrinsic::vector_reduce_xor:
@@ -2133,35 +2138,37 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
21332138
unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
21342139
Value *ReductionSrc = I.getOperand(0);
21352140

2136-
Value *TruncSrc;
2137-
if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(TruncSrc)))))
2141+
Value *Src;
2142+
if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
2143+
(TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
21382144
return false;
21392145

2140-
auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType());
2146+
auto CastOpc =
2147+
(Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
2148+
2149+
auto *SrcTy = cast<VectorType>(Src->getType());
21412150
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
21422151
Type *ResultTy = I.getType();
21432152

21442153
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21452154
InstructionCost OldCost = TTI.getArithmeticReductionCost(
21462155
ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
2147-
if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
2148-
OldCost +=
2149-
TTI.getCastInstrCost(Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
2150-
TTI::CastContextHint::None, CostKind, Trunc);
2156+
OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
2157+
TTI::CastContextHint::None, CostKind,
2158+
cast<CastInst>(ReductionSrc));
21512159
InstructionCost NewCost =
2152-
TTI.getArithmeticReductionCost(ReductionOpc, TruncSrcTy, std::nullopt,
2160+
TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
21532161
CostKind) +
2154-
TTI.getCastInstrCost(Instruction::Trunc, ResultTy,
2155-
ReductionSrcTy->getScalarType(),
2162+
TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
21562163
TTI::CastContextHint::None, CostKind);
21572164

21582165
if (OldCost <= NewCost || !NewCost.isValid())
21592166
return false;
21602167

2161-
Value *NewReduction = Builder.CreateIntrinsic(
2162-
TruncSrcTy->getScalarType(), II->getIntrinsicID(), {TruncSrc});
2163-
Value *NewTruncation = Builder.CreateTrunc(NewReduction, ResultTy);
2164-
replaceValue(I, *NewTruncation);
2168+
Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
2169+
II->getIntrinsicID(), {Src});
2170+
Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
2171+
replaceValue(I, *NewCast);
21652172
return true;
21662173
}
21672174

@@ -2559,7 +2566,7 @@ bool VectorCombine::run() {
25592566
switch (Opcode) {
25602567
case Instruction::Call:
25612568
MadeChange |= foldShuffleFromReductions(I);
2562-
MadeChange |= foldTruncFromReductions(I);
2569+
MadeChange |= foldCastFromReductions(I);
25632570
break;
25642571
case Instruction::ICmp:
25652572
case Instruction::FCmp:

llvm/test/Transforms/VectorCombine/RISCV/vecreduce-of-cast.ll

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ define i16 @reduce_mul_trunc_v8i64_i16(<8 x i64> %a0) {
7474

7575
define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0) {
7676
; CHECK-LABEL: @reduce_or_sext_v8i8_to_v8i32(
77-
; CHECK-NEXT: [[TR:%.*]] = sext <8 x i8> [[A0:%.*]] to <8 x i32>
78-
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
77+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
78+
; CHECK-NEXT: [[RED:%.*]] = sext i8 [[TMP1]] to i32
7979
; CHECK-NEXT: ret i32 [[RED]]
8080
;
8181
%tr = sext <8 x i8> %a0 to <8 x i32>
@@ -85,8 +85,8 @@ define i32 @reduce_or_sext_v8i8_to_v8i32(<8 x i8> %a0) {
8585

8686
define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0) {
8787
; CHECK-LABEL: @reduce_or_sext_v8i16_to_v8i32(
88-
; CHECK-NEXT: [[TR:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
89-
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
88+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
89+
; CHECK-NEXT: [[RED:%.*]] = sext i16 [[TMP1]] to i32
9090
; CHECK-NEXT: ret i32 [[RED]]
9191
;
9292
%tr = sext <8 x i16> %a0 to <8 x i32>
@@ -96,8 +96,8 @@ define i32 @reduce_or_sext_v8i16_to_v8i32(<8 x i16> %a0) {
9696

9797
define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0) {
9898
; CHECK-LABEL: @reduce_or_zext_v8i8_to_v8i32(
99-
; CHECK-NEXT: [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
100-
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
99+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> [[A0:%.*]])
100+
; CHECK-NEXT: [[RED:%.*]] = zext i8 [[TMP1]] to i32
101101
; CHECK-NEXT: ret i32 [[RED]]
102102
;
103103
%tr = zext <8 x i8> %a0 to <8 x i32>
@@ -107,15 +107,29 @@ define i32 @reduce_or_zext_v8i8_to_v8i32(<8 x i8> %a0) {
107107

108108
define i32 @reduce_or_zext_v8i16_to_v8i32(<8 x i16> %a0) {
109109
; CHECK-LABEL: @reduce_or_zext_v8i16_to_v8i32(
110-
; CHECK-NEXT: [[TR:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
111-
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TR]])
110+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> [[A0:%.*]])
111+
; CHECK-NEXT: [[RED:%.*]] = zext i16 [[TMP1]] to i32
112112
; CHECK-NEXT: ret i32 [[RED]]
113113
;
114114
%tr = zext <8 x i16> %a0 to <8 x i32>
115115
%red = tail call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %tr)
116116
ret i32 %red
117117
}
118118

119+
; Negative case - narrowing the reduce (to i8) is illegal.
120+
; TODO: We could narrow to i16 instead.
121+
define i32 @reduce_add_trunc_v8i8_to_v8i32(<8 x i8> %a0) {
122+
; CHECK-LABEL: @reduce_add_trunc_v8i8_to_v8i32(
123+
; CHECK-NEXT: [[TR:%.*]] = zext <8 x i8> [[A0:%.*]] to <8 x i32>
124+
; CHECK-NEXT: [[RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TR]])
125+
; CHECK-NEXT: ret i32 [[RED]]
126+
;
127+
%tr = zext <8 x i8> %a0 to <8 x i32>
128+
%red = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %tr)
129+
ret i32 %red
130+
}
131+
132+
119133
declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
120134
declare i16 @llvm.vector.reduce.add.v8i16(<8 x i16>)
121135
declare i8 @llvm.vector.reduce.add.v8i8(<8 x i8>)

0 commit comments

Comments
 (0)