Skip to content

Commit 4da541f

Browse files
committed
Address review comment
1 parent 681994c commit 4da541f

File tree

1 file changed

+60
-61
lines changed

1 file changed

+60
-61
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4664,6 +4664,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
46644664
return LegalizationCost * LT.first;
46654665
}
46664666

4667+
InstructionCost AArch64TTIImpl::getPartialReductionCost(
4668+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
4669+
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
4670+
TTI::PartialReductionExtendKind OpBExtend,
4671+
std::optional<unsigned> BinOp) const {
4672+
InstructionCost Invalid = InstructionCost::getInvalid();
4673+
InstructionCost Cost(TTI::TCC_Basic);
4674+
4675+
if (Opcode != Instruction::Add)
4676+
return Invalid;
4677+
4678+
if (InputTypeA != InputTypeB)
4679+
return Invalid;
4680+
4681+
EVT InputEVT = EVT::getEVT(InputTypeA);
4682+
EVT AccumEVT = EVT::getEVT(AccumType);
4683+
4684+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
4685+
return Invalid;
4686+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
4687+
return Invalid;
4688+
4689+
if (InputEVT == MVT::i8) {
4690+
switch (VF.getKnownMinValue()) {
4691+
default:
4692+
return Invalid;
4693+
case 8:
4694+
if (AccumEVT == MVT::i32)
4695+
Cost *= 2;
4696+
else if (AccumEVT != MVT::i64)
4697+
return Invalid;
4698+
break;
4699+
case 16:
4700+
if (AccumEVT == MVT::i64)
4701+
Cost *= 2;
4702+
else if (AccumEVT != MVT::i32)
4703+
return Invalid;
4704+
break;
4705+
}
4706+
} else if (InputEVT == MVT::i16) {
4707+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
4708+
// it to i64.
4709+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
4710+
return Invalid;
4711+
} else
4712+
return Invalid;
4713+
4714+
// AArch64 supports lowering mixed extensions to a usdot but only if the
4715+
// i8mm or sve/streaming features are available.
4716+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
4717+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
4718+
!ST->isSVEorStreamingSVEAvailable()))
4719+
return Invalid;
4720+
4721+
if (!BinOp || *BinOp != Instruction::Mul)
4722+
return Invalid;
4723+
4724+
return Cost;
4725+
}
4726+
46674727
InstructionCost AArch64TTIImpl::getShuffleCost(
46684728
TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int> Mask,
46694729
TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
@@ -5573,64 +5633,3 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
55735633
}
55745634
return false;
55755635
}
5576-
5577-
InstructionCost
5578-
AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
5579-
Type *AccumType, ElementCount VF,
5580-
TTI::PartialReductionExtendKind OpAExtend,
5581-
TTI::PartialReductionExtendKind OpBExtend,
5582-
std::optional<unsigned> BinOp) const {
5583-
InstructionCost Invalid = InstructionCost::getInvalid();
5584-
InstructionCost Cost(TTI::TCC_Basic);
5585-
5586-
if (Opcode != Instruction::Add)
5587-
return Invalid;
5588-
5589-
if (InputTypeA != InputTypeB)
5590-
return Invalid;
5591-
5592-
EVT InputEVT = EVT::getEVT(InputTypeA);
5593-
EVT AccumEVT = EVT::getEVT(AccumType);
5594-
5595-
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
5596-
return Invalid;
5597-
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
5598-
return Invalid;
5599-
5600-
if (InputEVT == MVT::i8) {
5601-
switch (VF.getKnownMinValue()) {
5602-
default:
5603-
return Invalid;
5604-
case 8:
5605-
if (AccumEVT == MVT::i32)
5606-
Cost *= 2;
5607-
else if (AccumEVT != MVT::i64)
5608-
return Invalid;
5609-
break;
5610-
case 16:
5611-
if (AccumEVT == MVT::i64)
5612-
Cost *= 2;
5613-
else if (AccumEVT != MVT::i32)
5614-
return Invalid;
5615-
break;
5616-
}
5617-
} else if (InputEVT == MVT::i16) {
5618-
// FIXME: Allow i32 accumulator but increase cost, as we would extend
5619-
// it to i64.
5620-
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
5621-
return Invalid;
5622-
} else
5623-
return Invalid;
5624-
5625-
// AArch64 supports lowering mixed extensions to a usdot but only if the
5626-
// i8mm or sve/streaming features are available.
5627-
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
5628-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5629-
!ST->isSVEorStreamingSVEAvailable()))
5630-
return Invalid;
5631-
5632-
if (!BinOp || *BinOp != Instruction::Mul)
5633-
return Invalid;
5634-
5635-
return Cost;
5636-
}

0 commit comments

Comments
 (0)