@@ -4664,6 +4664,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
4664
4664
return LegalizationCost * LT.first ;
4665
4665
}
4666
4666
4667
+ InstructionCost AArch64TTIImpl::getPartialReductionCost (
4668
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
4669
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
4670
+ TTI::PartialReductionExtendKind OpBExtend,
4671
+ std::optional<unsigned > BinOp) const {
4672
+ InstructionCost Invalid = InstructionCost::getInvalid ();
4673
+ InstructionCost Cost (TTI::TCC_Basic);
4674
+
4675
+ if (Opcode != Instruction::Add)
4676
+ return Invalid;
4677
+
4678
+ if (InputTypeA != InputTypeB)
4679
+ return Invalid;
4680
+
4681
+ EVT InputEVT = EVT::getEVT (InputTypeA);
4682
+ EVT AccumEVT = EVT::getEVT (AccumType);
4683
+
4684
+ if (VF.isScalable () && !ST->isSVEorStreamingSVEAvailable ())
4685
+ return Invalid;
4686
+ if (VF.isFixed () && (!ST->isNeonAvailable () || !ST->hasDotProd ()))
4687
+ return Invalid;
4688
+
4689
+ if (InputEVT == MVT::i8 ) {
4690
+ switch (VF.getKnownMinValue ()) {
4691
+ default :
4692
+ return Invalid;
4693
+ case 8 :
4694
+ if (AccumEVT == MVT::i32 )
4695
+ Cost *= 2 ;
4696
+ else if (AccumEVT != MVT::i64 )
4697
+ return Invalid;
4698
+ break ;
4699
+ case 16 :
4700
+ if (AccumEVT == MVT::i64 )
4701
+ Cost *= 2 ;
4702
+ else if (AccumEVT != MVT::i32 )
4703
+ return Invalid;
4704
+ break ;
4705
+ }
4706
+ } else if (InputEVT == MVT::i16 ) {
4707
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
4708
+ // it to i64.
4709
+ if (VF.getKnownMinValue () != 8 || AccumEVT != MVT::i64 )
4710
+ return Invalid;
4711
+ } else
4712
+ return Invalid;
4713
+
4714
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
4715
+ // i8mm or sve/streaming features are available.
4716
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
4717
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
4718
+ !ST->isSVEorStreamingSVEAvailable ()))
4719
+ return Invalid;
4720
+
4721
+ if (!BinOp || *BinOp != Instruction::Mul)
4722
+ return Invalid;
4723
+
4724
+ return Cost;
4725
+ }
4726
+
4667
4727
InstructionCost AArch64TTIImpl::getShuffleCost (
4668
4728
TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int > Mask,
4669
4729
TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
@@ -5573,64 +5633,3 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
5573
5633
}
5574
5634
return false ;
5575
5635
}
5576
-
5577
- InstructionCost
5578
- AArch64TTIImpl::getPartialReductionCost (unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
5579
- Type *AccumType, ElementCount VF,
5580
- TTI::PartialReductionExtendKind OpAExtend,
5581
- TTI::PartialReductionExtendKind OpBExtend,
5582
- std::optional<unsigned > BinOp) const {
5583
- InstructionCost Invalid = InstructionCost::getInvalid ();
5584
- InstructionCost Cost (TTI::TCC_Basic);
5585
-
5586
- if (Opcode != Instruction::Add)
5587
- return Invalid;
5588
-
5589
- if (InputTypeA != InputTypeB)
5590
- return Invalid;
5591
-
5592
- EVT InputEVT = EVT::getEVT (InputTypeA);
5593
- EVT AccumEVT = EVT::getEVT (AccumType);
5594
-
5595
- if (VF.isScalable () && !ST->isSVEorStreamingSVEAvailable ())
5596
- return Invalid;
5597
- if (VF.isFixed () && (!ST->isNeonAvailable () || !ST->hasDotProd ()))
5598
- return Invalid;
5599
-
5600
- if (InputEVT == MVT::i8 ) {
5601
- switch (VF.getKnownMinValue ()) {
5602
- default :
5603
- return Invalid;
5604
- case 8 :
5605
- if (AccumEVT == MVT::i32 )
5606
- Cost *= 2 ;
5607
- else if (AccumEVT != MVT::i64 )
5608
- return Invalid;
5609
- break ;
5610
- case 16 :
5611
- if (AccumEVT == MVT::i64 )
5612
- Cost *= 2 ;
5613
- else if (AccumEVT != MVT::i32 )
5614
- return Invalid;
5615
- break ;
5616
- }
5617
- } else if (InputEVT == MVT::i16 ) {
5618
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
5619
- // it to i64.
5620
- if (VF.getKnownMinValue () != 8 || AccumEVT != MVT::i64 )
5621
- return Invalid;
5622
- } else
5623
- return Invalid;
5624
-
5625
- // AArch64 supports lowering mixed extensions to a usdot but only if the
5626
- // i8mm or sve/streaming features are available.
5627
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
5628
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
5629
- !ST->isSVEorStreamingSVEAvailable ()))
5630
- return Invalid;
5631
-
5632
- if (!BinOp || *BinOp != Instruction::Mul)
5633
- return Invalid;
5634
-
5635
- return Cost;
5636
- }
0 commit comments