-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[AArch64][NFC] Move getPartialReductionCost into cpp file #123370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: David Sherwood (david-arm) ChangesThe function getPartialReductionCost is already quite large and Full diff: https://github.com/llvm/llvm-project/pull/123370.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7f10bfed739b41..ba26af129f2757 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5573,3 +5573,64 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
}
return false;
}
+
+InstructionCost
+AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+ Type *AccumType, ElementCount VF,
+ TTI::PartialReductionExtendKind OpAExtend,
+ TTI::PartialReductionExtendKind OpBExtend,
+ std::optional<unsigned> BinOp) const {
+ InstructionCost Invalid = InstructionCost::getInvalid();
+ InstructionCost Cost(TTI::TCC_Basic);
+
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ if (InputTypeA != InputTypeB)
+ return Invalid;
+
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ EVT AccumEVT = EVT::getEVT(AccumType);
+
+ if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+ return Invalid;
+ if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+ return Invalid;
+
+ if (InputEVT == MVT::i8) {
+ switch (VF.getKnownMinValue()) {
+ default:
+ return Invalid;
+ case 8:
+ if (AccumEVT == MVT::i32)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i64)
+ return Invalid;
+ break;
+ case 16:
+ if (AccumEVT == MVT::i64)
+ Cost *= 2;
+ else if (AccumEVT != MVT::i32)
+ return Invalid;
+ break;
+ }
+ } else if (InputEVT == MVT::i16) {
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
+ // it to i64.
+ if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+ return Invalid;
+ } else
+ return Invalid;
+
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
+ // i8mm or sve/streaming features are available.
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable()))
+ return Invalid;
+
+ if (!BinOp || *BinOp != Instruction::Mul)
+ return Invalid;
+
+ return Cost;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
- std::optional<unsigned> BinOp) const {
-
- InstructionCost Invalid = InstructionCost::getInvalid();
- InstructionCost Cost(TTI::TCC_Basic);
-
- if (Opcode != Instruction::Add)
- return Invalid;
-
- if (InputTypeA != InputTypeB)
- return Invalid;
-
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
-
- if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
- return Invalid;
- if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
- return Invalid;
-
- if (InputEVT == MVT::i8) {
- switch (VF.getKnownMinValue()) {
- default:
- return Invalid;
- case 8:
- if (AccumEVT == MVT::i32)
- Cost *= 2;
- else if (AccumEVT != MVT::i64)
- return Invalid;
- break;
- case 16:
- if (AccumEVT == MVT::i64)
- Cost *= 2;
- else if (AccumEVT != MVT::i32)
- return Invalid;
- break;
- }
- } else if (InputEVT == MVT::i16) {
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
- // it to i64.
- if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
- return Invalid;
- } else
- return Invalid;
-
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
- return Cost;
- }
+ std::optional<unsigned> BinOp) const;
bool enableOrderedReductions() const { return true; }
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps put the code close to getArithmeticReductionCost, as they are conceptually similar?
Either way LGTM too.
The function getPartialReductionCost is already quite large and is likely to grow in size as we add support for more cases in future. Therefore, I think it's best to move this into the cpp file.
f92cbd2
to
4da541f
Compare
The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.