Skip to content

[AArch64][SME] Extend Inliner cost-model with custom penalty for calls. #68416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 31, 2023

Conversation

sdesmalen-arm
Copy link
Collaborator

This is a stacked PR following on from #68415

This patch has two purposes:
(1) It tries to make inlining more likely when it can avoid a streaming-mode change.
(2) It avoids inlining when inlining causes more streaming-mode changes.

An example of (1) is:

  void streaming_compatible_bar(void);

  void foo(void) __arm_streaming {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  ->

  void f(void) {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

where it wouldn't have inlined the function when foo would be a non-streaming function.

An example of (2) is:

  void streaming_bar(void) __arm_streaming;

  void foo(void) __arm_streaming {
    streaming_bar();
    streaming_bar();
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  -> (do not inline into)

  void f(void) {
    streaming_bar();  // these are now two expensive streaming mode changes
    streaming_bar();
  }```

@llvmbot llvmbot added backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Oct 6, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-aarch64

Changes

This is a stacked PR following on from #68415

This patch has two purposes:
(1) It tries to make inlining more likely when it can avoid a streaming-mode change.
(2) It avoids inlining when inlining causes more streaming-mode changes.

An example of (1) is:

  void streaming_compatible_bar(void);

  void foo(void) __arm_streaming {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  ->

  void f(void) {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

where it wouldn't have inlined the function when foo would be a non-streaming function.

An example of (2) is:

  void streaming_bar(void) __arm_streaming;

  void foo(void) __arm_streaming {
    streaming_bar();
    streaming_bar();
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  -> (do not inline into)

  void f(void) {
    streaming_bar();  // these are now two expensive streaming mode changes
    streaming_bar();
  }```

---

Patch is 31.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68416.diff


11 Files Affected:

- (modified) llvm/include/llvm/Analysis/InlineCost.h (+2-1) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+16) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+5) 
- (modified) llvm/lib/Analysis/InlineCost.cpp (+9-6) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+71-4) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+3) 
- (modified) llvm/lib/Transforms/IPO/PartialInlining.cpp (+3-3) 
- (added) llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll (+43) 
- (modified) llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll (+203-30) 
- (modified) llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll (+69-2) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/InlineCost.h b/llvm/include/llvm/Analysis/InlineCost.h
index 57f452853d2d6d6..3f0bb879e021fd7 100644
--- a/llvm/include/llvm/Analysis/InlineCost.h
+++ b/llvm/include/llvm/Analysis/InlineCost.h
@@ -259,7 +259,8 @@ InlineParams getInlineParams(unsigned OptLevel, unsigned SizeOptLevel);
 
 /// Return the cost associated with a callsite, including parameter passing
 /// and the call/return instruction.
-int getCallsiteCost(const CallBase &Call, const DataLayout &DL);
+int getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
+                    const DataLayout &DL);
 
 /// Get an InlineCost object representing the cost of inlining this
 /// callsite.
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..7a85c03d659232e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1506,6 +1506,15 @@ class TargetTransformInfo {
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
 
+  /// Returns a penalty for invoking call \p Call in \p F.
+  /// For example, if a function F calls a function G, which in turn calls
+  /// function H, then getInlineCallPenalty(F, H()) would return the
+  /// penalty of calling H from F, e.g. after inlining G into F.
+  /// \p DefaultCallPenalty is passed to give a default penalty that
+  /// the target can amend or override.
+  unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                unsigned DefaultCallPenalty) const;
+
   /// \returns True if the caller and callee agree on how \p Types will be
   /// passed to or returned from the callee.
   /// to the callee.
@@ -2001,6 +2010,8 @@ class TargetTransformInfo::Concept {
       std::optional<uint32_t> AtomicCpySize) const = 0;
   virtual bool areInlineCompatible(const Function *Caller,
                                    const Function *Callee) const = 0;
+  virtual unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                        unsigned DefaultCallPenalty) const = 0;
   virtual bool areTypesABICompatible(const Function *Caller,
                                      const Function *Callee,
                                      const ArrayRef<Type *> &Types) const = 0;
@@ -2662,6 +2673,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                            const Function *Callee) const override {
     return Impl.areInlineCompatible(Caller, Callee);
   }
+  unsigned
+  getInlineCallPenalty(const Function *F, const CallBase &Call,
+                       unsigned DefaultCallPenalty) const override {
+    return Impl.getInlineCallPenalty(F, Call, DefaultCallPenalty);
+  }
   bool areTypesABICompatible(const Function *Caller, const Function *Callee,
                              const ArrayRef<Type *> &Types) const override {
     return Impl.areTypesABICompatible(Caller, Callee, Types);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index c1ff314ae51c98b..e6fc178365626ba 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -802,6 +802,11 @@ class TargetTransformInfoImplBase {
             Callee->getFnAttribute("target-features"));
   }
 
+  unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                unsigned DefaultCallPenalty) const {
+    return DefaultCallPenalty;
+  }
+
   bool areTypesABICompatible(const Function *Caller, const Function *Callee,
                              const ArrayRef<Type *> &Types) const {
     return (Caller->getFnAttribute("target-cpu") ==
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index fa0c30637633df3..7096e06d925adef 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -695,7 +695,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       }
     } else
       // Otherwise simply add the cost for merely making the call.
-      addCost(CallPenalty);
+      addCost(TTI.getInlineCallPenalty(CandidateCall.getCaller(), Call,
+                                       CallPenalty));
   }
 
   void onFinalizeSwitch(unsigned JumpTableSize,
@@ -918,7 +919,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
     // Compute the total savings for the call site.
     auto *CallerBB = CandidateCall.getParent();
     BlockFrequencyInfo *CallerBFI = &(GetBFI(*(CallerBB->getParent())));
-    CycleSavings += getCallsiteCost(this->CandidateCall, DL);
+    CycleSavings += getCallsiteCost(TTI, this->CandidateCall, DL);
     CycleSavings *= *CallerBFI->getBlockProfileCount(CallerBB);
 
     // Remove the cost of the cold basic blocks to model the runtime cost more
@@ -1076,7 +1077,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
 
     // Give out bonuses for the callsite, as the instructions setting them up
     // will be gone after inlining.
-    addCost(-getCallsiteCost(this->CandidateCall, DL));
+    addCost(-getCallsiteCost(TTI, this->CandidateCall, DL));
 
     // If this function uses the coldcc calling convention, prefer not to inline
     // it.
@@ -1315,7 +1316,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   InlineResult onAnalysisStart() override {
     increment(InlineCostFeatureIndex::callsite_cost,
-              -1 * getCallsiteCost(this->CandidateCall, DL));
+              -1 * getCallsiteCost(TTI, this->CandidateCall, DL));
 
     set(InlineCostFeatureIndex::cold_cc_penalty,
         (F.getCallingConv() == CallingConv::Cold));
@@ -2887,7 +2888,8 @@ static bool functionsHaveCompatibleAttributes(
          AttributeFuncs::areInlineCompatible(*Caller, *Callee);
 }
 
-int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
+int llvm::getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
+                          const DataLayout &DL) {
   int64_t Cost = 0;
   for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
     if (Call.isByValArgument(I)) {
@@ -2917,7 +2919,8 @@ int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
   }
   // The call instruction also disappears after inlining.
   Cost += InstrCost;
-  Cost += CallPenalty;
+  Cost += TTI.getInlineCallPenalty(Call.getCaller(), Call, CallPenalty);
+
   return std::min<int64_t>(Cost, INT_MAX);
 }
 
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..10ed3a4437dae5f 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1133,6 +1133,12 @@ bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
   return TTIImpl->areInlineCompatible(Caller, Callee);
 }
 
+unsigned TargetTransformInfo::getInlineCallPenalty(
+    const Function *F, const CallBase &Call,
+    unsigned DefaultCallPenalty) const {
+  return TTIImpl->getInlineCallPenalty(F, Call, DefaultCallPenalty);
+}
+
 bool TargetTransformInfo::areTypesABICompatible(
     const Function *Caller, const Function *Callee,
     const ArrayRef<Type *> &Types) const {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index cded28054f59259..e107c7d8540cce5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -190,16 +190,49 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
 
+static bool isSMEABIRoutineCall(const CallInst &CI) {
+  const auto *F = CI.getCalledFunction();
+  return F && StringSwitch<bool>(F->getName())
+                  .Case("__arm_sme_state", true)
+                  .Case("__arm_tpidr2_save", true)
+                  .Case("__arm_tpidr2_restore", true)
+                  .Case("__arm_za_disable", true)
+                  .Default(false);
+}
+
+/// Returns true if the function has explicit operations that can only be lowered
+/// using incompatible instructions for the selected mode.
+/// This also returns true if the function F may use or modify ZA state.
+static bool hasPossibleIncompatibleOps(const Function *F) {
+  for (const BasicBlock &BB : *F) {
+    for (const Instruction &I : BB) {
+      // Be conservative for now and assume that any call to inline asm or to
+      // intrinsics could could result in non-streaming ops (e.g. calls to
+      // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
+      // all native LLVM instructions can be lowered to compatible instructions.
+      if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
+          (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
+           isSMEABIRoutineCall(cast<CallInst>(I))))
+        return true;
+    }
+  }
+  return false;
+}
+
 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
                                          const Function *Callee) const {
   SMEAttrs CallerAttrs(*Caller);
   SMEAttrs CalleeAttrs(*Callee);
-  if (CallerAttrs.requiresSMChange(CalleeAttrs,
-                                   /*BodyOverridesInterface=*/true) ||
-      CallerAttrs.requiresLazySave(CalleeAttrs) ||
-      CalleeAttrs.hasNewZABody())
+  if (CalleeAttrs.hasNewZABody())
     return false;
 
+  if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresSMChange(CalleeAttrs,
+                                   /*BodyOverridesInterface=*/true)) {
+    if (hasPossibleIncompatibleOps(Callee))
+      return false;
+  }
+
   const TargetMachine &TM = getTLI()->getTargetMachine();
 
   const FeatureBitset &CallerBits =
@@ -212,6 +245,40 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
   return (CallerBits & CalleeBits) == CalleeBits;
 }
 
+unsigned
+AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                     unsigned DefaultCallPenalty) const {
+  // This function calculates a penalty for executing Call in F.
+  //
+  // There are two ways this function can be called:
+  // (1)  F:
+  //       call from F -> G (the call here is Call)
+  //
+  // For (1), Call.getCaller() == F, so it will always return a high cost if
+  // a streaming-mode change is required (thus promoting the need to inline the
+  // function)
+  //
+  // (2)  F:
+  //       call from F -> G (the call here is not Call)
+  //      G:
+  //       call from G -> H (the call here is Call)
+  //
+  // For (2), if after inlining the body of G into F the call to H requires a
+  // streaming-mode change, and the call to G from F would also require a
+  // streaming-mode change, then there is benefit to do the streaming-mode
+  // change only once and avoid inlining of G into F.
+  SMEAttrs FAttrs(*F);
+  SMEAttrs CalleeAttrs(Call);
+  if (FAttrs.requiresSMChange(CalleeAttrs)) {
+    if (F == Call.getCaller())                                // (1)
+      return 5 * DefaultCallPenalty;
+    if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
+      return 10 * DefaultCallPenalty;
+  }
+
+  return DefaultCallPenalty;
+}
+
 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
     TargetTransformInfo::RegisterKind K) const {
   assert(K != TargetTransformInfo::RGK_Scalar);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a6baade412c77d2..cccce44fe35ffc0 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -77,6 +77,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   bool areInlineCompatible(const Function *Caller,
                            const Function *Callee) const;
 
+  unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
+                                unsigned DefaultCallPenalty) const;
+
   /// \name Scalar TTI Implementations
   /// @{
 
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index 25da06add24f031..aa4f205ec5bdf1e 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline(
   const DataLayout &DL = Caller->getParent()->getDataLayout();
 
   // The savings of eliminating the call:
-  int NonWeightedSavings = getCallsiteCost(CB, DL);
+  int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL);
   BlockFrequency NormWeightedSavings(NonWeightedSavings);
 
   // Weighted saving is smaller than weighted cost, return false
@@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB,
     }
 
     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
-      InlineCost += getCallsiteCost(*CI, DL);
+      InlineCost += getCallsiteCost(*TTI, *CI, DL);
       continue;
     }
 
     if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
-      InlineCost += getCallsiteCost(*II, DL);
+      InlineCost += getCallsiteCost(*TTI, *II, DL);
       continue;
     }
 
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
new file mode 100644
index 000000000000000..f207efaeaad36b8
--- /dev/null
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs-low-threshold.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -S -passes=inline  -inlinedefault-threshold=1 | FileCheck %s
+
+; This test sets the inline-threshold to 1 such that by default the call to @streaming_callee is not inlined.
+; However, if the call to @streaming_callee requires a streaming-mode change, it should always inline the call because the streaming-mode change is more expensive.
+target triple = "aarch64"
+
+declare void @streaming_compatible_f() "aarch64_pstate_sm_compatible"
+
+define void @streaming_callee() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_callee
+; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_compatible_f()
+  call void @streaming_compatible_f()
+  ret void
+}
+
+; Inline call to @streaming_callee to remove a streaming mode change.
+define void @non_streaming_caller_inline() {
+; CHECK-LABEL: define void @non_streaming_caller_inline
+; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    call void @streaming_compatible_f()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_callee()
+  ret void
+}
+
+; Don't inline call to @streaming_callee when the inline-threshold is set to 1, because it does not eliminate a streaming-mode change.
+define void @streaming_caller_dont_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_dont_inline
+; CHECK-SAME: () #[[ATTR1]] {
+; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    ret void
+;
+  call void @streaming_callee()
+  ret void
+}
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
index 3df5400875ae288..d6b1f3ef45e7655 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
@@ -102,11 +102,11 @@ entry:
 ; [ ] N  -> SC
 ; [ ] N  -> N + B
 ; [ ] N  -> SC + B
-define void @normal_caller_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_streaming_callee_dont_inline
+define void @normal_caller_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -136,11 +136,11 @@ entry:
 ; [ ] N  -> SC
 ; [x] N  -> N + B
 ; [ ] N  -> SC + B
-define void @normal_caller_locally_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_dont_inline
+define void @normal_caller_locally_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -153,11 +153,11 @@ entry:
 ; [ ] N  -> SC
 ; [ ] N  -> N + B
 ; [x] N  -> SC + B
-define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline() {
-; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline
+define void @normal_caller_streaming_compatible_locally_streaming_callee_inline() {
+; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR1]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_compatible_locally_streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -170,11 +170,11 @@ entry:
 ; [ ] S  -> SC
 ; [ ] S  -> N + B
 ; [ ] S  -> SC + B
-define void @streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_enabled" {
-; CHECK-LABEL: define void @streaming_caller_normal_callee_dont_inline
+define void @streaming_caller_normal_callee_inline() "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: define void @streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR2]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -255,11 +255,11 @@ entry:
 ; [ ] N + B -> SC
 ; [ ] N + B -> N + B
 ; [ ] N + B -> SC + B
-define void @locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_body" {
-; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_dont_inline
+define void @locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_body" {
+; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR3]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -340,11 +340,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_dont_inline
+define void @streaming_compatible_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @normal_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -357,11 +357,11 @@ entry:
 ; [ ] SC -> SC
 ; [ ] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_dont_inline
+define void @streaming_compatible_caller_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_inline
 ; CHECK-SAME: () #[[ATTR0]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    call void @streaming_callee()
+; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -391,11 +391,11 @@ entry:
 ; [ ] SC -> SC
 ; [x] SC -> N + B
 ; [ ] SC -> SC + B
-define void @streaming_compatible_caller_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
-; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_dont_inline
+define void @streaming_compatible_caller_locally_streaming_callee_inline() "aar...
[truncated]

@github-actions
Copy link

github-actions bot commented Oct 6, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@jroelofs jroelofs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jroelofs
Copy link
Contributor

gentle ping

This patch has two purposes:
(1) It tries to make inlining more likely when it can avoid a streaming-mode change.
(2) It avoids inlining when inlining causes more streaming-mode changes.

An example of (1) is:

  void streaming_compatible_bar(void);

  void foo(void) __arm_streaming {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  ->

  void f(void) {
    /* other code */
    streaming_compatible_bar();
    /* other code */
  }

  where it wouldn't have inlined the function when foo would be a non-streaming function.

An example of (2) is:

  void streaming_bar(void) __arm_streaming;

  void foo(void) __arm_streaming {
    streaming_bar();
    streaming_bar();
  }

  void f(void) {
    foo();            // expensive streaming mode change
  }

  -> (do not inline into)

  void f(void) {
    streaming_bar();  // these are now two expensive streaming mode changes
    streaming_bar();
  }
@sdesmalen-arm sdesmalen-arm force-pushed the sme-inlining-part2-cost branch from e98d96e to b2b29af Compare October 30, 2023 11:06

declare void @streaming_compatible_f() "aarch64_pstate_sm_compatible"

define void @streaming_callee() "aarch64_pstate_sm_enabled" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would be helpful here to explain that @streaming_callee doesn't contain any operations that use ZA state, and therefore can be legally inlined into a normal function.

@sdesmalen-arm sdesmalen-arm merged commit 00a8314 into llvm:main Oct 31, 2023
@sdesmalen-arm sdesmalen-arm deleted the sme-inlining-part2-cost branch February 23, 2024 11:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants