Skip to content

Commit 00a8314

Browse files
[AArch64][SME] Extend Inliner cost-model with custom penalty for calls. (#68416)
This is a stacked PR following on from #68415 This patch has two purposes: (1) It tries to make inlining more likely when it can avoid a streaming-mode change. (2) It avoids inlining when inlining causes more streaming-mode changes. An example of (1) is: ``` void streaming_compatible_bar(void); void foo(void) __arm_streaming { /* other code */ streaming_compatible_bar(); /* other code */ } void f(void) { foo(); // expensive streaming mode change } -> void f(void) { /* other code */ streaming_compatible_bar(); /* other code */ } ``` where it wouldn't have inlined the function when foo would be a non-streaming function. An example of (2) is: ``` void streaming_bar(void) __arm_streaming; void foo(void) __arm_streaming { streaming_bar(); streaming_bar(); } void f(void) { foo(); // expensive streaming mode change } -> (do not inline into) void f(void) { streaming_bar(); // these are now two expensive streaming mode changes streaming_bar(); }```
1 parent b8d3ccd commit 00a8314

File tree

10 files changed

+227
-10
lines changed

10 files changed

+227
-10
lines changed

llvm/include/llvm/Analysis/InlineCost.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ InlineParams getInlineParams(unsigned OptLevel, unsigned SizeOptLevel);
259259

260260
/// Return the cost associated with a callsite, including parameter passing
261261
/// and the call/return instruction.
262-
int getCallsiteCost(const CallBase &Call, const DataLayout &DL);
262+
int getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
263+
const DataLayout &DL);
263264

264265
/// Get an InlineCost object representing the cost of inlining this
265266
/// callsite.

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,15 @@ class TargetTransformInfo {
15171517
bool areInlineCompatible(const Function *Caller,
15181518
const Function *Callee) const;
15191519

1520+
/// Returns a penalty for invoking call \p Call in \p F.
1521+
/// For example, if a function F calls a function G, which in turn calls
1522+
/// function H, then getInlineCallPenalty(F, H()) would return the
1523+
/// penalty of calling H from F, e.g. after inlining G into F.
1524+
/// \p DefaultCallPenalty is passed to give a default penalty that
1525+
/// the target can amend or override.
1526+
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
1527+
unsigned DefaultCallPenalty) const;
1528+
15201529
/// \returns True if the caller and callee agree on how \p Types will be
15211530
/// passed to or returned from the callee.
15221531
/// to the callee.
@@ -2012,6 +2021,8 @@ class TargetTransformInfo::Concept {
20122021
std::optional<uint32_t> AtomicCpySize) const = 0;
20132022
virtual bool areInlineCompatible(const Function *Caller,
20142023
const Function *Callee) const = 0;
2024+
virtual unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
2025+
unsigned DefaultCallPenalty) const = 0;
20152026
virtual bool areTypesABICompatible(const Function *Caller,
20162027
const Function *Callee,
20172028
const ArrayRef<Type *> &Types) const = 0;
@@ -2673,6 +2684,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
26732684
const Function *Callee) const override {
26742685
return Impl.areInlineCompatible(Caller, Callee);
26752686
}
2687+
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
2688+
unsigned DefaultCallPenalty) const override {
2689+
return Impl.getInlineCallPenalty(F, Call, DefaultCallPenalty);
2690+
}
26762691
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
26772692
const ArrayRef<Type *> &Types) const override {
26782693
return Impl.areTypesABICompatible(Caller, Callee, Types);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,11 @@ class TargetTransformInfoImplBase {
802802
Callee->getFnAttribute("target-features"));
803803
}
804804

805+
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
806+
unsigned DefaultCallPenalty) const {
807+
return DefaultCallPenalty;
808+
}
809+
805810
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
806811
const ArrayRef<Type *> &Types) const {
807812
return (Caller->getFnAttribute("target-cpu") ==

llvm/lib/Analysis/InlineCost.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
695695
}
696696
} else
697697
// Otherwise simply add the cost for merely making the call.
698-
addCost(CallPenalty);
698+
addCost(TTI.getInlineCallPenalty(CandidateCall.getCaller(), Call,
699+
CallPenalty));
699700
}
700701

701702
void onFinalizeSwitch(unsigned JumpTableSize,
@@ -918,7 +919,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
918919
// Compute the total savings for the call site.
919920
auto *CallerBB = CandidateCall.getParent();
920921
BlockFrequencyInfo *CallerBFI = &(GetBFI(*(CallerBB->getParent())));
921-
CycleSavings += getCallsiteCost(this->CandidateCall, DL);
922+
CycleSavings += getCallsiteCost(TTI, this->CandidateCall, DL);
922923
CycleSavings *= *CallerBFI->getBlockProfileCount(CallerBB);
923924

924925
// Remove the cost of the cold basic blocks to model the runtime cost more
@@ -1076,7 +1077,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
10761077

10771078
// Give out bonuses for the callsite, as the instructions setting them up
10781079
// will be gone after inlining.
1079-
addCost(-getCallsiteCost(this->CandidateCall, DL));
1080+
addCost(-getCallsiteCost(TTI, this->CandidateCall, DL));
10801081

10811082
// If this function uses the coldcc calling convention, prefer not to inline
10821083
// it.
@@ -1315,7 +1316,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
13151316

13161317
InlineResult onAnalysisStart() override {
13171318
increment(InlineCostFeatureIndex::callsite_cost,
1318-
-1 * getCallsiteCost(this->CandidateCall, DL));
1319+
-1 * getCallsiteCost(TTI, this->CandidateCall, DL));
13191320

13201321
set(InlineCostFeatureIndex::cold_cc_penalty,
13211322
(F.getCallingConv() == CallingConv::Cold));
@@ -2887,7 +2888,8 @@ static bool functionsHaveCompatibleAttributes(
28872888
AttributeFuncs::areInlineCompatible(*Caller, *Callee);
28882889
}
28892890

2890-
int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
2891+
int llvm::getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
2892+
const DataLayout &DL) {
28912893
int64_t Cost = 0;
28922894
for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
28932895
if (Call.isByValArgument(I)) {
@@ -2917,7 +2919,8 @@ int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
29172919
}
29182920
// The call instruction also disappears after inlining.
29192921
Cost += InstrCost;
2920-
Cost += CallPenalty;
2922+
Cost += TTI.getInlineCallPenalty(Call.getCaller(), Call, CallPenalty);
2923+
29212924
return std::min<int64_t>(Cost, INT_MAX);
29222925
}
29232926

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,13 @@ bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
11331133
return TTIImpl->areInlineCompatible(Caller, Callee);
11341134
}
11351135

1136+
unsigned
1137+
TargetTransformInfo::getInlineCallPenalty(const Function *F,
1138+
const CallBase &Call,
1139+
unsigned DefaultCallPenalty) const {
1140+
return TTIImpl->getInlineCallPenalty(F, Call, DefaultCallPenalty);
1141+
}
1142+
11361143
bool TargetTransformInfo::areTypesABICompatible(
11371144
const Function *Caller, const Function *Callee,
11381145
const ArrayRef<Type *> &Types) const {

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ static cl::opt<unsigned>
4646
NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
4747
cl::Hidden);
4848

49+
static cl::opt<unsigned> CallPenaltyChangeSM(
50+
"call-penalty-sm-change", cl::init(5), cl::Hidden,
51+
cl::desc(
52+
"Penalty of calling a function that requires a change to PSTATE.SM"));
53+
54+
static cl::opt<unsigned> InlineCallPenaltyChangeSM(
55+
"inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
56+
cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
57+
4958
namespace {
5059
class TailFoldingOption {
5160
// These bitfields will only ever be set to something non-zero in operator=,
@@ -269,6 +278,40 @@ bool AArch64TTIImpl::areTypesABICompatible(
269278
return true;
270279
}
271280

281+
unsigned
282+
AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
283+
unsigned DefaultCallPenalty) const {
284+
// This function calculates a penalty for executing Call in F.
285+
//
286+
// There are two ways this function can be called:
287+
// (1) F:
288+
// call from F -> G (the call here is Call)
289+
//
290+
// For (1), Call.getCaller() == F, so it will always return a high cost if
291+
// a streaming-mode change is required (thus promoting the need to inline the
292+
// function)
293+
//
294+
// (2) F:
295+
// call from F -> G (the call here is not Call)
296+
// G:
297+
// call from G -> H (the call here is Call)
298+
//
299+
// For (2), if after inlining the body of G into F the call to H requires a
300+
// streaming-mode change, and the call to G from F would also require a
301+
// streaming-mode change, then there is benefit to do the streaming-mode
302+
// change only once and avoid inlining of G into F.
303+
SMEAttrs FAttrs(*F);
304+
SMEAttrs CalleeAttrs(Call);
305+
if (FAttrs.requiresSMChange(CalleeAttrs)) {
306+
if (F == Call.getCaller()) // (1)
307+
return CallPenaltyChangeSM * DefaultCallPenalty;
308+
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
309+
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
310+
}
311+
312+
return DefaultCallPenalty;
313+
}
314+
272315
bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
273316
TargetTransformInfo::RegisterKind K) const {
274317
assert(K != TargetTransformInfo::RGK_Scalar);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
8080
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
8181
const ArrayRef<Type *> &Types) const;
8282

83+
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
84+
unsigned DefaultCallPenalty) const;
85+
8386
/// \name Scalar TTI Implementations
8487
/// @{
8588

llvm/lib/Transforms/IPO/PartialInlining.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline(
767767
const DataLayout &DL = Caller->getParent()->getDataLayout();
768768

769769
// The savings of eliminating the call:
770-
int NonWeightedSavings = getCallsiteCost(CB, DL);
770+
int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL);
771771
BlockFrequency NormWeightedSavings(NonWeightedSavings);
772772

773773
// Weighted saving is smaller than weighted cost, return false
@@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB,
842842
}
843843

844844
if (CallInst *CI = dyn_cast<CallInst>(&I)) {
845-
InlineCost += getCallsiteCost(*CI, DL);
845+
InlineCost += getCallsiteCost(*TTI, *CI, DL);
846846
continue;
847847
}
848848

849849
if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
850-
InlineCost += getCallsiteCost(*II, DL);
850+
InlineCost += getCallsiteCost(*TTI, *II, DL);
851851
continue;
852852
}
853853

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
2+
; RUN: opt < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -S -passes=inline -inlinedefault-threshold=1 | FileCheck %s
3+
4+
; This test sets the inline-threshold to 1 such that by default the call to @streaming_callee is not inlined.
5+
; However, if the call to @streaming_callee requires a streaming-mode change, it should always inline the call because the streaming-mode change is more expensive.
6+
target triple = "aarch64"
7+
8+
declare void @streaming_compatible_f() "aarch64_pstate_sm_compatible"
9+
10+
; Function @streaming_callee doesn't contain any operations that may use ZA
11+
; state and therefore can be legally inlined into a normal function.
12+
define void @streaming_callee() "aarch64_pstate_sm_enabled" {
13+
; CHECK-LABEL: define void @streaming_callee
14+
; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
15+
; CHECK-NEXT: call void @streaming_compatible_f()
16+
; CHECK-NEXT: call void @streaming_compatible_f()
17+
; CHECK-NEXT: ret void
18+
;
19+
call void @streaming_compatible_f()
20+
call void @streaming_compatible_f()
21+
ret void
22+
}
23+
24+
; Inline call to @streaming_callee to remove a streaming mode change.
25+
define void @non_streaming_caller_inline() {
26+
; CHECK-LABEL: define void @non_streaming_caller_inline
27+
; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
28+
; CHECK-NEXT: call void @streaming_compatible_f()
29+
; CHECK-NEXT: call void @streaming_compatible_f()
30+
; CHECK-NEXT: ret void
31+
;
32+
call void @streaming_callee()
33+
ret void
34+
}
35+
36+
; Don't inline call to @streaming_callee when the inline-threshold is set to 1, because it does not eliminate a streaming-mode change.
37+
define void @streaming_caller_dont_inline() "aarch64_pstate_sm_enabled" {
38+
; CHECK-LABEL: define void @streaming_caller_dont_inline
39+
; CHECK-SAME: () #[[ATTR1]] {
40+
; CHECK-NEXT: call void @streaming_callee()
41+
; CHECK-NEXT: ret void
42+
;
43+
call void @streaming_callee()
44+
ret void
45+
}

llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,98 @@ entry:
581581
%res = call i64 @normal_callee_call_sme_state()
582582
ret i64 %res
583583
}
584+
585+
586+
587+
declare void @streaming_body() "aarch64_pstate_sm_enabled"
588+
589+
define void @streaming_caller_single_streaming_callee() "aarch64_pstate_sm_enabled" {
590+
; CHECK-LABEL: define void @streaming_caller_single_streaming_callee
591+
; CHECK-SAME: () #[[ATTR2]] {
592+
; CHECK-NEXT: call void @streaming_body()
593+
; CHECK-NEXT: ret void
594+
;
595+
call void @streaming_body()
596+
ret void
597+
}
598+
599+
define void @streaming_caller_multiple_streaming_callees() "aarch64_pstate_sm_enabled" {
600+
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees
601+
; CHECK-SAME: () #[[ATTR2]] {
602+
; CHECK-NEXT: call void @streaming_body()
603+
; CHECK-NEXT: call void @streaming_body()
604+
; CHECK-NEXT: ret void
605+
;
606+
call void @streaming_body()
607+
call void @streaming_body()
608+
ret void
609+
}
610+
611+
; Allow inlining, as inline it would not increase the number of streaming-mode changes.
612+
define void @streaming_caller_single_streaming_callee_inline() {
613+
; CHECK-LABEL: define void @streaming_caller_single_streaming_callee_inline
614+
; CHECK-SAME: () #[[ATTR1]] {
615+
; CHECK-NEXT: call void @streaming_body()
616+
; CHECK-NEXT: ret void
617+
;
618+
call void @streaming_caller_single_streaming_callee()
619+
ret void
620+
}
621+
622+
; Prevent inlining, as inline it would lead to multiple streaming-mode changes.
623+
define void @streaming_caller_multiple_streaming_callees_dont_inline() {
624+
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees_dont_inline
625+
; CHECK-SAME: () #[[ATTR1]] {
626+
; CHECK-NEXT: call void @streaming_caller_multiple_streaming_callees()
627+
; CHECK-NEXT: ret void
628+
;
629+
call void @streaming_caller_multiple_streaming_callees()
630+
ret void
631+
}
632+
633+
declare void @streaming_compatible_body() "aarch64_pstate_sm_compatible"
634+
635+
define void @streaming_caller_single_streaming_compatible_callee() "aarch64_pstate_sm_enabled" {
636+
; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee
637+
; CHECK-SAME: () #[[ATTR2]] {
638+
; CHECK-NEXT: call void @streaming_compatible_body()
639+
; CHECK-NEXT: ret void
640+
;
641+
call void @streaming_compatible_body()
642+
ret void
643+
}
644+
645+
define void @streaming_caller_multiple_streaming_compatible_callees() "aarch64_pstate_sm_enabled" {
646+
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees
647+
; CHECK-SAME: () #[[ATTR2]] {
648+
; CHECK-NEXT: call void @streaming_compatible_body()
649+
; CHECK-NEXT: call void @streaming_compatible_body()
650+
; CHECK-NEXT: ret void
651+
;
652+
call void @streaming_compatible_body()
653+
call void @streaming_compatible_body()
654+
ret void
655+
}
656+
657+
; Allow inlining, as inline would remove a streaming-mode change.
658+
define void @streaming_caller_single_streaming_compatible_callee_inline() {
659+
; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee_inline
660+
; CHECK-SAME: () #[[ATTR1]] {
661+
; CHECK-NEXT: call void @streaming_compatible_body()
662+
; CHECK-NEXT: ret void
663+
;
664+
call void @streaming_caller_single_streaming_compatible_callee()
665+
ret void
666+
}
667+
668+
; Allow inlining, as inline would remove several stremaing-mode changes.
669+
define void @streaming_caller_multiple_streaming_compatible_callees_inline() {
670+
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees_inline
671+
; CHECK-SAME: () #[[ATTR1]] {
672+
; CHECK-NEXT: call void @streaming_compatible_body()
673+
; CHECK-NEXT: call void @streaming_compatible_body()
674+
; CHECK-NEXT: ret void
675+
;
676+
call void @streaming_caller_multiple_streaming_compatible_callees()
677+
ret void
678+
}

0 commit comments

Comments
 (0)