Skip to content

Commit 681994c

Browse files
committed
[AArch64][NFC] Move getPartialReductionCost into cpp file
The function getPartialReductionCost is already quite large and is likely to grow in size as we add support for more cases in future. Therefore, I think it's best to move this into the cpp file.
1 parent 5139c90 commit 681994c

File tree

2 files changed

+62
-56
lines changed

2 files changed

+62
-56
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5573,3 +5573,64 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
55735573
}
55745574
return false;
55755575
}
5576+
5577+
InstructionCost
5578+
AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
5579+
Type *AccumType, ElementCount VF,
5580+
TTI::PartialReductionExtendKind OpAExtend,
5581+
TTI::PartialReductionExtendKind OpBExtend,
5582+
std::optional<unsigned> BinOp) const {
5583+
InstructionCost Invalid = InstructionCost::getInvalid();
5584+
InstructionCost Cost(TTI::TCC_Basic);
5585+
5586+
if (Opcode != Instruction::Add)
5587+
return Invalid;
5588+
5589+
if (InputTypeA != InputTypeB)
5590+
return Invalid;
5591+
5592+
EVT InputEVT = EVT::getEVT(InputTypeA);
5593+
EVT AccumEVT = EVT::getEVT(AccumType);
5594+
5595+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
5596+
return Invalid;
5597+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
5598+
return Invalid;
5599+
5600+
if (InputEVT == MVT::i8) {
5601+
switch (VF.getKnownMinValue()) {
5602+
default:
5603+
return Invalid;
5604+
case 8:
5605+
if (AccumEVT == MVT::i32)
5606+
Cost *= 2;
5607+
else if (AccumEVT != MVT::i64)
5608+
return Invalid;
5609+
break;
5610+
case 16:
5611+
if (AccumEVT == MVT::i64)
5612+
Cost *= 2;
5613+
else if (AccumEVT != MVT::i32)
5614+
return Invalid;
5615+
break;
5616+
}
5617+
} else if (InputEVT == MVT::i16) {
5618+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
5619+
// it to i64.
5620+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
5621+
return Invalid;
5622+
} else
5623+
return Invalid;
5624+
5625+
// AArch64 supports lowering mixed extensions to a usdot but only if the
5626+
// i8mm or sve/streaming features are available.
5627+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
5628+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5629+
!ST->isSVEorStreamingSVEAvailable()))
5630+
return Invalid;
5631+
5632+
if (!BinOp || *BinOp != Instruction::Mul)
5633+
return Invalid;
5634+
5635+
return Cost;
5636+
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
367367
Type *AccumType, ElementCount VF,
368368
TTI::PartialReductionExtendKind OpAExtend,
369369
TTI::PartialReductionExtendKind OpBExtend,
370-
std::optional<unsigned> BinOp) const {
371-
372-
InstructionCost Invalid = InstructionCost::getInvalid();
373-
InstructionCost Cost(TTI::TCC_Basic);
374-
375-
if (Opcode != Instruction::Add)
376-
return Invalid;
377-
378-
if (InputTypeA != InputTypeB)
379-
return Invalid;
380-
381-
EVT InputEVT = EVT::getEVT(InputTypeA);
382-
EVT AccumEVT = EVT::getEVT(AccumType);
383-
384-
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
385-
return Invalid;
386-
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
387-
return Invalid;
388-
389-
if (InputEVT == MVT::i8) {
390-
switch (VF.getKnownMinValue()) {
391-
default:
392-
return Invalid;
393-
case 8:
394-
if (AccumEVT == MVT::i32)
395-
Cost *= 2;
396-
else if (AccumEVT != MVT::i64)
397-
return Invalid;
398-
break;
399-
case 16:
400-
if (AccumEVT == MVT::i64)
401-
Cost *= 2;
402-
else if (AccumEVT != MVT::i32)
403-
return Invalid;
404-
break;
405-
}
406-
} else if (InputEVT == MVT::i16) {
407-
// FIXME: Allow i32 accumulator but increase cost, as we would extend
408-
// it to i64.
409-
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
410-
return Invalid;
411-
} else
412-
return Invalid;
413-
414-
// AArch64 supports lowering mixed extensions to a usdot but only if the
415-
// i8mm or sve/streaming features are available.
416-
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
417-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
418-
!ST->isSVEorStreamingSVEAvailable()))
419-
return Invalid;
420-
421-
if (!BinOp || *BinOp != Instruction::Mul)
422-
return Invalid;
423-
424-
return Cost;
425-
}
370+
std::optional<unsigned> BinOp) const;
426371

427372
bool enableOrderedReductions() const { return true; }
428373

0 commit comments

Comments
 (0)