@@ -117,7 +117,7 @@ class VectorCombine {
117
117
bool foldShuffleOfShuffles (Instruction &I);
118
118
bool foldShuffleToIdentity (Instruction &I);
119
119
bool foldShuffleFromReductions (Instruction &I);
120
- bool foldTruncFromReductions (Instruction &I);
120
+ bool foldCastFromReductions (Instruction &I);
121
121
bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
122
122
123
123
void replaceValue (Value &Old, Value &New) {
@@ -2113,15 +2113,20 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2113
2113
2114
2114
// / Determine if its more efficient to fold:
2115
2115
// / reduce(trunc(x)) -> trunc(reduce(x)).
2116
- bool VectorCombine::foldTruncFromReductions (Instruction &I) {
2116
+ // / reduce(sext(x)) -> sext(reduce(x)).
2117
+ // / reduce(zext(x)) -> zext(reduce(x)).
2118
+ bool VectorCombine::foldCastFromReductions (Instruction &I) {
2117
2119
auto *II = dyn_cast<IntrinsicInst>(&I);
2118
2120
if (!II)
2119
2121
return false ;
2120
2122
2123
+ bool TruncOnly = false ;
2121
2124
Intrinsic::ID IID = II->getIntrinsicID ();
2122
2125
switch (IID) {
2123
2126
case Intrinsic::vector_reduce_add:
2124
2127
case Intrinsic::vector_reduce_mul:
2128
+ TruncOnly = true ;
2129
+ break ;
2125
2130
case Intrinsic::vector_reduce_and:
2126
2131
case Intrinsic::vector_reduce_or:
2127
2132
case Intrinsic::vector_reduce_xor:
@@ -2133,35 +2138,37 @@ bool VectorCombine::foldTruncFromReductions(Instruction &I) {
2133
2138
unsigned ReductionOpc = getArithmeticReductionInstruction (IID);
2134
2139
Value *ReductionSrc = I.getOperand (0 );
2135
2140
2136
- Value *TruncSrc;
2137
- if (!match (ReductionSrc, m_OneUse (m_Trunc (m_Value (TruncSrc)))))
2141
+ Value *Src;
2142
+ if (!match (ReductionSrc, m_OneUse (m_Trunc (m_Value (Src)))) &&
2143
+ (TruncOnly || !match (ReductionSrc, m_OneUse (m_ZExtOrSExt (m_Value (Src))))))
2138
2144
return false ;
2139
2145
2140
- auto *TruncSrcTy = cast<VectorType>(TruncSrc->getType ());
2146
+ auto CastOpc =
2147
+ (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode ();
2148
+
2149
+ auto *SrcTy = cast<VectorType>(Src->getType ());
2141
2150
auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType ());
2142
2151
Type *ResultTy = I.getType ();
2143
2152
2144
2153
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2145
2154
InstructionCost OldCost = TTI.getArithmeticReductionCost (
2146
2155
ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
2147
- if (auto *Trunc = dyn_cast<CastInst>(ReductionSrc))
2148
- OldCost +=
2149
- TTI.getCastInstrCost (Instruction::Trunc, ReductionSrcTy, TruncSrcTy,
2150
- TTI::CastContextHint::None, CostKind, Trunc);
2156
+ OldCost += TTI.getCastInstrCost (CastOpc, ReductionSrcTy, SrcTy,
2157
+ TTI::CastContextHint::None, CostKind,
2158
+ cast<CastInst>(ReductionSrc));
2151
2159
InstructionCost NewCost =
2152
- TTI.getArithmeticReductionCost (ReductionOpc, TruncSrcTy , std::nullopt,
2160
+ TTI.getArithmeticReductionCost (ReductionOpc, SrcTy , std::nullopt,
2153
2161
CostKind) +
2154
- TTI.getCastInstrCost (Instruction::Trunc, ResultTy,
2155
- ReductionSrcTy->getScalarType (),
2162
+ TTI.getCastInstrCost (CastOpc, ResultTy, ReductionSrcTy->getScalarType (),
2156
2163
TTI::CastContextHint::None, CostKind);
2157
2164
2158
2165
if (OldCost <= NewCost || !NewCost.isValid ())
2159
2166
return false ;
2160
2167
2161
- Value *NewReduction = Builder.CreateIntrinsic (
2162
- TruncSrcTy-> getScalarType (), II->getIntrinsicID (), {TruncSrc });
2163
- Value *NewTruncation = Builder.CreateTrunc ( NewReduction, ResultTy);
2164
- replaceValue (I, *NewTruncation );
2168
+ Value *NewReduction = Builder.CreateIntrinsic (SrcTy-> getScalarType (),
2169
+ II->getIntrinsicID (), {Src });
2170
+ Value *NewCast = Builder.CreateCast (CastOpc, NewReduction, ResultTy);
2171
+ replaceValue (I, *NewCast );
2165
2172
return true ;
2166
2173
}
2167
2174
@@ -2559,7 +2566,7 @@ bool VectorCombine::run() {
2559
2566
switch (Opcode) {
2560
2567
case Instruction::Call:
2561
2568
MadeChange |= foldShuffleFromReductions (I);
2562
- MadeChange |= foldTruncFromReductions (I);
2569
+ MadeChange |= foldCastFromReductions (I);
2563
2570
break ;
2564
2571
case Instruction::ICmp:
2565
2572
case Instruction::FCmp:
0 commit comments