Skip to content

Commit 6d30bc0

Browse files
[AArch64][SME] Allow inlining when streaming-mode attributes dont match up. (#68415)
The use-case here is to support things like: int foo(int x, int y) __arm_streaming { return std::max<int>(x, y); } where the call to non-streaming `std::max<int>(x, y)` can be safely inlined into the streaming function. This is a first step and will need further work to allow more cases (e.g. more finegrained analysis of the function calls to ensure they don't result in any incompatible instructions for the requested mode).
1 parent a902ca6 commit 6d30bc0

File tree

3 files changed

+218
-40
lines changed

3 files changed

+218
-40
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,49 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
190190
static cl::opt<bool> EnableScalableAutovecInStreamingMode(
191191
"enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
192192

193+
static bool isSMEABIRoutineCall(const CallInst &CI) {
194+
const auto *F = CI.getCalledFunction();
195+
return F && StringSwitch<bool>(F->getName())
196+
.Case("__arm_sme_state", true)
197+
.Case("__arm_tpidr2_save", true)
198+
.Case("__arm_tpidr2_restore", true)
199+
.Case("__arm_za_disable", true)
200+
.Default(false);
201+
}
202+
203+
/// Returns true if the function has explicit operations that can only be
204+
/// lowered using incompatible instructions for the selected mode. This also
205+
/// returns true if the function F may use or modify ZA state.
206+
static bool hasPossibleIncompatibleOps(const Function *F) {
207+
for (const BasicBlock &BB : *F) {
208+
for (const Instruction &I : BB) {
209+
// Be conservative for now and assume that any call to inline asm or to
210+
// intrinsics could could result in non-streaming ops (e.g. calls to
211+
// @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
212+
// all native LLVM instructions can be lowered to compatible instructions.
213+
if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
214+
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
215+
isSMEABIRoutineCall(cast<CallInst>(I))))
216+
return true;
217+
}
218+
}
219+
return false;
220+
}
221+
193222
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
194223
const Function *Callee) const {
195224
SMEAttrs CallerAttrs(*Caller);
196225
SMEAttrs CalleeAttrs(*Callee);
197-
if (CallerAttrs.requiresSMChange(CalleeAttrs,
198-
/*BodyOverridesInterface=*/true) ||
199-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
200-
CalleeAttrs.hasNewZABody())
226+
if (CalleeAttrs.hasNewZABody())
201227
return false;
202228

229+
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
230+
CallerAttrs.requiresSMChange(CalleeAttrs,
231+
/*BodyOverridesInterface=*/true)) {
232+
if (hasPossibleIncompatibleOps(Callee))
233+
return false;
234+
}
235+
203236
const TargetMachine &TM = getTLI()->getTargetMachine();
204237

205238
const FeatureBitset &CallerBits =

llvm/test/Transforms/Inline/AArch64/sme-pstatesm-attrs.ll

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ entry:
102102
; [ ] N -> SC
103103
; [ ] N -> N + B
104104
; [ ] N -> SC + B
105-
define void @normal_caller_streaming_callee_dont_inline() {
106-
; CHECK-LABEL: define void @normal_caller_streaming_callee_dont_inline
105+
define void @normal_caller_streaming_callee_inline() {
106+
; CHECK-LABEL: define void @normal_caller_streaming_callee_inline
107107
; CHECK-SAME: () #[[ATTR1]] {
108108
; CHECK-NEXT: entry:
109-
; CHECK-NEXT: call void @streaming_callee()
109+
; CHECK-NEXT: call void @inlined_body()
110110
; CHECK-NEXT: ret void
111111
;
112112
entry:
@@ -136,11 +136,11 @@ entry:
136136
; [ ] N -> SC
137137
; [x] N -> N + B
138138
; [ ] N -> SC + B
139-
define void @normal_caller_locally_streaming_callee_dont_inline() {
140-
; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_dont_inline
139+
define void @normal_caller_locally_streaming_callee_inline() {
140+
; CHECK-LABEL: define void @normal_caller_locally_streaming_callee_inline
141141
; CHECK-SAME: () #[[ATTR1]] {
142142
; CHECK-NEXT: entry:
143-
; CHECK-NEXT: call void @locally_streaming_callee()
143+
; CHECK-NEXT: call void @inlined_body()
144144
; CHECK-NEXT: ret void
145145
;
146146
entry:
@@ -153,11 +153,11 @@ entry:
153153
; [ ] N -> SC
154154
; [ ] N -> N + B
155155
; [x] N -> SC + B
156-
define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline() {
157-
; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_dont_inline
156+
define void @normal_caller_streaming_compatible_locally_streaming_callee_inline() {
157+
; CHECK-LABEL: define void @normal_caller_streaming_compatible_locally_streaming_callee_inline
158158
; CHECK-SAME: () #[[ATTR1]] {
159159
; CHECK-NEXT: entry:
160-
; CHECK-NEXT: call void @streaming_compatible_locally_streaming_callee()
160+
; CHECK-NEXT: call void @inlined_body()
161161
; CHECK-NEXT: ret void
162162
;
163163
entry:
@@ -170,11 +170,11 @@ entry:
170170
; [ ] S -> SC
171171
; [ ] S -> N + B
172172
; [ ] S -> SC + B
173-
define void @streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_enabled" {
174-
; CHECK-LABEL: define void @streaming_caller_normal_callee_dont_inline
173+
define void @streaming_caller_normal_callee_inline() "aarch64_pstate_sm_enabled" {
174+
; CHECK-LABEL: define void @streaming_caller_normal_callee_inline
175175
; CHECK-SAME: () #[[ATTR2]] {
176176
; CHECK-NEXT: entry:
177-
; CHECK-NEXT: call void @normal_callee()
177+
; CHECK-NEXT: call void @inlined_body()
178178
; CHECK-NEXT: ret void
179179
;
180180
entry:
@@ -255,11 +255,11 @@ entry:
255255
; [ ] N + B -> SC
256256
; [ ] N + B -> N + B
257257
; [ ] N + B -> SC + B
258-
define void @locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_body" {
259-
; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_dont_inline
258+
define void @locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_body" {
259+
; CHECK-LABEL: define void @locally_streaming_caller_normal_callee_inline
260260
; CHECK-SAME: () #[[ATTR3]] {
261261
; CHECK-NEXT: entry:
262-
; CHECK-NEXT: call void @normal_callee()
262+
; CHECK-NEXT: call void @inlined_body()
263263
; CHECK-NEXT: ret void
264264
;
265265
entry:
@@ -340,11 +340,11 @@ entry:
340340
; [ ] SC -> SC
341341
; [ ] SC -> N + B
342342
; [ ] SC -> SC + B
343-
define void @streaming_compatible_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" {
344-
; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_dont_inline
343+
define void @streaming_compatible_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" {
344+
; CHECK-LABEL: define void @streaming_compatible_caller_normal_callee_inline
345345
; CHECK-SAME: () #[[ATTR0]] {
346346
; CHECK-NEXT: entry:
347-
; CHECK-NEXT: call void @normal_callee()
347+
; CHECK-NEXT: call void @inlined_body()
348348
; CHECK-NEXT: ret void
349349
;
350350
entry:
@@ -357,11 +357,11 @@ entry:
357357
; [ ] SC -> SC
358358
; [ ] SC -> N + B
359359
; [ ] SC -> SC + B
360-
define void @streaming_compatible_caller_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
361-
; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_dont_inline
360+
define void @streaming_compatible_caller_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
361+
; CHECK-LABEL: define void @streaming_compatible_caller_streaming_callee_inline
362362
; CHECK-SAME: () #[[ATTR0]] {
363363
; CHECK-NEXT: entry:
364-
; CHECK-NEXT: call void @streaming_callee()
364+
; CHECK-NEXT: call void @inlined_body()
365365
; CHECK-NEXT: ret void
366366
;
367367
entry:
@@ -391,11 +391,11 @@ entry:
391391
; [ ] SC -> SC
392392
; [x] SC -> N + B
393393
; [ ] SC -> SC + B
394-
define void @streaming_compatible_caller_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
395-
; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_dont_inline
394+
define void @streaming_compatible_caller_locally_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
395+
; CHECK-LABEL: define void @streaming_compatible_caller_locally_streaming_callee_inline
396396
; CHECK-SAME: () #[[ATTR0]] {
397397
; CHECK-NEXT: entry:
398-
; CHECK-NEXT: call void @locally_streaming_callee()
398+
; CHECK-NEXT: call void @inlined_body()
399399
; CHECK-NEXT: ret void
400400
;
401401
entry:
@@ -408,11 +408,11 @@ entry:
408408
; [ ] SC -> SC
409409
; [ ] SC -> N + B
410410
; [x] SC -> SC + B
411-
define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_dont_inline() "aarch64_pstate_sm_compatible" {
412-
; CHECK-LABEL: define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_dont_inline
411+
define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_inline() "aarch64_pstate_sm_compatible" {
412+
; CHECK-LABEL: define void @streaming_compatible_caller_streaming_compatible_locally_streaming_callee_inline
413413
; CHECK-SAME: () #[[ATTR0]] {
414414
; CHECK-NEXT: entry:
415-
; CHECK-NEXT: call void @streaming_compatible_locally_streaming_callee()
415+
; CHECK-NEXT: call void @inlined_body()
416416
; CHECK-NEXT: ret void
417417
;
418418
entry:
@@ -424,11 +424,11 @@ entry:
424424
; [ ] SC + B -> SC
425425
; [ ] SC + B -> N + B
426426
; [ ] SC + B -> SC + B
427-
define void @streaming_compatible_locally_streaming_caller_normal_callee_dont_inline() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
428-
; CHECK-LABEL: define void @streaming_compatible_locally_streaming_caller_normal_callee_dont_inline
427+
define void @streaming_compatible_locally_streaming_caller_normal_callee_inline() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
428+
; CHECK-LABEL: define void @streaming_compatible_locally_streaming_caller_normal_callee_inline
429429
; CHECK-SAME: () #[[ATTR4]] {
430430
; CHECK-NEXT: entry:
431-
; CHECK-NEXT: call void @normal_callee()
431+
; CHECK-NEXT: call void @inlined_body()
432432
; CHECK-NEXT: ret void
433433
;
434434
entry:
@@ -503,3 +503,81 @@ entry:
503503
call void @streaming_compatible_locally_streaming_callee()
504504
ret void
505505
}
506+
507+
define void @normal_callee_with_inlineasm() {
508+
; CHECK-LABEL: define void @normal_callee_with_inlineasm
509+
; CHECK-SAME: () #[[ATTR1]] {
510+
; CHECK-NEXT: entry:
511+
; CHECK-NEXT: call void asm sideeffect "
512+
; CHECK-NEXT: ret void
513+
;
514+
entry:
515+
call void asm sideeffect "; inlineasm", ""()
516+
ret void
517+
}
518+
519+
define void @streaming_caller_normal_callee_with_inlineasm_dont_inline() "aarch64_pstate_sm_enabled" {
520+
; CHECK-LABEL: define void @streaming_caller_normal_callee_with_inlineasm_dont_inline
521+
; CHECK-SAME: () #[[ATTR2]] {
522+
; CHECK-NEXT: entry:
523+
; CHECK-NEXT: call void @normal_callee_with_inlineasm()
524+
; CHECK-NEXT: ret void
525+
;
526+
entry:
527+
call void @normal_callee_with_inlineasm()
528+
ret void
529+
}
530+
531+
define i64 @normal_callee_with_intrinsic_call() {
532+
; CHECK-LABEL: define i64 @normal_callee_with_intrinsic_call
533+
; CHECK-SAME: () #[[ATTR1]] {
534+
; CHECK-NEXT: entry:
535+
; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.aarch64.sve.cntb(i32 4)
536+
; CHECK-NEXT: ret i64 [[RES]]
537+
;
538+
entry:
539+
%res = call i64 @llvm.aarch64.sve.cntb(i32 4)
540+
ret i64 %res
541+
}
542+
543+
define i64 @streaming_caller_normal_callee_with_intrinsic_call_dont_inline() "aarch64_pstate_sm_enabled" {
544+
; CHECK-LABEL: define i64 @streaming_caller_normal_callee_with_intrinsic_call_dont_inline
545+
; CHECK-SAME: () #[[ATTR2]] {
546+
; CHECK-NEXT: entry:
547+
; CHECK-NEXT: [[RES:%.*]] = call i64 @normal_callee_with_intrinsic_call()
548+
; CHECK-NEXT: ret i64 [[RES]]
549+
;
550+
entry:
551+
%res = call i64 @normal_callee_with_intrinsic_call()
552+
ret i64 %res
553+
}
554+
555+
declare i64 @llvm.aarch64.sve.cntb(i32)
556+
557+
define i64 @normal_callee_call_sme_state() {
558+
; CHECK-LABEL: define i64 @normal_callee_call_sme_state
559+
; CHECK-SAME: () #[[ATTR1]] {
560+
; CHECK-NEXT: entry:
561+
; CHECK-NEXT: [[RES:%.*]] = call { i64, i64 } @__arm_sme_state()
562+
; CHECK-NEXT: [[RES_0:%.*]] = extractvalue { i64, i64 } [[RES]], 0
563+
; CHECK-NEXT: ret i64 [[RES_0]]
564+
;
565+
entry:
566+
%res = call {i64, i64} @__arm_sme_state()
567+
%res.0 = extractvalue {i64, i64} %res, 0
568+
ret i64 %res.0
569+
}
570+
571+
declare {i64, i64} @__arm_sme_state()
572+
573+
define i64 @streaming_caller_normal_callee_call_sme_state_dont_inline() "aarch64_pstate_sm_enabled" {
574+
; CHECK-LABEL: define i64 @streaming_caller_normal_callee_call_sme_state_dont_inline
575+
; CHECK-SAME: () #[[ATTR2]] {
576+
; CHECK-NEXT: entry:
577+
; CHECK-NEXT: [[RES:%.*]] = call i64 @normal_callee_call_sme_state()
578+
; CHECK-NEXT: ret i64 [[RES]]
579+
;
580+
entry:
581+
%res = call i64 @normal_callee_call_sme_state()
582+
ret i64 %res
583+
}

0 commit comments

Comments
 (0)