Skip to content

[IPO] Prevent removal of some convergent attr #134863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions llvm/lib/Transforms/IPO/AttributorAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2902,6 +2902,22 @@ struct AANonConvergentImpl : public AANonConvergent {
}
};

static bool FunctionCalledWithConvergenceToken(const Function *F) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start with lowercase

for (auto &Use : F->uses()) {
CallBase *CB = dyn_cast<CallBase>(Use.getUser());
if (!CB)
continue;

// We are not called, just used as an argument.
if (CB->getCalledFunction() != F)
continue;
Comment on lines +2908 to +2913
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!CB)
continue;
// We are not called, just used as an argument.
if (CB->getCalledFunction() != F)
continue;
if (!CB || CB->getCalledFunction() != F)
continue;


if (CB->getConvergenceControlToken())
return true;
}
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be a conservatively correct true. If the address is captured you do not know the users. If the function is externally visible, you also do not know the users.

}

struct AANonConvergentFunction final : AANonConvergentImpl {
AANonConvergentFunction(const IRPosition &IRP, Attributor &A)
: AANonConvergentImpl(IRP, A) {}
Expand Down Expand Up @@ -2929,6 +2945,11 @@ struct AANonConvergentFunction final : AANonConvergentImpl {
UsedAssumedInformation)) {
return indicatePessimisticFixpoint();
}

const Function *F = this->getIRPosition().getAssociatedFunction();
if (FunctionCalledWithConvergenceToken(F))
return indicatePessimisticFixpoint();

return ChangeStatus::UNCHANGED;
}

Expand Down
21 changes: 20 additions & 1 deletion llvm/lib/Transforms/IPO/FunctionAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,23 @@ static bool InstrBreaksNonConvergent(Instruction &I,
!SCCNodes.contains(CB->getCalledFunction());
}

static bool FunctionRequiresConvergence(const Function *F) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is identical to FunctionCalledWithConvergenceToken. Can we move the definition to a common place, say Function::hasConvergentCalls()?

for (auto &Use : F->uses()) {
CallBase *CB = dyn_cast<CallBase>(Use.getUser());
if (!CB)
return true;

// We are not called, just used as an argument.
if (CB->getCalledFunction() != F)
continue;

if (CB->getConvergenceControlToken())
return true;
}

return false;
}

/// Helper for NoUnwind inference predicate InstrBreaksAttribute.
static bool InstrBreaksNonThrowing(Instruction &I, const SCCNodeSet &SCCNodes) {
if (!I.mayThrow(/* IncludePhaseOneUnwind */ true))
Expand Down Expand Up @@ -1967,7 +1984,9 @@ static void inferConvergent(const SCCNodeSet &SCCNodes,
AI.registerAttrInference(AttributeInferer::InferenceDescriptor{
Attribute::Convergent,
// Skip non-convergent functions.
[](const Function &F) { return !F.isConvergent(); },
[](const Function &F) {
return !F.isConvergent() || FunctionRequiresConvergence(&F);
},
// Instructions that break non-convergent assumption.
[SCCNodes](Instruction &I) {
return InstrBreaksNonConvergent(I, SCCNodes);
Expand Down
92 changes: 92 additions & 0 deletions llvm/test/Transforms/ADCE/convergence.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt %s -passes=adce -S | FileCheck %s

; CHECK: Function Attrs: convergent
define i32 @foo(i32 %a) #0 {
; CHECK-LABEL: define i32 @foo(
; CHECK-SAME: i32 [[A:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: ret i32 [[A]]
;
entry:
%tk = call token @llvm.experimental.convergence.entry()
ret i32 %a
}

; CHECK: Function Attrs: convergent
define void @bar() #0 {
; CHECK-LABEL: define void @bar(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: ret void
;
entry:
%tk = call token @llvm.experimental.convergence.anchor()
ret void
}

; CHECK: Function Attrs: convergent
define void @baz() #0 {
; CHECK-LABEL: define void @baz(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: br label %[[HEADER:.*]]
; CHECK: [[HEADER]]:
; CHECK-NEXT: br i1 true, label %[[BODY:.*]], label %[[EXIT:.*]]
; CHECK: [[BODY]]:
; CHECK-NEXT: br label %[[HEADER]]
; CHECK: [[EXIT]]:
; CHECK-NEXT: ret void
;
entry:
%tk0 = call token @llvm.experimental.convergence.entry()
br label %header

header:
%tk1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %tk0) ]
br i1 true, label %body, label %exit

body:
br label %header

exit:
ret void
}

define void @indirect_inner() #0 {
; CHECK-LABEL: define void @indirect_inner(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: ret void
;
entry:
%tk0 = call token @llvm.experimental.convergence.entry()
ret void
}

define void @indirect() #0 {
; CHECK-LABEL: define void @indirect(
; CHECK-SAME: ) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TK0:%.*]] = call token @llvm.experimental.convergence.entry()
; CHECK-NEXT: [[VAR:%.*]] = alloca ptr, align 8
; CHECK-NEXT: store ptr @indirect_inner, ptr [[VAR]], align 8
; CHECK-NEXT: [[PTR:%.*]] = load ptr, ptr [[VAR]], align 8
; CHECK-NEXT: call void [[PTR]]() #[[ATTR0]] [ "convergencectrl"(token [[TK0]]) ]
; CHECK-NEXT: ret void
;
entry:
%tk0 = call token @llvm.experimental.convergence.entry()
%var = alloca ptr, align 8
store ptr @indirect_inner, ptr %var, align 8
%ptr = load ptr, ptr %var, align 8
call void %ptr() convergent [ "convergencectrl"(token %tk0) ]
ret void
}

declare token @llvm.experimental.convergence.entry() #1
declare token @llvm.experimental.convergence.anchor() #1
declare token @llvm.experimental.convergence.loop() #1

attributes #0 = { convergent }
attributes #1 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
Loading