Skip to content

[AArch64][SME] Extend Inliner cost-model with custom penalty for calls. #68416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/Analysis/InlineCost.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ InlineParams getInlineParams(unsigned OptLevel, unsigned SizeOptLevel);

/// Return the cost associated with a callsite, including parameter passing
/// and the call/return instruction.
int getCallsiteCost(const CallBase &Call, const DataLayout &DL);
int getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
const DataLayout &DL);

/// Get an InlineCost object representing the cost of inlining this
/// callsite.
Expand Down
15 changes: 15 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,15 @@ class TargetTransformInfo {
bool areInlineCompatible(const Function *Caller,
const Function *Callee) const;

/// Returns a penalty for invoking call \p Call in \p F.
/// For example, if a function F calls a function G, which in turn calls
/// function H, then getInlineCallPenalty(F, H()) would return the
/// penalty of calling H from F, e.g. after inlining G into F.
/// \p DefaultCallPenalty is passed to give a default penalty that
/// the target can amend or override.
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

/// \returns True if the caller and callee agree on how \p Types will be
/// passed to or returned from the callee.
/// to the callee.
Expand Down Expand Up @@ -2001,6 +2010,8 @@ class TargetTransformInfo::Concept {
std::optional<uint32_t> AtomicCpySize) const = 0;
virtual bool areInlineCompatible(const Function *Caller,
const Function *Callee) const = 0;
virtual unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const = 0;
virtual bool areTypesABICompatible(const Function *Caller,
const Function *Callee,
const ArrayRef<Type *> &Types) const = 0;
Expand Down Expand Up @@ -2662,6 +2673,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
const Function *Callee) const override {
return Impl.areInlineCompatible(Caller, Callee);
}
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const override {
return Impl.getInlineCallPenalty(F, Call, DefaultCallPenalty);
}
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const override {
return Impl.areTypesABICompatible(Caller, Callee, Types);
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,11 @@ class TargetTransformInfoImplBase {
Callee->getFnAttribute("target-features"));
}

unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const {
return DefaultCallPenalty;
}

bool areTypesABICompatible(const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const {
return (Caller->getFnAttribute("target-cpu") ==
Expand Down
15 changes: 9 additions & 6 deletions llvm/lib/Analysis/InlineCost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,8 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
}
} else
// Otherwise simply add the cost for merely making the call.
addCost(CallPenalty);
addCost(TTI.getInlineCallPenalty(CandidateCall.getCaller(), Call,
CallPenalty));
}

void onFinalizeSwitch(unsigned JumpTableSize,
Expand Down Expand Up @@ -918,7 +919,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
// Compute the total savings for the call site.
auto *CallerBB = CandidateCall.getParent();
BlockFrequencyInfo *CallerBFI = &(GetBFI(*(CallerBB->getParent())));
CycleSavings += getCallsiteCost(this->CandidateCall, DL);
CycleSavings += getCallsiteCost(TTI, this->CandidateCall, DL);
CycleSavings *= *CallerBFI->getBlockProfileCount(CallerBB);

// Remove the cost of the cold basic blocks to model the runtime cost more
Expand Down Expand Up @@ -1076,7 +1077,7 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {

// Give out bonuses for the callsite, as the instructions setting them up
// will be gone after inlining.
addCost(-getCallsiteCost(this->CandidateCall, DL));
addCost(-getCallsiteCost(TTI, this->CandidateCall, DL));

// If this function uses the coldcc calling convention, prefer not to inline
// it.
Expand Down Expand Up @@ -1315,7 +1316,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {

InlineResult onAnalysisStart() override {
increment(InlineCostFeatureIndex::callsite_cost,
-1 * getCallsiteCost(this->CandidateCall, DL));
-1 * getCallsiteCost(TTI, this->CandidateCall, DL));

set(InlineCostFeatureIndex::cold_cc_penalty,
(F.getCallingConv() == CallingConv::Cold));
Expand Down Expand Up @@ -2887,7 +2888,8 @@ static bool functionsHaveCompatibleAttributes(
AttributeFuncs::areInlineCompatible(*Caller, *Callee);
}

int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
int llvm::getCallsiteCost(const TargetTransformInfo &TTI, const CallBase &Call,
const DataLayout &DL) {
int64_t Cost = 0;
for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
if (Call.isByValArgument(I)) {
Expand Down Expand Up @@ -2917,7 +2919,8 @@ int llvm::getCallsiteCost(const CallBase &Call, const DataLayout &DL) {
}
// The call instruction also disappears after inlining.
Cost += InstrCost;
Cost += CallPenalty;
Cost += TTI.getInlineCallPenalty(Call.getCaller(), Call, CallPenalty);

return std::min<int64_t>(Cost, INT_MAX);
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,13 @@ bool TargetTransformInfo::areInlineCompatible(const Function *Caller,
return TTIImpl->areInlineCompatible(Caller, Callee);
}

unsigned
TargetTransformInfo::getInlineCallPenalty(const Function *F,
const CallBase &Call,
unsigned DefaultCallPenalty) const {
return TTIImpl->getInlineCallPenalty(F, Call, DefaultCallPenalty);
}

bool TargetTransformInfo::areTypesABICompatible(
const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const {
Expand Down
43 changes: 43 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ static cl::opt<unsigned>
NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
cl::Hidden);

static cl::opt<unsigned> CallPenaltyChangeSM(
"call-penalty-sm-change", cl::init(5), cl::Hidden,
cl::desc(
"Penalty of calling a function that requires a change to PSTATE.SM"));

static cl::opt<unsigned> InlineCallPenaltyChangeSM(
"inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));

namespace {
class TailFoldingOption {
// These bitfields will only ever be set to something non-zero in operator=,
Expand Down Expand Up @@ -269,6 +278,40 @@ bool AArch64TTIImpl::areTypesABICompatible(
return true;
}

unsigned
AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const {
// This function calculates a penalty for executing Call in F.
//
// There are two ways this function can be called:
// (1) F:
// call from F -> G (the call here is Call)
//
// For (1), Call.getCaller() == F, so it will always return a high cost if
// a streaming-mode change is required (thus promoting the need to inline the
// function)
//
// (2) F:
// call from F -> G (the call here is not Call)
// G:
// call from G -> H (the call here is Call)
//
// For (2), if after inlining the body of G into F the call to H requires a
// streaming-mode change, and the call to G from F would also require a
// streaming-mode change, then there is benefit to do the streaming-mode
// change only once and avoid inlining of G into F.
SMEAttrs FAttrs(*F);
SMEAttrs CalleeAttrs(Call);
if (FAttrs.requiresSMChange(CalleeAttrs)) {
if (F == Call.getCaller()) // (1)
return CallPenaltyChangeSM * DefaultCallPenalty;
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
}

return DefaultCallPenalty;
}

bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
TargetTransformInfo::RegisterKind K) const {
assert(K != TargetTransformInfo::RGK_Scalar);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
bool areTypesABICompatible(const Function *Caller, const Function *Callee,
const ArrayRef<Type *> &Types) const;

unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

/// \name Scalar TTI Implementations
/// @{

Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Transforms/IPO/PartialInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline(
const DataLayout &DL = Caller->getParent()->getDataLayout();

// The savings of eliminating the call:
int NonWeightedSavings = getCallsiteCost(CB, DL);
int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL);
BlockFrequency NormWeightedSavings(NonWeightedSavings);

// Weighted saving is smaller than weighted cost, return false
Expand Down Expand Up @@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB,
}

if (CallInst *CI = dyn_cast<CallInst>(&I)) {
InlineCost += getCallsiteCost(*CI, DL);
InlineCost += getCallsiteCost(*TTI, *CI, DL);
continue;
}

if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
InlineCost += getCallsiteCost(*II, DL);
InlineCost += getCallsiteCost(*TTI, *II, DL);
continue;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
; RUN: opt < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -S -passes=inline -inlinedefault-threshold=1 | FileCheck %s

; This test sets the inline-threshold to 1 such that by default the call to @streaming_callee is not inlined.
; However, if the call to @streaming_callee requires a streaming-mode change, it should always inline the call because the streaming-mode change is more expensive.
target triple = "aarch64"

declare void @streaming_compatible_f() "aarch64_pstate_sm_compatible"

; Function @streaming_callee doesn't contain any operations that may use ZA
; state and therefore can be legally inlined into a normal function.
define void @streaming_callee() "aarch64_pstate_sm_enabled" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would be helpful here to explain that @streaming_callee doesn't contain any operations that use ZA state, and therefore can be legally inlined into a normal function.

; CHECK-LABEL: define void @streaming_callee
; CHECK-SAME: () #[[ATTR1:[0-9]+]] {
; CHECK-NEXT: call void @streaming_compatible_f()
; CHECK-NEXT: call void @streaming_compatible_f()
; CHECK-NEXT: ret void
;
call void @streaming_compatible_f()
call void @streaming_compatible_f()
ret void
}

; Inline call to @streaming_callee to remove a streaming mode change.
define void @non_streaming_caller_inline() {
; CHECK-LABEL: define void @non_streaming_caller_inline
; CHECK-SAME: () #[[ATTR2:[0-9]+]] {
; CHECK-NEXT: call void @streaming_compatible_f()
; CHECK-NEXT: call void @streaming_compatible_f()
; CHECK-NEXT: ret void
;
call void @streaming_callee()
ret void
}

; Don't inline call to @streaming_callee when the inline-threshold is set to 1, because it does not eliminate a streaming-mode change.
define void @streaming_caller_dont_inline() "aarch64_pstate_sm_enabled" {
; CHECK-LABEL: define void @streaming_caller_dont_inline
; CHECK-SAME: () #[[ATTR1]] {
; CHECK-NEXT: call void @streaming_callee()
; CHECK-NEXT: ret void
;
call void @streaming_callee()
ret void
}
95 changes: 95 additions & 0 deletions llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,98 @@ entry:
%res = call i64 @normal_callee_call_sme_state()
ret i64 %res
}



declare void @streaming_body() "aarch64_pstate_sm_enabled"

define void @streaming_caller_single_streaming_callee() "aarch64_pstate_sm_enabled" {
; CHECK-LABEL: define void @streaming_caller_single_streaming_callee
; CHECK-SAME: () #[[ATTR2]] {
; CHECK-NEXT: call void @streaming_body()
; CHECK-NEXT: ret void
;
call void @streaming_body()
ret void
}

define void @streaming_caller_multiple_streaming_callees() "aarch64_pstate_sm_enabled" {
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees
; CHECK-SAME: () #[[ATTR2]] {
; CHECK-NEXT: call void @streaming_body()
; CHECK-NEXT: call void @streaming_body()
; CHECK-NEXT: ret void
;
call void @streaming_body()
call void @streaming_body()
ret void
}

; Allow inlining, as inline it would not increase the number of streaming-mode changes.
define void @streaming_caller_single_streaming_callee_inline() {
; CHECK-LABEL: define void @streaming_caller_single_streaming_callee_inline
; CHECK-SAME: () #[[ATTR1]] {
; CHECK-NEXT: call void @streaming_body()
; CHECK-NEXT: ret void
;
call void @streaming_caller_single_streaming_callee()
ret void
}

; Prevent inlining, as inline it would lead to multiple streaming-mode changes.
define void @streaming_caller_multiple_streaming_callees_dont_inline() {
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_callees_dont_inline
; CHECK-SAME: () #[[ATTR1]] {
; CHECK-NEXT: call void @streaming_caller_multiple_streaming_callees()
; CHECK-NEXT: ret void
;
call void @streaming_caller_multiple_streaming_callees()
ret void
}

declare void @streaming_compatible_body() "aarch64_pstate_sm_compatible"

define void @streaming_caller_single_streaming_compatible_callee() "aarch64_pstate_sm_enabled" {
; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee
; CHECK-SAME: () #[[ATTR2]] {
; CHECK-NEXT: call void @streaming_compatible_body()
; CHECK-NEXT: ret void
;
call void @streaming_compatible_body()
ret void
}

define void @streaming_caller_multiple_streaming_compatible_callees() "aarch64_pstate_sm_enabled" {
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees
; CHECK-SAME: () #[[ATTR2]] {
; CHECK-NEXT: call void @streaming_compatible_body()
; CHECK-NEXT: call void @streaming_compatible_body()
; CHECK-NEXT: ret void
;
call void @streaming_compatible_body()
call void @streaming_compatible_body()
ret void
}

; Allow inlining, as inline would remove a streaming-mode change.
define void @streaming_caller_single_streaming_compatible_callee_inline() {
; CHECK-LABEL: define void @streaming_caller_single_streaming_compatible_callee_inline
; CHECK-SAME: () #[[ATTR1]] {
; CHECK-NEXT: call void @streaming_compatible_body()
; CHECK-NEXT: ret void
;
call void @streaming_caller_single_streaming_compatible_callee()
ret void
}

; Allow inlining, as inline would remove several stremaing-mode changes.
define void @streaming_caller_multiple_streaming_compatible_callees_inline() {
; CHECK-LABEL: define void @streaming_caller_multiple_streaming_compatible_callees_inline
; CHECK-SAME: () #[[ATTR1]] {
; CHECK-NEXT: call void @streaming_compatible_body()
; CHECK-NEXT: call void @streaming_compatible_body()
; CHECK-NEXT: ret void
;
call void @streaming_caller_multiple_streaming_compatible_callees()
ret void
}