@@ -1395,22 +1395,46 @@ class BoUpSLP {
1395
1395
return VectorizableTree.front()->Scalars;
1396
1396
}
1397
1397
1398
+ /// Returns the type/is-signed info for the root node in the graph without
1399
+ /// casting.
1400
+ std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
1401
+ const TreeEntry &Root = *VectorizableTree.front().get();
1402
+ if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
1403
+ !Root.Scalars.front()->getType()->isIntegerTy())
1404
+ return std::nullopt;
1405
+ auto It = MinBWs.find(&Root);
1406
+ if (It != MinBWs.end())
1407
+ return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
1408
+ It->second.first),
1409
+ It->second.second);
1410
+ if (Root.getOpcode() == Instruction::ZExt ||
1411
+ Root.getOpcode() == Instruction::SExt)
1412
+ return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
1413
+ Root.getOpcode() == Instruction::SExt);
1414
+ return std::nullopt;
1415
+ }
1416
+
1398
1417
/// Checks if the root graph node can be emitted with narrower bitwidth at
1399
1418
/// codegen and returns it signedness, if so.
1400
1419
bool isSignedMinBitwidthRootNode() const {
1401
1420
return MinBWs.at(VectorizableTree.front().get()).second;
1402
1421
}
1403
1422
1404
- /// Returns reduction bitwidth and signedness, if it does not match the
1405
- /// original requested size.
1406
- std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1423
+ /// Returns reduction type after minbitdth analysis.
1424
+ FixedVectorType *getReductionType() const {
1407
1425
if (ReductionBitWidth == 0 ||
1426
+ !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
1408
1427
ReductionBitWidth >=
1409
1428
DL->getTypeSizeInBits(
1410
1429
VectorizableTree.front()->Scalars.front()->getType()))
1411
- return std::nullopt;
1412
- return std::make_pair(ReductionBitWidth,
1413
- MinBWs.at(VectorizableTree.front().get()).second);
1430
+ return getWidenedType(
1431
+ VectorizableTree.front()->Scalars.front()->getType(),
1432
+ VectorizableTree.front()->getVectorFactor());
1433
+ return getWidenedType(
1434
+ IntegerType::get(
1435
+ VectorizableTree.front()->Scalars.front()->getContext(),
1436
+ ReductionBitWidth),
1437
+ VectorizableTree.front()->getVectorFactor());
1414
1438
}
1415
1439
1416
1440
/// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11384,6 +11408,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
11384
11408
return CommonCost;
11385
11409
auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
11386
11410
TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
11411
+
11412
+ bool IsArithmeticExtendedReduction =
11413
+ E->Idx == 0 && UserIgnoreList &&
11414
+ all_of(*UserIgnoreList, [](Value *V) {
11415
+ auto *I = cast<Instruction>(V);
11416
+ return is_contained({Instruction::Add, Instruction::FAdd,
11417
+ Instruction::Mul, Instruction::FMul,
11418
+ Instruction::And, Instruction::Or,
11419
+ Instruction::Xor},
11420
+ I->getOpcode());
11421
+ });
11422
+ if (IsArithmeticExtendedReduction &&
11423
+ (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
11424
+ return CommonCost;
11387
11425
return CommonCost +
11388
11426
TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
11389
11427
VecOpcode == Opcode ? VI : nullptr);
@@ -12748,32 +12786,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
12748
12786
unsigned SrcSize = It->second.first;
12749
12787
unsigned DstSize = ReductionBitWidth;
12750
12788
unsigned Opcode = Instruction::Trunc;
12751
- if (SrcSize < DstSize)
12752
- Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
12753
- auto *SrcVecTy =
12754
- getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
12755
- auto *DstVecTy =
12756
- getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
12757
- TTI::CastContextHint CCH = getCastContextHint(E);
12758
- InstructionCost CastCost;
12759
- switch (E.getOpcode()) {
12760
- case Instruction::SExt:
12761
- case Instruction::ZExt:
12762
- case Instruction::Trunc: {
12763
- const TreeEntry *OpTE = getOperandEntry(&E, 0);
12764
- CCH = getCastContextHint(*OpTE);
12765
- break;
12766
- }
12767
- default:
12768
- break;
12789
+ if (SrcSize < DstSize) {
12790
+ bool IsArithmeticExtendedReduction =
12791
+ all_of(*UserIgnoreList, [](Value *V) {
12792
+ auto *I = cast<Instruction>(V);
12793
+ return is_contained({Instruction::Add, Instruction::FAdd,
12794
+ Instruction::Mul, Instruction::FMul,
12795
+ Instruction::And, Instruction::Or,
12796
+ Instruction::Xor},
12797
+ I->getOpcode());
12798
+ });
12799
+ if (IsArithmeticExtendedReduction)
12800
+ Opcode =
12801
+ Instruction::BitCast; // Handle it by getExtendedReductionCost
12802
+ else
12803
+ Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
12804
+ }
12805
+ if (Opcode != Instruction::BitCast) {
12806
+ auto *SrcVecTy =
12807
+ getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
12808
+ auto *DstVecTy =
12809
+ getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
12810
+ TTI::CastContextHint CCH = getCastContextHint(E);
12811
+ InstructionCost CastCost;
12812
+ switch (E.getOpcode()) {
12813
+ case Instruction::SExt:
12814
+ case Instruction::ZExt:
12815
+ case Instruction::Trunc: {
12816
+ const TreeEntry *OpTE = getOperandEntry(&E, 0);
12817
+ CCH = getCastContextHint(*OpTE);
12818
+ break;
12819
+ }
12820
+ default:
12821
+ break;
12822
+ }
12823
+ CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
12824
+ TTI::TCK_RecipThroughput);
12825
+ Cost += CastCost;
12826
+ LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
12827
+ << " for final resize for reduction from " << SrcVecTy
12828
+ << " to " << DstVecTy << "\n";
12829
+ dbgs() << "SLP: Current total cost = " << Cost << "\n");
12769
12830
}
12770
- CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
12771
- TTI::TCK_RecipThroughput);
12772
- Cost += CastCost;
12773
- LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
12774
- << " for final resize for reduction from " << SrcVecTy
12775
- << " to " << DstVecTy << "\n";
12776
- dbgs() << "SLP: Current total cost = " << Cost << "\n");
12777
12831
}
12778
12832
}
12779
12833
@@ -19951,8 +20005,8 @@ class HorizontalReduction {
19951
20005
19952
20006
// Estimate cost.
19953
20007
InstructionCost TreeCost = V.getTreeCost(VL);
19954
- InstructionCost ReductionCost = getReductionCost(
19955
- TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign() );
20008
+ InstructionCost ReductionCost =
20009
+ getReductionCost( TTI, VL, IsCmpSelMinMax, RdxFMF, V);
19956
20010
InstructionCost Cost = TreeCost + ReductionCost;
19957
20011
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
19958
20012
<< " for reduction\n");
@@ -20243,14 +20297,14 @@ class HorizontalReduction {
20243
20297
20244
20298
private:
20245
20299
/// Calculate the cost of a reduction.
20246
- InstructionCost getReductionCost(
20247
- TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20248
- bool IsCmpSelMinMax, FastMathFlags FMF,
20249
- const std::optional<std::pair<unsigned, bool>> BitwidthAndSign ) {
20300
+ InstructionCost getReductionCost(TargetTransformInfo *TTI,
20301
+ ArrayRef<Value *> ReducedVals,
20302
+ bool IsCmpSelMinMax, FastMathFlags FMF,
20303
+ const BoUpSLP &R ) {
20250
20304
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
20251
20305
Type *ScalarTy = ReducedVals.front()->getType();
20252
20306
unsigned ReduxWidth = ReducedVals.size();
20253
- FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth );
20307
+ FixedVectorType *VectorTy = R.getReductionType( );
20254
20308
InstructionCost VectorCost = 0, ScalarCost;
20255
20309
// If all of the reduced values are constant, the vector cost is 0, since
20256
20310
// the reduction value can be calculated at the compile time.
@@ -20308,21 +20362,16 @@ class HorizontalReduction {
20308
20362
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
20309
20363
/*Extract*/ false, TTI::TCK_RecipThroughput);
20310
20364
} else {
20311
- auto [Bitwidth, IsSigned] =
20312
- BitwidthAndSign.value_or(std::make_pair(0u, false));
20313
- if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20314
- // Represent vector_reduce_add(ZExt(<n x i1>)) to
20315
- // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20316
- auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20317
- IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20318
- VectorCost =
20319
- TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20320
- getWidenedType(ScalarTy, ReduxWidth),
20321
- TTI::CastContextHint::None, CostKind) +
20322
- TTI->getIntrinsicInstrCost(ICA, CostKind);
20323
- } else {
20365
+ Type *RedTy = VectorTy->getElementType();
20366
+ auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
20367
+ std::make_pair(RedTy, true));
20368
+ if (RType == RedTy) {
20324
20369
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20325
20370
FMF, CostKind);
20371
+ } else {
20372
+ VectorCost = TTI->getExtendedReductionCost(
20373
+ RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
20374
+ FMF, CostKind);
20326
20375
}
20327
20376
}
20328
20377
}
0 commit comments