Skip to content

[AArch64][NFC] Move getPartialReductionCost into cpp file #123370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 20, 2025

Conversation

david-arm
Copy link
Contributor

The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

Changes

The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.


Full diff: https://github.com/llvm/llvm-project/pull/123370.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+61)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+1-56)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7f10bfed739b41..ba26af129f2757 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5573,3 +5573,64 @@ bool AArch64TTIImpl::isProfitableToSinkOperands(
   }
   return false;
 }
+
+InstructionCost
+AArch64TTIImpl::getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
+                        Type *AccumType, ElementCount VF,
+                        TTI::PartialReductionExtendKind OpAExtend,
+                        TTI::PartialReductionExtendKind OpBExtend,
+                        std::optional<unsigned> BinOp) const {
+  InstructionCost Invalid = InstructionCost::getInvalid();
+  InstructionCost Cost(TTI::TCC_Basic);
+
+  if (Opcode != Instruction::Add)
+    return Invalid;
+
+  if (InputTypeA != InputTypeB)
+    return Invalid;
+
+  EVT InputEVT = EVT::getEVT(InputTypeA);
+  EVT AccumEVT = EVT::getEVT(AccumType);
+
+  if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
+    return Invalid;
+  if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
+    return Invalid;
+
+  if (InputEVT == MVT::i8) {
+    switch (VF.getKnownMinValue()) {
+    default:
+      return Invalid;
+    case 8:
+      if (AccumEVT == MVT::i32)
+        Cost *= 2;
+      else if (AccumEVT != MVT::i64)
+        return Invalid;
+      break;
+    case 16:
+      if (AccumEVT == MVT::i64)
+        Cost *= 2;
+      else if (AccumEVT != MVT::i32)
+        return Invalid;
+      break;
+    }
+  } else if (InputEVT == MVT::i16) {
+    // FIXME: Allow i32 accumulator but increase cost, as we would extend
+    //        it to i64.
+    if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
+      return Invalid;
+  } else
+    return Invalid;
+
+  // AArch64 supports lowering mixed extensions to a usdot but only if the
+  // i8mm or sve/streaming features are available.
+  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
+      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+       !ST->isSVEorStreamingSVEAvailable()))
+    return Invalid;
+
+  if (!BinOp || *BinOp != Instruction::Mul)
+    return Invalid;
+
+  return Cost;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 1eb805ae00b1bb..b65e3c7a1ab20e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
                           Type *AccumType, ElementCount VF,
                           TTI::PartialReductionExtendKind OpAExtend,
                           TTI::PartialReductionExtendKind OpBExtend,
-                          std::optional<unsigned> BinOp) const {
-
-    InstructionCost Invalid = InstructionCost::getInvalid();
-    InstructionCost Cost(TTI::TCC_Basic);
-
-    if (Opcode != Instruction::Add)
-      return Invalid;
-
-    if (InputTypeA != InputTypeB)
-      return Invalid;
-
-    EVT InputEVT = EVT::getEVT(InputTypeA);
-    EVT AccumEVT = EVT::getEVT(AccumType);
-
-    if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
-      return Invalid;
-    if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
-      return Invalid;
-
-    if (InputEVT == MVT::i8) {
-      switch (VF.getKnownMinValue()) {
-      default:
-        return Invalid;
-      case 8:
-        if (AccumEVT == MVT::i32)
-          Cost *= 2;
-        else if (AccumEVT != MVT::i64)
-          return Invalid;
-        break;
-      case 16:
-        if (AccumEVT == MVT::i64)
-          Cost *= 2;
-        else if (AccumEVT != MVT::i32)
-          return Invalid;
-        break;
-      }
-    } else if (InputEVT == MVT::i16) {
-      // FIXME: Allow i32 accumulator but increase cost, as we would extend
-      //        it to i64.
-      if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
-        return Invalid;
-    } else
-      return Invalid;
-
-    // AArch64 supports lowering mixed extensions to a usdot but only if the
-    // i8mm or sve/streaming features are available.
-    if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-        (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-         !ST->isSVEorStreamingSVEAvailable()))
-      return Invalid;
-
-    if (!BinOp || *BinOp != Instruction::Mul)
-      return Invalid;
-
-    return Cost;
-  }
+                          std::optional<unsigned> BinOp) const;
 
   bool enableOrderedReductions() const { return true; }
 

Copy link

github-actions bot commented Jan 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps put the code close to getArithmeticReductionCost, as they are conceptually similar?

Either way LGTM too.

The function getPartialReductionCost is already quite large and
is likely to grow in size as we add support for more cases in
future. Therefore, I think it's best to move this into the cpp
file.
@david-arm david-arm merged commit a733c1f into llvm:main Jan 20, 2025
6 of 7 checks passed
@david-arm david-arm deleted the part_reduc_nfc2 branch January 28, 2025 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants