Skip to content

Commit e5e18d4

Browse files
committed
[OPT] Search whole BB for convergence token.
The spec for llvm.experimental.convergence.entry says that is must be in the entry block for a function, and must preceed any other convergent operation. It does not have to be the first instruction in the entry block. Inlining assumes that the call to llvm.experimental.convergence.entry will be the first instruction after any phi instructions. This commit modifies inlining to search the entire block for the call.
1 parent ab208de commit e5e18d4

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

llvm/lib/Transforms/Utils/InlineFunction.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,19 @@ namespace {
180180
}
181181
};
182182

183+
IntrinsicInst *getConevrgenceEntryIfAny(BasicBlock &BB) {
184+
auto *I = BB.getFirstNonPHI();
185+
while (I) {
186+
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
187+
if (IntrinsicCall->getIntrinsicID() ==
188+
Intrinsic::experimental_convergence_entry) {
189+
return IntrinsicCall;
190+
}
191+
}
192+
I = I->getNextNode();
193+
}
194+
return nullptr;
195+
}
183196
} // end anonymous namespace
184197

185198
/// Get or create a target for the branch from ResumeInsts.
@@ -2438,15 +2451,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
24382451
// fully implements convergence control tokens, there is no mixing of
24392452
// controlled and uncontrolled convergent operations in the whole program.
24402453
if (CB.isConvergent()) {
2441-
auto *I = CalledFunc->getEntryBlock().getFirstNonPHI();
2442-
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
2443-
if (IntrinsicCall->getIntrinsicID() ==
2444-
Intrinsic::experimental_convergence_entry) {
2445-
if (!ConvergenceControlToken) {
2446-
return InlineResult::failure(
2447-
"convergent call needs convergencectrl operand");
2448-
}
2449-
}
2454+
auto *I = getConevrgenceEntryIfAny(CalledFunc->getEntryBlock());
2455+
if (I && !ConvergenceControlToken) {
2456+
return InlineResult::failure(
2457+
"convergent call needs convergencectrl operand");
24502458
}
24512459
}
24522460

@@ -2737,13 +2745,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
27372745
}
27382746

27392747
if (ConvergenceControlToken) {
2740-
auto *I = FirstNewBlock->getFirstNonPHI();
2741-
if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) {
2742-
if (IntrinsicCall->getIntrinsicID() ==
2743-
Intrinsic::experimental_convergence_entry) {
2744-
IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken);
2745-
IntrinsicCall->eraseFromParent();
2746-
}
2748+
auto *IntrinsicCall = getConevrgenceEntryIfAny(*FirstNewBlock);
2749+
if (IntrinsicCall) {
2750+
IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken);
2751+
IntrinsicCall->eraseFromParent();
27472752
}
27482753
}
27492754

llvm/test/Transforms/Inline/convergence-inline.ll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,30 @@ define void @test_two_calls() convergent {
185185
ret void
186186
}
187187

188+
define i32 @token_not_first(i32 %x) convergent alwaysinline {
189+
; CHECK-LABEL: @token_not_first(
190+
; CHECK-NEXT: {{%.*}} = alloca ptr, align 8
191+
; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
192+
; CHECK-NEXT: [[Y:%.*]] = call i32 @g(i32 [[X:%.*]]) [ "convergencectrl"(token [[TOKEN]]) ]
193+
; CHECK-NEXT: ret i32 [[Y]]
194+
;
195+
%p = alloca ptr, align 8
196+
%token = call token @llvm.experimental.convergence.entry()
197+
%y = call i32 @g(i32 %x) [ "convergencectrl"(token %token) ]
198+
ret i32 %y
199+
}
200+
201+
define void @test_token_not_first() convergent {
202+
; CHECK-LABEL: @test_token_not_first(
203+
; CHECK-NEXT: [[TOKEN:%.*]] = call token @llvm.experimental.convergence.entry()
204+
; CHECK-NEXT: {{%.*}} = call i32 @g(i32 23) [ "convergencectrl"(token [[TOKEN]]) ]
205+
; CHECK-NEXT: ret void
206+
;
207+
%token = call token @llvm.experimental.convergence.entry()
208+
%x = call i32 @token_not_first(i32 23) [ "convergencectrl"(token %token) ]
209+
ret void
210+
}
211+
188212
declare void @f(i32) convergent
189213
declare i32 @g(i32) convergent
190214

0 commit comments

Comments
 (0)