Skip to content

[Coroutines] Inline the .noalloc ramp function marked coro_safe_elide #114004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 8, 2024

Conversation

yuxuanchen1997
Copy link
Member

@yuxuanchen1997 yuxuanchen1997 commented Oct 29, 2024

Fixes #114487.

We found that 761bf33 causes some problems with example.

This patch does two things:

  • It backs out 761bf33 because we are actually violating the invariant for a function pass when we redirect the call to the .noalloc variant.
  • Perform inlining on the callee. If the inlining fails, we back out the change.

Copy link

github-actions bot commented Oct 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@yuxuanchen1997 yuxuanchen1997 force-pushed the users/yuxuanchen1997/coro-fix-cgscc-update branch from 15eb5af to e796cc7 Compare October 31, 2024 23:09
@yuxuanchen1997 yuxuanchen1997 force-pushed the users/yuxuanchen1997/coro-fix-cgscc-update branch from e796cc7 to 57b66b8 Compare October 31, 2024 23:37
@yuxuanchen1997 yuxuanchen1997 marked this pull request as ready for review October 31, 2024 23:38
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-coroutines

Author: Yuxuan Chen (yuxuanchen1997)

Changes

Fixes #114487.

We found that 761bf33 causes some problems with example.

This patch does two things:

  • It backs out 761bf33 because we are actually violating the invariant for a function pass when we redirect the call to the .noalloc variant.
  • Perform inlining on the callee. If the inlining fails, we back out the change.

Full diff: https://github.com/llvm/llvm-project/pull/114004.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h (+7-3)
  • (modified) llvm/lib/Passes/PassBuilderPipelines.cpp (+3-3)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1-1)
  • (modified) llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp (+78-55)
  • (added) llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll (+32)
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
index 986a5dbd1ed0fe..352c9e14526697 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
@@ -17,14 +17,18 @@
 #ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
 #define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
 
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
 
-class Function;
-
 struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> {
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
+  CoroAnnotationElidePass() {}
+
+  PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
+                        LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return false; }
 };
 } // end namespace llvm
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 0585e83e59a9ab..f48e5148854c39 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -975,8 +975,7 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
 
   if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
     MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
-    MainCGPipeline.addPass(
-        createCGSCCToFunctionPassAdaptor(CoroAnnotationElidePass()));
+    MainCGPipeline.addPass(CoroAnnotationElidePass());
   }
 
   // Make sure we don't affect potential future NoRerun CGSCC adaptors.
@@ -1027,7 +1026,8 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level,
   if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
     MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
         CoroSplitPass(Level != OptimizationLevel::O0)));
-    MPM.addPass(createModuleToFunctionPassAdaptor(CoroAnnotationElidePass()));
+    MPM.addPass(
+        createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass()));
   }
 
   return MPM;
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 017ae311c55eb4..4994005560472f 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -244,6 +244,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass())
 CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass())
 CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass())
+CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass())
 #undef CGSCC_PASS
 
 #ifndef CGSCC_PASS_WITH_PARAMS
@@ -344,7 +345,6 @@ FUNCTION_PASS("complex-deinterleaving", ComplexDeinterleavingPass(TM))
 FUNCTION_PASS("consthoist", ConstantHoistingPass())
 FUNCTION_PASS("constraint-elimination", ConstraintEliminationPass())
 FUNCTION_PASS("coro-elide", CoroElidePass())
-FUNCTION_PASS("coro-annotation-elide", CoroAnnotationElidePass())
 FUNCTION_PASS("correlated-propagation", CorrelatedValuePropagationPass())
 FUNCTION_PASS("count-visits", CountVisitsPass())
 FUNCTION_PASS("dce", DCEPass())
diff --git a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
index 5f19d600a983aa..9e22d96387fc31 100644
--- a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
@@ -16,6 +16,7 @@
 
 #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
 
+#include "llvm/Analysis/CGSCCPassManager.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/IR/Analysis.h"
@@ -25,6 +26,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
+#include "llvm/Transforms/Utils/Cloning.h"
 
 #include <cassert>
 
@@ -42,10 +44,10 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
 // Create an alloca in the caller, using FrameSize and FrameAlign as the callee
 // coroutine's activation frame.
 static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
-                                    Align FrameAlign) {
+    Align FrameAlign) {
   LLVMContext &C = Caller->getContext();
   BasicBlock::iterator InsertPt =
-      getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
+    getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
   const DataLayout &DL = Caller->getDataLayout();
   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
@@ -59,7 +61,7 @@ static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
 //  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
 //    pointer to the frame as an additional argument to NewCallee.
 static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
-                        uint64_t FrameSize, Align FrameAlign) {
+    uint64_t FrameSize, Align FrameAlign) {
   // TODO: generate the lifetime intrinsics for the new frame. This will require
   // introduction of two pesudo lifetime intrinsics in the frontend around the
   // `co_await` expression and convert them to real lifetime intrinsics here.
@@ -72,13 +74,13 @@ static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
 
   if (auto *CI = dyn_cast<CallInst>(CB)) {
     auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
-                                   NewArgs, "", NewCBInsertPt);
+        NewArgs, "", NewCBInsertPt);
     NewCI->setTailCallKind(CI->getTailCallKind());
     NewCB = NewCI;
   } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
     NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
-                               II->getNormalDest(), II->getUnwindDest(),
-                               NewArgs, {}, "", NewCBInsertPt);
+        II->getNormalDest(), II->getUnwindDest(),
+        NewArgs, {}, "", NewCBInsertPt);
   } else {
     llvm_unreachable("CallBase should either be Call or Invoke!");
   }
@@ -88,65 +90,86 @@ static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
   NewCB->setAttributes(CB->getAttributes());
   NewCB->setDebugLoc(CB->getDebugLoc());
   std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
-            NewCB->bundle_op_info_begin());
+      NewCB->bundle_op_info_begin());
 
   NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
   CB->replaceAllUsesWith(NewCB);
-  CB->eraseFromParent();
+
+  InlineFunctionInfo IFI;
+  InlineResult IR = InlineFunction(*NewCB, IFI);
+  if (IR.isSuccess()) {
+    CB->eraseFromParent();
+  } else {
+    NewCB->replaceAllUsesWith(CB);
+    NewCB->eraseFromParent();
+  }
 }
 
-PreservedAnalyses CoroAnnotationElidePass::run(Function &F,
-                                               FunctionAnalysisManager &FAM) {
+PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
+    CGSCCAnalysisManager &AM,
+    LazyCallGraph &CG,
+    CGSCCUpdateResult &UR) {
   bool Changed = false;
+  CallGraphUpdater CGUpdater;
+  CGUpdater.initialize(CG, C, AM, UR);
 
-  Function *NewCallee =
-      F.getParent()->getFunction((F.getName() + ".noalloc").str());
-
-  if (!NewCallee)
-    return PreservedAnalyses::all();
-
-  auto FramePtrArgPosition = NewCallee->arg_size() - 1;
-  auto FrameSize = NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
-  auto FrameAlign = NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
+  auto &FAM =
+    AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
 
-  SmallVector<CallBase *, 4> Users;
-  for (auto *U : F.users()) {
-    if (auto *CB = dyn_cast<CallBase>(U)) {
-      if (CB->getCalledFunction() == &F)
-        Users.push_back(CB);
-    }
-  }
-
-  auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
-
-  for (auto *CB : Users) {
-    auto *Caller = CB->getFunction();
-    if (!Caller)
+  for (LazyCallGraph::Node &N : C) {
+    Function *Callee = &N.getFunction();
+    Function *NewCallee = Callee->getParent()->getFunction(
+        (Callee->getName() + ".noalloc").str());
+    if (!NewCallee)
       continue;
 
-    bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
-    bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
-    if (IsCallerPresplitCoroutine && HasAttr) {
-      processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
-
-      ORE.emit([&]() {
-        return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
-               << "'" << ore::NV("callee", F.getName()) << "' elided in '"
-               << ore::NV("caller", Caller->getName()) << "'";
-      });
-
-      FAM.invalidate(*Caller, PreservedAnalyses::none());
-      Changed = true;
-    } else {
-      ORE.emit([&]() {
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
-                                        Caller)
-               << "'" << ore::NV("callee", F.getName()) << "' not elided in '"
-               << ore::NV("caller", Caller->getName()) << "' (caller_presplit="
-               << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
-               << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
-               << ")";
-      });
+    SmallVector<CallBase *, 4> Users;
+    for (auto *U : Callee->users()) {
+      if (auto *CB = dyn_cast<CallBase>(U)) {
+        if (CB->getCalledFunction() == Callee)
+          Users.push_back(CB);
+      }
+    }
+    auto FramePtrArgPosition = NewCallee->arg_size() - 1;
+    auto FrameSize = NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
+    auto FrameAlign = NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
+
+    auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
+
+    for (auto *CB : Users) {
+      auto *Caller = CB->getFunction();
+      if (!Caller)
+        continue;
+
+      bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
+      bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
+      if (IsCallerPresplitCoroutine && HasAttr) {
+        auto *CallerN = CG.lookup(*Caller);
+        auto *CallerC = CG.lookupSCC(*CallerN);
+        processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
+
+        ORE.emit([&]() {
+            return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
+            << "'" << ore::NV("callee", Callee->getName()) << "' elided in '"
+            << ore::NV("caller", Caller->getName()) << "'";
+            });
+
+        FAM.invalidate(*Caller, PreservedAnalyses::none());
+        Changed = true;
+        updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
+            FAM);
+
+      } else {
+        ORE.emit([&]() {
+            return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
+                Caller)
+            << "'" << ore::NV("callee", Callee->getName()) << "' not elided in '"
+            << ore::NV("caller", Caller->getName()) << "' (caller_presplit="
+            << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
+            << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
+            << ")";
+            });
+      }
     }
   }
 
diff --git a/llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll b/llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll
new file mode 100644
index 00000000000000..228e722940e18f
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll
@@ -0,0 +1,32 @@
+; Verify that we don't crash when eliding coro_elide_safe callsites.
+; RUN: opt < %s -passes='cgscc(function<>(simplifycfg<>),function-attrs,coro-split,coro-annotation-elide)'  -S | FileCheck %s
+
+; CHECK-LABEL: define void @foo()
+define void @foo() presplitcoroutine personality ptr null {
+entry:
+  %0 = call token @llvm.coro.save(ptr null)
+  br label %branch
+
+branch:
+; Check that we don't call bar at all. 
+; CHECK-NOT: call void @bar{{.*}}
+  call void @bar() coro_elide_safe
+; CHECK: call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr @bar.resumers)
+  ret void
+}
+
+; CHECK-LABEL: define void @bar()
+define void @bar() presplitcoroutine personality ptr null {
+entry:
+  %0 = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+  %1 = call ptr @llvm.coro.begin(token %0, ptr null)
+  %2 = call token @llvm.coro.save(ptr null)
+  %3 = call i8 @llvm.coro.suspend(token none, i1 false)
+  ret void
+}
+
+declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr) nounwind
+declare ptr @llvm.coro.begin(token, ptr writeonly) nounwind
+declare token @llvm.coro.save(ptr) nomerge nounwind
+declare i8 @llvm.coro.suspend(token, i1) nounwind
+

@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yuxuan Chen (yuxuanchen1997)

Changes

Fixes #114487.

We found that 761bf33 causes some problems with example.

This patch does two things:

  • It backs out 761bf33 because we are actually violating the invariant for a function pass when we redirect the call to the .noalloc variant.
  • Perform inlining on the callee. If the inlining fails, we back out the change.

Full diff: https://github.com/llvm/llvm-project/pull/114004.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h (+7-3)
  • (modified) llvm/lib/Passes/PassBuilderPipelines.cpp (+3-3)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1-1)
  • (modified) llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp (+78-55)
  • (added) llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll (+32)
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
index 986a5dbd1ed0fe..352c9e14526697 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroAnnotationElide.h
@@ -17,14 +17,18 @@
 #ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
 #define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
 
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
 
-class Function;
-
 struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> {
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
+  CoroAnnotationElidePass() {}
+
+  PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
+                        LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return false; }
 };
 } // end namespace llvm
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 0585e83e59a9ab..f48e5148854c39 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -975,8 +975,7 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
 
   if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
     MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
-    MainCGPipeline.addPass(
-        createCGSCCToFunctionPassAdaptor(CoroAnnotationElidePass()));
+    MainCGPipeline.addPass(CoroAnnotationElidePass());
   }
 
   // Make sure we don't affect potential future NoRerun CGSCC adaptors.
@@ -1027,7 +1026,8 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level,
   if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
     MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
         CoroSplitPass(Level != OptimizationLevel::O0)));
-    MPM.addPass(createModuleToFunctionPassAdaptor(CoroAnnotationElidePass()));
+    MPM.addPass(
+        createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass()));
   }
 
   return MPM;
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 017ae311c55eb4..4994005560472f 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -244,6 +244,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass())
 CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass())
 CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass())
+CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass())
 #undef CGSCC_PASS
 
 #ifndef CGSCC_PASS_WITH_PARAMS
@@ -344,7 +345,6 @@ FUNCTION_PASS("complex-deinterleaving", ComplexDeinterleavingPass(TM))
 FUNCTION_PASS("consthoist", ConstantHoistingPass())
 FUNCTION_PASS("constraint-elimination", ConstraintEliminationPass())
 FUNCTION_PASS("coro-elide", CoroElidePass())
-FUNCTION_PASS("coro-annotation-elide", CoroAnnotationElidePass())
 FUNCTION_PASS("correlated-propagation", CorrelatedValuePropagationPass())
 FUNCTION_PASS("count-visits", CountVisitsPass())
 FUNCTION_PASS("dce", DCEPass())
diff --git a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
index 5f19d600a983aa..9e22d96387fc31 100644
--- a/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp
@@ -16,6 +16,7 @@
 
 #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
 
+#include "llvm/Analysis/CGSCCPassManager.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/IR/Analysis.h"
@@ -25,6 +26,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
+#include "llvm/Transforms/Utils/Cloning.h"
 
 #include <cassert>
 
@@ -42,10 +44,10 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
 // Create an alloca in the caller, using FrameSize and FrameAlign as the callee
 // coroutine's activation frame.
 static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
-                                    Align FrameAlign) {
+    Align FrameAlign) {
   LLVMContext &C = Caller->getContext();
   BasicBlock::iterator InsertPt =
-      getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
+    getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
   const DataLayout &DL = Caller->getDataLayout();
   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
@@ -59,7 +61,7 @@ static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
 //  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
 //    pointer to the frame as an additional argument to NewCallee.
 static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
-                        uint64_t FrameSize, Align FrameAlign) {
+    uint64_t FrameSize, Align FrameAlign) {
   // TODO: generate the lifetime intrinsics for the new frame. This will require
   // introduction of two pesudo lifetime intrinsics in the frontend around the
   // `co_await` expression and convert them to real lifetime intrinsics here.
@@ -72,13 +74,13 @@ static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
 
   if (auto *CI = dyn_cast<CallInst>(CB)) {
     auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
-                                   NewArgs, "", NewCBInsertPt);
+        NewArgs, "", NewCBInsertPt);
     NewCI->setTailCallKind(CI->getTailCallKind());
     NewCB = NewCI;
   } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
     NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
-                               II->getNormalDest(), II->getUnwindDest(),
-                               NewArgs, {}, "", NewCBInsertPt);
+        II->getNormalDest(), II->getUnwindDest(),
+        NewArgs, {}, "", NewCBInsertPt);
   } else {
     llvm_unreachable("CallBase should either be Call or Invoke!");
   }
@@ -88,65 +90,86 @@ static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
   NewCB->setAttributes(CB->getAttributes());
   NewCB->setDebugLoc(CB->getDebugLoc());
   std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
-            NewCB->bundle_op_info_begin());
+      NewCB->bundle_op_info_begin());
 
   NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
   CB->replaceAllUsesWith(NewCB);
-  CB->eraseFromParent();
+
+  InlineFunctionInfo IFI;
+  InlineResult IR = InlineFunction(*NewCB, IFI);
+  if (IR.isSuccess()) {
+    CB->eraseFromParent();
+  } else {
+    NewCB->replaceAllUsesWith(CB);
+    NewCB->eraseFromParent();
+  }
 }
 
-PreservedAnalyses CoroAnnotationElidePass::run(Function &F,
-                                               FunctionAnalysisManager &FAM) {
+PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
+    CGSCCAnalysisManager &AM,
+    LazyCallGraph &CG,
+    CGSCCUpdateResult &UR) {
   bool Changed = false;
+  CallGraphUpdater CGUpdater;
+  CGUpdater.initialize(CG, C, AM, UR);
 
-  Function *NewCallee =
-      F.getParent()->getFunction((F.getName() + ".noalloc").str());
-
-  if (!NewCallee)
-    return PreservedAnalyses::all();
-
-  auto FramePtrArgPosition = NewCallee->arg_size() - 1;
-  auto FrameSize = NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
-  auto FrameAlign = NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
+  auto &FAM =
+    AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
 
-  SmallVector<CallBase *, 4> Users;
-  for (auto *U : F.users()) {
-    if (auto *CB = dyn_cast<CallBase>(U)) {
-      if (CB->getCalledFunction() == &F)
-        Users.push_back(CB);
-    }
-  }
-
-  auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
-
-  for (auto *CB : Users) {
-    auto *Caller = CB->getFunction();
-    if (!Caller)
+  for (LazyCallGraph::Node &N : C) {
+    Function *Callee = &N.getFunction();
+    Function *NewCallee = Callee->getParent()->getFunction(
+        (Callee->getName() + ".noalloc").str());
+    if (!NewCallee)
       continue;
 
-    bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
-    bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
-    if (IsCallerPresplitCoroutine && HasAttr) {
-      processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
-
-      ORE.emit([&]() {
-        return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
-               << "'" << ore::NV("callee", F.getName()) << "' elided in '"
-               << ore::NV("caller", Caller->getName()) << "'";
-      });
-
-      FAM.invalidate(*Caller, PreservedAnalyses::none());
-      Changed = true;
-    } else {
-      ORE.emit([&]() {
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
-                                        Caller)
-               << "'" << ore::NV("callee", F.getName()) << "' not elided in '"
-               << ore::NV("caller", Caller->getName()) << "' (caller_presplit="
-               << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
-               << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
-               << ")";
-      });
+    SmallVector<CallBase *, 4> Users;
+    for (auto *U : Callee->users()) {
+      if (auto *CB = dyn_cast<CallBase>(U)) {
+        if (CB->getCalledFunction() == Callee)
+          Users.push_back(CB);
+      }
+    }
+    auto FramePtrArgPosition = NewCallee->arg_size() - 1;
+    auto FrameSize = NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
+    auto FrameAlign = NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
+
+    auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
+
+    for (auto *CB : Users) {
+      auto *Caller = CB->getFunction();
+      if (!Caller)
+        continue;
+
+      bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
+      bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
+      if (IsCallerPresplitCoroutine && HasAttr) {
+        auto *CallerN = CG.lookup(*Caller);
+        auto *CallerC = CG.lookupSCC(*CallerN);
+        processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
+
+        ORE.emit([&]() {
+            return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
+            << "'" << ore::NV("callee", Callee->getName()) << "' elided in '"
+            << ore::NV("caller", Caller->getName()) << "'";
+            });
+
+        FAM.invalidate(*Caller, PreservedAnalyses::none());
+        Changed = true;
+        updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
+            FAM);
+
+      } else {
+        ORE.emit([&]() {
+            return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
+                Caller)
+            << "'" << ore::NV("callee", Callee->getName()) << "' not elided in '"
+            << ore::NV("caller", Caller->getName()) << "' (caller_presplit="
+            << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
+            << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
+            << ")";
+            });
+      }
     }
   }
 
diff --git a/llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll b/llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll
new file mode 100644
index 00000000000000..228e722940e18f
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/gh114487-crash-in-cgscc.ll
@@ -0,0 +1,32 @@
+; Verify that we don't crash when eliding coro_elide_safe callsites.
+; RUN: opt < %s -passes='cgscc(function<>(simplifycfg<>),function-attrs,coro-split,coro-annotation-elide)'  -S | FileCheck %s
+
+; CHECK-LABEL: define void @foo()
+define void @foo() presplitcoroutine personality ptr null {
+entry:
+  %0 = call token @llvm.coro.save(ptr null)
+  br label %branch
+
+branch:
+; Check that we don't call bar at all. 
+; CHECK-NOT: call void @bar{{.*}}
+  call void @bar() coro_elide_safe
+; CHECK: call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr @bar.resumers)
+  ret void
+}
+
+; CHECK-LABEL: define void @bar()
+define void @bar() presplitcoroutine personality ptr null {
+entry:
+  %0 = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+  %1 = call ptr @llvm.coro.begin(token %0, ptr null)
+  %2 = call token @llvm.coro.save(ptr null)
+  %3 = call i8 @llvm.coro.suspend(token none, i1 false)
+  ret void
+}
+
+declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr) nounwind
+declare ptr @llvm.coro.begin(token, ptr writeonly) nounwind
+declare token @llvm.coro.save(ptr) nomerge nounwind
+declare i8 @llvm.coro.suspend(token, i1) nounwind
+

@yuxuanchen1997 yuxuanchen1997 force-pushed the users/yuxuanchen1997/coro-fix-cgscc-update branch from 011b265 to e48430e Compare November 1, 2024 18:04
InlineResult IR = InlineFunction(*NewCB, IFI);
if (IR.isSuccess()) {
CB->eraseFromParent();
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for the failed to inline case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Sorry for the delay.

@yuxuanchen1997 yuxuanchen1997 force-pushed the users/yuxuanchen1997/coro-fix-cgscc-update branch from 3e21ebd to 56ca97f Compare November 7, 2024 19:45
@yuxuanchen1997 yuxuanchen1997 merged commit c641497 into main Nov 8, 2024
8 checks passed
@yuxuanchen1997 yuxuanchen1997 deleted the users/yuxuanchen1997/coro-fix-cgscc-update branch November 8, 2024 06:41
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Coroutines] Crash during CoroAnnotationElidePass
3 participants