Skip to content

[Coroutines] Support for Custom ABIs #111755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

TylerNowicki
Copy link
Collaborator

@TylerNowicki TylerNowicki commented Oct 9, 2024

This change extends the current method for creating ABI object to allow users (plugin libraries) to create custom ABI objects for their needs. This is accomplished by inheriting one of the common ABIs and overriding one or more of the methods to create a custom ABI. To use a custom ABI for a given coroutine the coro.begin.custom.abi intrinsic is used in place of the coro.begin intrinsic. This takes an additional i32 arg that specifies the index of an ABI generator for the custom ABI object in a SmallVector passed to the CoroSplitPass ctor.

The detailed changes include:

  • Add the llvm.coro.begin.custom intrinsic used to specify the index of the custom ABI to use for the given coroutine.
  • Add constructors to CoroSplit that take a list of generators that create the custom ABI object.
  • Extend the CreateNewABI function used by CoroSplit to return a unique_ptr to an ABI object.
  • Add has/getCustomABI methods to CoroBeginInst class.
  • Add a unittest for a custom ABI.

See doc update here: #111781

* Add the llvm.coro.begin.custom intrinsic used to specify the index of
  the custom ABI to use for the given coroutine.
* Add constructors to CoroSplit that take a list of generators that
  create the custom ABI object.
* Extend the CreateNewABI function used by CoroSplit to return a
  unique_ptr to an ABI object.
* Add has/getCustomABI methods to CoroBeginInst class.
* Add a unittest for a custom ABI.
@llvmbot llvmbot added coroutines C++20 coroutines llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Oct 9, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Tyler Nowicki (TylerNowicki)

Changes

This change extends the current method for creating ABI object to allow users (plugin libraries) to create custom ABI objects for their needs. This is accomplished by inheriting one of the common ABIs and overriding one or more of the methods to create a custom ABI. To use a custom ABI for a given coroutine the coro.begin.custom.abi intrinsic is used in place of the coro.begin intrinsic. This takes an additional i32 arg that specifies the index of an ABI generator for the custom ABI object in a SmallVector passed to CoroSplitPass ctor.

The detailed changes include:

  • Add the llvm.coro.begin.custom intrinsic used to specify the index of the custom ABI to use for the given coroutine.
  • Add constructors to CoroSplit that take a list of generators that create the custom ABI object.
  • Extend the CreateNewABI function used by CoroSplit to return a unique_ptr to an ABI object.
  • Add has/getCustomABI methods to CoroBeginInst class.
  • Add a unittest for a custom ABI.

Full diff: https://github.com/llvm/llvm-project/pull/111755.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+2-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/ABI.h (+7-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroInstr.h (+15-4)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroSplit.h (+11-2)
  • (modified) llvm/lib/Transforms/Coroutines/CoroCleanup.cpp (+3-1)
  • (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+35-3)
  • (modified) llvm/lib/Transforms/Coroutines/Coroutines.cpp (+3-1)
  • (modified) llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (+87)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01a16e7c7b1e59..f6888d001fed69 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -778,6 +778,7 @@ class TargetTransformInfoImplBase {
     case Intrinsic::experimental_gc_relocate:
     case Intrinsic::coro_alloc:
     case Intrinsic::coro_begin:
+    case Intrinsic::coro_begin_custom_abi:
     case Intrinsic::coro_free:
     case Intrinsic::coro_end:
     case Intrinsic::coro_frame:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 20dd921ddbd230..8a0721cf23f538 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1719,7 +1719,8 @@ def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty],
                                        [IntrNoMem]>;
 def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                                [WriteOnly<ArgIndex<1>>]>;
-
+def int_coro_begin_custom_abi : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty, llvm_i32_ty],
+                               [WriteOnly<ArgIndex<1>>]>;
 def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                               [IntrReadMem, IntrArgMemOnly,
                                ReadOnly<ArgIndex<1>>,
diff --git a/llvm/include/llvm/Transforms/Coroutines/ABI.h b/llvm/include/llvm/Transforms/Coroutines/ABI.h
index e7568d275c1615..8b83c5308056eb 100644
--- a/llvm/include/llvm/Transforms/Coroutines/ABI.h
+++ b/llvm/include/llvm/Transforms/Coroutines/ABI.h
@@ -29,7 +29,13 @@ namespace coro {
 // This interface/API is to provide an object oriented way to implement ABI
 // functionality. This is intended to replace use of the ABI enum to perform
 // ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
-// ABIs.
+// ABIs. However, specific users may need to modify the behavior of these. This
+// can be accomplished by inheriting one of the common ABIs and overriding one
+// or more of the methods to create a custom ABI. To use a custom ABI for a
+// given coroutine the coro.begin.custom.abi intrinsic is used in place of the
+// coro.begin intrinsic. This takes an additional i32 arg that specifies the
+// index of an ABI generator for the custom ABI object in a SmallVector passed
+// to CoroSplitPass ctor.
 
 class BaseABI {
 public:
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
index a329a06bf13891..3aa30bec85c3a5 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
@@ -124,7 +124,8 @@ class AnyCoroIdInst : public IntrinsicInst {
   IntrinsicInst *getCoroBegin() {
     for (User *U : users())
       if (auto *II = dyn_cast<IntrinsicInst>(U))
-        if (II->getIntrinsicID() == Intrinsic::coro_begin)
+        if (II->getIntrinsicID() == Intrinsic::coro_begin ||
+            II->getIntrinsicID() == Intrinsic::coro_begin_custom_abi)
           return II;
     llvm_unreachable("no coro.begin associated with coro.id");
   }
@@ -442,20 +443,30 @@ class CoroFreeInst : public IntrinsicInst {
   }
 };
 
-/// This class represents the llvm.coro.begin instructions.
+/// This class represents the llvm.coro.begin or llvm.coro.begin.custom.abi
+/// instructions.
 class CoroBeginInst : public IntrinsicInst {
-  enum { IdArg, MemArg };
+  enum { IdArg, MemArg, CustomABIArg };
 
 public:
   AnyCoroIdInst *getId() const {
     return cast<AnyCoroIdInst>(getArgOperand(IdArg));
   }
 
+  bool hasCustomABI() const {
+    return getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
+  }
+
+  int getCustomABI() const {
+    return cast<ConstantInt>(getArgOperand(CustomABIArg))->getZExtValue();
+  }
+
   Value *getMem() const { return getArgOperand(MemArg); }
 
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::coro_begin;
+    return I->getIntrinsicID() == Intrinsic::coro_begin ||
+           I->getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
   }
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
index a5fd57f8f9dfab..6c6a982e828050 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
@@ -28,17 +28,26 @@ struct Shape;
 } // namespace coro
 
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
+  using BaseABITy =
+      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
 
   CoroSplitPass(bool OptimizeFrame = false);
+
+  CoroSplitPass(SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
                 bool OptimizeFrame = false);
 
+  CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+                SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return true; }
 
-  using BaseABITy =
-      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
   // Generator for an ABI transformer
   BaseABITy CreateAndInitABI;
 
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index dd92b3593af92e..1cda7f93f72a2c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -53,6 +53,7 @@ bool Lowerer::lower(Function &F) {
       default:
         continue;
       case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi:
         II->replaceAllUsesWith(II->getArgOperand(1));
         break;
       case Intrinsic::coro_free:
@@ -112,7 +113,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) {
       M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr",
           "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon",
           "llvm.coro.id.async", "llvm.coro.id.retcon.once",
-          "llvm.coro.async.size.replace", "llvm.coro.async.resume"});
+          "llvm.coro.async.size.replace", "llvm.coro.async.resume",
+          "llvm.coro.begin.custom.abi"});
 }
 
 PreservedAnalyses CoroCleanupPass::run(Module &M,
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index ef1f27118bc14b..88ce331c8cfb64 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2200,7 +2200,15 @@ static void addPrepareFunction(const Module &M,
 
 static std::unique_ptr<coro::BaseABI>
 CreateNewABI(Function &F, coro::Shape &S,
-             std::function<bool(Instruction &)> IsMatCallback) {
+             std::function<bool(Instruction &)> IsMatCallback,
+             const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs) {
+  if (S.CoroBegin->hasCustomABI()) {
+    unsigned CustomABI = S.CoroBegin->getCustomABI();
+    if (CustomABI >= GenCustomABIs.size())
+      llvm_unreachable("Custom ABI not found amoung those specified");
+    return GenCustomABIs[CustomABI](F, S);
+  }
+
   switch (S.ABI) {
   case coro::ABI::Switch:
     return std::unique_ptr<coro::BaseABI>(
@@ -2221,7 +2229,17 @@ CreateNewABI(Function &F, coro::Shape &S,
 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
     : CreateAndInitABI([](Function &F, coro::Shape &S) {
         std::unique_ptr<coro::BaseABI> ABI =
-            CreateNewABI(F, S, coro::isTriviallyMaterializable);
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+CoroSplitPass::CoroSplitPass(
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
@@ -2232,7 +2250,21 @@ CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
 CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
                              bool OptimizeFrame)
     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
-        std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+// For back compatibility, constructor takes a materializable callback and
+// creates a generator for an ABI with a modified materializable callback.
+CoroSplitPass::CoroSplitPass(
+    std::function<bool(Instruction &)> IsMatCallback,
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index f4d9a7a8aa8569..1c45bcd7f6a837 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -73,6 +73,7 @@ static const char *const CoroIntrinsics[] = {
     "llvm.coro.await.suspend.handle",
     "llvm.coro.await.suspend.void",
     "llvm.coro.begin",
+    "llvm.coro.begin.custom.abi",
     "llvm.coro.destroy",
     "llvm.coro.done",
     "llvm.coro.end",
@@ -247,7 +248,8 @@ void coro::Shape::analyze(Function &F,
         }
         break;
       }
-      case Intrinsic::coro_begin: {
+      case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi: {
         auto CB = cast<CoroBeginInst>(II);
 
         // Ignore coro id's that aren't pre-split.
diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
index 1d55889a32d7aa..c3394fdaa940ba 100644
--- a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
+++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
@@ -182,4 +182,91 @@ TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
   CallInst *CI = getCallByName(Resume1, "should.remat");
   ASSERT_TRUE(CI);
 }
+
+StringRef TextCoroBeginCustomABI = R"(
+    define ptr @f(i32 %n) presplitcoroutine {
+    entry:
+      %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+      %size = call i32 @llvm.coro.size.i32()
+      %alloc = call ptr @malloc(i32 %size)
+      %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
+
+      %inc1 = add i32 %n, 1
+      %val2 = call i32 @should.remat(i32 %inc1)
+      %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume1
+                                      i8 1, label %cleanup]
+    resume1:
+      %inc2 = add i32 %val2, 1
+      %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume2
+                                      i8 1, label %cleanup]
+
+    resume2:
+      call void @print(i32 %val2)
+      call void @print(i32 %inc2)
+      br label %cleanup
+
+    cleanup:
+      %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+      call void @free(ptr %mem)
+      br label %suspend
+    suspend:
+      call i1 @llvm.coro.end(ptr %hdl, i1 0)
+      ret ptr %hdl
+    }
+
+    declare ptr @llvm.coro.free(token, ptr)
+    declare i32 @llvm.coro.size.i32()
+    declare i8  @llvm.coro.suspend(token, i1)
+    declare void @llvm.coro.resume(ptr)
+    declare void @llvm.coro.destroy(ptr)
+
+    declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+    declare i1 @llvm.coro.alloc(token)
+    declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
+    declare i1 @llvm.coro.end(ptr, i1)
+
+    declare i32 @should.remat(i32)
+
+    declare noalias ptr @malloc(i32)
+    declare void @print(i32)
+    declare void @free(ptr)
+  )";
+
+// SwitchABI with overridden isMaterializable
+class ExtraCustomABI : public coro::SwitchABI {
+public:
+  ExtraCustomABI(Function &F, coro::Shape &S)
+      : coro::SwitchABI(F, S, ExtraMaterializable) {}
+};
+
+TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
+  ParseAssembly(TextCoroBeginCustomABI);
+
+  ASSERT_TRUE(M);
+
+  CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
+    return std::unique_ptr<coro::BaseABI>(new ExtraCustomABI(F, S));
+  };
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(CoroSplitPass({GenCustomABI}));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With callback the extra rematerialization of the function should have
+  // happened
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_TRUE(CI);
+}
+
 } // namespace

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2024

@llvm/pr-subscribers-llvm-ir

Author: Tyler Nowicki (TylerNowicki)

Changes

This change extends the current method for creating ABI object to allow users (plugin libraries) to create custom ABI objects for their needs. This is accomplished by inheriting one of the common ABIs and overriding one or more of the methods to create a custom ABI. To use a custom ABI for a given coroutine the coro.begin.custom.abi intrinsic is used in place of the coro.begin intrinsic. This takes an additional i32 arg that specifies the index of an ABI generator for the custom ABI object in a SmallVector passed to CoroSplitPass ctor.

The detailed changes include:

  • Add the llvm.coro.begin.custom intrinsic used to specify the index of the custom ABI to use for the given coroutine.
  • Add constructors to CoroSplit that take a list of generators that create the custom ABI object.
  • Extend the CreateNewABI function used by CoroSplit to return a unique_ptr to an ABI object.
  • Add has/getCustomABI methods to CoroBeginInst class.
  • Add a unittest for a custom ABI.

Full diff: https://github.com/llvm/llvm-project/pull/111755.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+2-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/ABI.h (+7-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroInstr.h (+15-4)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroSplit.h (+11-2)
  • (modified) llvm/lib/Transforms/Coroutines/CoroCleanup.cpp (+3-1)
  • (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+35-3)
  • (modified) llvm/lib/Transforms/Coroutines/Coroutines.cpp (+3-1)
  • (modified) llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (+87)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01a16e7c7b1e59..f6888d001fed69 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -778,6 +778,7 @@ class TargetTransformInfoImplBase {
     case Intrinsic::experimental_gc_relocate:
     case Intrinsic::coro_alloc:
     case Intrinsic::coro_begin:
+    case Intrinsic::coro_begin_custom_abi:
     case Intrinsic::coro_free:
     case Intrinsic::coro_end:
     case Intrinsic::coro_frame:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 20dd921ddbd230..8a0721cf23f538 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1719,7 +1719,8 @@ def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty],
                                        [IntrNoMem]>;
 def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                                [WriteOnly<ArgIndex<1>>]>;
-
+def int_coro_begin_custom_abi : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty, llvm_i32_ty],
+                               [WriteOnly<ArgIndex<1>>]>;
 def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                               [IntrReadMem, IntrArgMemOnly,
                                ReadOnly<ArgIndex<1>>,
diff --git a/llvm/include/llvm/Transforms/Coroutines/ABI.h b/llvm/include/llvm/Transforms/Coroutines/ABI.h
index e7568d275c1615..8b83c5308056eb 100644
--- a/llvm/include/llvm/Transforms/Coroutines/ABI.h
+++ b/llvm/include/llvm/Transforms/Coroutines/ABI.h
@@ -29,7 +29,13 @@ namespace coro {
 // This interface/API is to provide an object oriented way to implement ABI
 // functionality. This is intended to replace use of the ABI enum to perform
 // ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
-// ABIs.
+// ABIs. However, specific users may need to modify the behavior of these. This
+// can be accomplished by inheriting one of the common ABIs and overriding one
+// or more of the methods to create a custom ABI. To use a custom ABI for a
+// given coroutine the coro.begin.custom.abi intrinsic is used in place of the
+// coro.begin intrinsic. This takes an additional i32 arg that specifies the
+// index of an ABI generator for the custom ABI object in a SmallVector passed
+// to CoroSplitPass ctor.
 
 class BaseABI {
 public:
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
index a329a06bf13891..3aa30bec85c3a5 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
@@ -124,7 +124,8 @@ class AnyCoroIdInst : public IntrinsicInst {
   IntrinsicInst *getCoroBegin() {
     for (User *U : users())
       if (auto *II = dyn_cast<IntrinsicInst>(U))
-        if (II->getIntrinsicID() == Intrinsic::coro_begin)
+        if (II->getIntrinsicID() == Intrinsic::coro_begin ||
+            II->getIntrinsicID() == Intrinsic::coro_begin_custom_abi)
           return II;
     llvm_unreachable("no coro.begin associated with coro.id");
   }
@@ -442,20 +443,30 @@ class CoroFreeInst : public IntrinsicInst {
   }
 };
 
-/// This class represents the llvm.coro.begin instructions.
+/// This class represents the llvm.coro.begin or llvm.coro.begin.custom.abi
+/// instructions.
 class CoroBeginInst : public IntrinsicInst {
-  enum { IdArg, MemArg };
+  enum { IdArg, MemArg, CustomABIArg };
 
 public:
   AnyCoroIdInst *getId() const {
     return cast<AnyCoroIdInst>(getArgOperand(IdArg));
   }
 
+  bool hasCustomABI() const {
+    return getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
+  }
+
+  int getCustomABI() const {
+    return cast<ConstantInt>(getArgOperand(CustomABIArg))->getZExtValue();
+  }
+
   Value *getMem() const { return getArgOperand(MemArg); }
 
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::coro_begin;
+    return I->getIntrinsicID() == Intrinsic::coro_begin ||
+           I->getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
   }
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
index a5fd57f8f9dfab..6c6a982e828050 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
@@ -28,17 +28,26 @@ struct Shape;
 } // namespace coro
 
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
+  using BaseABITy =
+      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
 
   CoroSplitPass(bool OptimizeFrame = false);
+
+  CoroSplitPass(SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
                 bool OptimizeFrame = false);
 
+  CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+                SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return true; }
 
-  using BaseABITy =
-      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
   // Generator for an ABI transformer
   BaseABITy CreateAndInitABI;
 
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index dd92b3593af92e..1cda7f93f72a2c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -53,6 +53,7 @@ bool Lowerer::lower(Function &F) {
       default:
         continue;
       case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi:
         II->replaceAllUsesWith(II->getArgOperand(1));
         break;
       case Intrinsic::coro_free:
@@ -112,7 +113,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) {
       M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr",
           "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon",
           "llvm.coro.id.async", "llvm.coro.id.retcon.once",
-          "llvm.coro.async.size.replace", "llvm.coro.async.resume"});
+          "llvm.coro.async.size.replace", "llvm.coro.async.resume",
+          "llvm.coro.begin.custom.abi"});
 }
 
 PreservedAnalyses CoroCleanupPass::run(Module &M,
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index ef1f27118bc14b..88ce331c8cfb64 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2200,7 +2200,15 @@ static void addPrepareFunction(const Module &M,
 
 static std::unique_ptr<coro::BaseABI>
 CreateNewABI(Function &F, coro::Shape &S,
-             std::function<bool(Instruction &)> IsMatCallback) {
+             std::function<bool(Instruction &)> IsMatCallback,
+             const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs) {
+  if (S.CoroBegin->hasCustomABI()) {
+    unsigned CustomABI = S.CoroBegin->getCustomABI();
+    if (CustomABI >= GenCustomABIs.size())
+      llvm_unreachable("Custom ABI not found amoung those specified");
+    return GenCustomABIs[CustomABI](F, S);
+  }
+
   switch (S.ABI) {
   case coro::ABI::Switch:
     return std::unique_ptr<coro::BaseABI>(
@@ -2221,7 +2229,17 @@ CreateNewABI(Function &F, coro::Shape &S,
 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
     : CreateAndInitABI([](Function &F, coro::Shape &S) {
         std::unique_ptr<coro::BaseABI> ABI =
-            CreateNewABI(F, S, coro::isTriviallyMaterializable);
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+CoroSplitPass::CoroSplitPass(
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
@@ -2232,7 +2250,21 @@ CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
 CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
                              bool OptimizeFrame)
     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
-        std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+// For back compatibility, constructor takes a materializable callback and
+// creates a generator for an ABI with a modified materializable callback.
+CoroSplitPass::CoroSplitPass(
+    std::function<bool(Instruction &)> IsMatCallback,
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index f4d9a7a8aa8569..1c45bcd7f6a837 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -73,6 +73,7 @@ static const char *const CoroIntrinsics[] = {
     "llvm.coro.await.suspend.handle",
     "llvm.coro.await.suspend.void",
     "llvm.coro.begin",
+    "llvm.coro.begin.custom.abi",
     "llvm.coro.destroy",
     "llvm.coro.done",
     "llvm.coro.end",
@@ -247,7 +248,8 @@ void coro::Shape::analyze(Function &F,
         }
         break;
       }
-      case Intrinsic::coro_begin: {
+      case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi: {
         auto CB = cast<CoroBeginInst>(II);
 
         // Ignore coro id's that aren't pre-split.
diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
index 1d55889a32d7aa..c3394fdaa940ba 100644
--- a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
+++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
@@ -182,4 +182,91 @@ TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
   CallInst *CI = getCallByName(Resume1, "should.remat");
   ASSERT_TRUE(CI);
 }
+
+StringRef TextCoroBeginCustomABI = R"(
+    define ptr @f(i32 %n) presplitcoroutine {
+    entry:
+      %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+      %size = call i32 @llvm.coro.size.i32()
+      %alloc = call ptr @malloc(i32 %size)
+      %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
+
+      %inc1 = add i32 %n, 1
+      %val2 = call i32 @should.remat(i32 %inc1)
+      %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume1
+                                      i8 1, label %cleanup]
+    resume1:
+      %inc2 = add i32 %val2, 1
+      %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume2
+                                      i8 1, label %cleanup]
+
+    resume2:
+      call void @print(i32 %val2)
+      call void @print(i32 %inc2)
+      br label %cleanup
+
+    cleanup:
+      %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+      call void @free(ptr %mem)
+      br label %suspend
+    suspend:
+      call i1 @llvm.coro.end(ptr %hdl, i1 0)
+      ret ptr %hdl
+    }
+
+    declare ptr @llvm.coro.free(token, ptr)
+    declare i32 @llvm.coro.size.i32()
+    declare i8  @llvm.coro.suspend(token, i1)
+    declare void @llvm.coro.resume(ptr)
+    declare void @llvm.coro.destroy(ptr)
+
+    declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+    declare i1 @llvm.coro.alloc(token)
+    declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
+    declare i1 @llvm.coro.end(ptr, i1)
+
+    declare i32 @should.remat(i32)
+
+    declare noalias ptr @malloc(i32)
+    declare void @print(i32)
+    declare void @free(ptr)
+  )";
+
+// SwitchABI with overridden isMaterializable
+class ExtraCustomABI : public coro::SwitchABI {
+public:
+  ExtraCustomABI(Function &F, coro::Shape &S)
+      : coro::SwitchABI(F, S, ExtraMaterializable) {}
+};
+
+TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
+  ParseAssembly(TextCoroBeginCustomABI);
+
+  ASSERT_TRUE(M);
+
+  CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
+    return std::unique_ptr<coro::BaseABI>(new ExtraCustomABI(F, S));
+  };
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(CoroSplitPass({GenCustomABI}));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With callback the extra rematerialization of the function should have
+  // happened
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_TRUE(CI);
+}
+
 } // namespace

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2024

@llvm/pr-subscribers-coroutines

Author: Tyler Nowicki (TylerNowicki)

Changes

This change extends the current method for creating ABI object to allow users (plugin libraries) to create custom ABI objects for their needs. This is accomplished by inheriting one of the common ABIs and overriding one or more of the methods to create a custom ABI. To use a custom ABI for a given coroutine the coro.begin.custom.abi intrinsic is used in place of the coro.begin intrinsic. This takes an additional i32 arg that specifies the index of an ABI generator for the custom ABI object in a SmallVector passed to CoroSplitPass ctor.

The detailed changes include:

  • Add the llvm.coro.begin.custom intrinsic used to specify the index of the custom ABI to use for the given coroutine.
  • Add constructors to CoroSplit that take a list of generators that create the custom ABI object.
  • Extend the CreateNewABI function used by CoroSplit to return a unique_ptr to an ABI object.
  • Add has/getCustomABI methods to CoroBeginInst class.
  • Add a unittest for a custom ABI.

Full diff: https://github.com/llvm/llvm-project/pull/111755.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+2-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/ABI.h (+7-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroInstr.h (+15-4)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroSplit.h (+11-2)
  • (modified) llvm/lib/Transforms/Coroutines/CoroCleanup.cpp (+3-1)
  • (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+35-3)
  • (modified) llvm/lib/Transforms/Coroutines/Coroutines.cpp (+3-1)
  • (modified) llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (+87)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01a16e7c7b1e59..f6888d001fed69 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -778,6 +778,7 @@ class TargetTransformInfoImplBase {
     case Intrinsic::experimental_gc_relocate:
     case Intrinsic::coro_alloc:
     case Intrinsic::coro_begin:
+    case Intrinsic::coro_begin_custom_abi:
     case Intrinsic::coro_free:
     case Intrinsic::coro_end:
     case Intrinsic::coro_frame:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 20dd921ddbd230..8a0721cf23f538 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1719,7 +1719,8 @@ def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty],
                                        [IntrNoMem]>;
 def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                                [WriteOnly<ArgIndex<1>>]>;
-
+def int_coro_begin_custom_abi : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty, llvm_i32_ty],
+                               [WriteOnly<ArgIndex<1>>]>;
 def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                               [IntrReadMem, IntrArgMemOnly,
                                ReadOnly<ArgIndex<1>>,
diff --git a/llvm/include/llvm/Transforms/Coroutines/ABI.h b/llvm/include/llvm/Transforms/Coroutines/ABI.h
index e7568d275c1615..8b83c5308056eb 100644
--- a/llvm/include/llvm/Transforms/Coroutines/ABI.h
+++ b/llvm/include/llvm/Transforms/Coroutines/ABI.h
@@ -29,7 +29,13 @@ namespace coro {
 // This interface/API is to provide an object oriented way to implement ABI
 // functionality. This is intended to replace use of the ABI enum to perform
 // ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
-// ABIs.
+// ABIs. However, specific users may need to modify the behavior of these. This
+// can be accomplished by inheriting one of the common ABIs and overriding one
+// or more of the methods to create a custom ABI. To use a custom ABI for a
+// given coroutine the coro.begin.custom.abi intrinsic is used in place of the
+// coro.begin intrinsic. This takes an additional i32 arg that specifies the
+// index of an ABI generator for the custom ABI object in a SmallVector passed
+// to CoroSplitPass ctor.
 
 class BaseABI {
 public:
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
index a329a06bf13891..3aa30bec85c3a5 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
@@ -124,7 +124,8 @@ class AnyCoroIdInst : public IntrinsicInst {
   IntrinsicInst *getCoroBegin() {
     for (User *U : users())
       if (auto *II = dyn_cast<IntrinsicInst>(U))
-        if (II->getIntrinsicID() == Intrinsic::coro_begin)
+        if (II->getIntrinsicID() == Intrinsic::coro_begin ||
+            II->getIntrinsicID() == Intrinsic::coro_begin_custom_abi)
           return II;
     llvm_unreachable("no coro.begin associated with coro.id");
   }
@@ -442,20 +443,30 @@ class CoroFreeInst : public IntrinsicInst {
   }
 };
 
-/// This class represents the llvm.coro.begin instructions.
+/// This class represents the llvm.coro.begin or llvm.coro.begin.custom.abi
+/// instructions.
 class CoroBeginInst : public IntrinsicInst {
-  enum { IdArg, MemArg };
+  enum { IdArg, MemArg, CustomABIArg };
 
 public:
   AnyCoroIdInst *getId() const {
     return cast<AnyCoroIdInst>(getArgOperand(IdArg));
   }
 
+  bool hasCustomABI() const {
+    return getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
+  }
+
+  int getCustomABI() const {
+    return cast<ConstantInt>(getArgOperand(CustomABIArg))->getZExtValue();
+  }
+
   Value *getMem() const { return getArgOperand(MemArg); }
 
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::coro_begin;
+    return I->getIntrinsicID() == Intrinsic::coro_begin ||
+           I->getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
   }
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
index a5fd57f8f9dfab..6c6a982e828050 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
@@ -28,17 +28,26 @@ struct Shape;
 } // namespace coro
 
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
+  using BaseABITy =
+      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
 
   CoroSplitPass(bool OptimizeFrame = false);
+
+  CoroSplitPass(SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
                 bool OptimizeFrame = false);
 
+  CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+                SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return true; }
 
-  using BaseABITy =
-      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
   // Generator for an ABI transformer
   BaseABITy CreateAndInitABI;
 
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index dd92b3593af92e..1cda7f93f72a2c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -53,6 +53,7 @@ bool Lowerer::lower(Function &F) {
       default:
         continue;
       case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi:
         II->replaceAllUsesWith(II->getArgOperand(1));
         break;
       case Intrinsic::coro_free:
@@ -112,7 +113,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) {
       M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr",
           "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon",
           "llvm.coro.id.async", "llvm.coro.id.retcon.once",
-          "llvm.coro.async.size.replace", "llvm.coro.async.resume"});
+          "llvm.coro.async.size.replace", "llvm.coro.async.resume",
+          "llvm.coro.begin.custom.abi"});
 }
 
 PreservedAnalyses CoroCleanupPass::run(Module &M,
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index ef1f27118bc14b..88ce331c8cfb64 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2200,7 +2200,15 @@ static void addPrepareFunction(const Module &M,
 
 static std::unique_ptr<coro::BaseABI>
 CreateNewABI(Function &F, coro::Shape &S,
-             std::function<bool(Instruction &)> IsMatCallback) {
+             std::function<bool(Instruction &)> IsMatCallback,
+             const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs) {
+  if (S.CoroBegin->hasCustomABI()) {
+    unsigned CustomABI = S.CoroBegin->getCustomABI();
+    if (CustomABI >= GenCustomABIs.size())
+      llvm_unreachable("Custom ABI not found amoung those specified");
+    return GenCustomABIs[CustomABI](F, S);
+  }
+
   switch (S.ABI) {
   case coro::ABI::Switch:
     return std::unique_ptr<coro::BaseABI>(
@@ -2221,7 +2229,17 @@ CreateNewABI(Function &F, coro::Shape &S,
 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
     : CreateAndInitABI([](Function &F, coro::Shape &S) {
         std::unique_ptr<coro::BaseABI> ABI =
-            CreateNewABI(F, S, coro::isTriviallyMaterializable);
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+CoroSplitPass::CoroSplitPass(
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
@@ -2232,7 +2250,21 @@ CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
 CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
                              bool OptimizeFrame)
     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
-        std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+// For back compatibility, constructor takes a materializable callback and
+// creates a generator for an ABI with a modified materializable callback.
+CoroSplitPass::CoroSplitPass(
+    std::function<bool(Instruction &)> IsMatCallback,
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index f4d9a7a8aa8569..1c45bcd7f6a837 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -73,6 +73,7 @@ static const char *const CoroIntrinsics[] = {
     "llvm.coro.await.suspend.handle",
     "llvm.coro.await.suspend.void",
     "llvm.coro.begin",
+    "llvm.coro.begin.custom.abi",
     "llvm.coro.destroy",
     "llvm.coro.done",
     "llvm.coro.end",
@@ -247,7 +248,8 @@ void coro::Shape::analyze(Function &F,
         }
         break;
       }
-      case Intrinsic::coro_begin: {
+      case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi: {
         auto CB = cast<CoroBeginInst>(II);
 
         // Ignore coro id's that aren't pre-split.
diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
index 1d55889a32d7aa..c3394fdaa940ba 100644
--- a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
+++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
@@ -182,4 +182,91 @@ TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
   CallInst *CI = getCallByName(Resume1, "should.remat");
   ASSERT_TRUE(CI);
 }
+
+StringRef TextCoroBeginCustomABI = R"(
+    define ptr @f(i32 %n) presplitcoroutine {
+    entry:
+      %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+      %size = call i32 @llvm.coro.size.i32()
+      %alloc = call ptr @malloc(i32 %size)
+      %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
+
+      %inc1 = add i32 %n, 1
+      %val2 = call i32 @should.remat(i32 %inc1)
+      %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume1
+                                      i8 1, label %cleanup]
+    resume1:
+      %inc2 = add i32 %val2, 1
+      %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume2
+                                      i8 1, label %cleanup]
+
+    resume2:
+      call void @print(i32 %val2)
+      call void @print(i32 %inc2)
+      br label %cleanup
+
+    cleanup:
+      %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+      call void @free(ptr %mem)
+      br label %suspend
+    suspend:
+      call i1 @llvm.coro.end(ptr %hdl, i1 0)
+      ret ptr %hdl
+    }
+
+    declare ptr @llvm.coro.free(token, ptr)
+    declare i32 @llvm.coro.size.i32()
+    declare i8  @llvm.coro.suspend(token, i1)
+    declare void @llvm.coro.resume(ptr)
+    declare void @llvm.coro.destroy(ptr)
+
+    declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+    declare i1 @llvm.coro.alloc(token)
+    declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
+    declare i1 @llvm.coro.end(ptr, i1)
+
+    declare i32 @should.remat(i32)
+
+    declare noalias ptr @malloc(i32)
+    declare void @print(i32)
+    declare void @free(ptr)
+  )";
+
+// SwitchABI with overridden isMaterializable
+class ExtraCustomABI : public coro::SwitchABI {
+public:
+  ExtraCustomABI(Function &F, coro::Shape &S)
+      : coro::SwitchABI(F, S, ExtraMaterializable) {}
+};
+
+TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
+  ParseAssembly(TextCoroBeginCustomABI);
+
+  ASSERT_TRUE(M);
+
+  CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
+    return std::unique_ptr<coro::BaseABI>(new ExtraCustomABI(F, S));
+  };
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(CoroSplitPass({GenCustomABI}));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With callback the extra rematerialization of the function should have
+  // happened
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_TRUE(CI);
+}
+
 } // namespace

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Tyler Nowicki (TylerNowicki)

Changes

This change extends the current method for creating ABI object to allow users (plugin libraries) to create custom ABI objects for their needs. This is accomplished by inheriting one of the common ABIs and overriding one or more of the methods to create a custom ABI. To use a custom ABI for a given coroutine the coro.begin.custom.abi intrinsic is used in place of the coro.begin intrinsic. This takes an additional i32 arg that specifies the index of an ABI generator for the custom ABI object in a SmallVector passed to CoroSplitPass ctor.

The detailed changes include:

  • Add the llvm.coro.begin.custom intrinsic used to specify the index of the custom ABI to use for the given coroutine.
  • Add constructors to CoroSplit that take a list of generators that create the custom ABI object.
  • Extend the CreateNewABI function used by CoroSplit to return a unique_ptr to an ABI object.
  • Add has/getCustomABI methods to CoroBeginInst class.
  • Add a unittest for a custom ABI.

Full diff: https://github.com/llvm/llvm-project/pull/111755.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+2-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/ABI.h (+7-1)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroInstr.h (+15-4)
  • (modified) llvm/include/llvm/Transforms/Coroutines/CoroSplit.h (+11-2)
  • (modified) llvm/lib/Transforms/Coroutines/CoroCleanup.cpp (+3-1)
  • (modified) llvm/lib/Transforms/Coroutines/CoroSplit.cpp (+35-3)
  • (modified) llvm/lib/Transforms/Coroutines/Coroutines.cpp (+3-1)
  • (modified) llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp (+87)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 01a16e7c7b1e59..f6888d001fed69 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -778,6 +778,7 @@ class TargetTransformInfoImplBase {
     case Intrinsic::experimental_gc_relocate:
     case Intrinsic::coro_alloc:
     case Intrinsic::coro_begin:
+    case Intrinsic::coro_begin_custom_abi:
     case Intrinsic::coro_free:
     case Intrinsic::coro_end:
     case Intrinsic::coro_frame:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 20dd921ddbd230..8a0721cf23f538 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1719,7 +1719,8 @@ def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty],
                                        [IntrNoMem]>;
 def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                                [WriteOnly<ArgIndex<1>>]>;
-
+def int_coro_begin_custom_abi : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty, llvm_i32_ty],
+                               [WriteOnly<ArgIndex<1>>]>;
 def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                               [IntrReadMem, IntrArgMemOnly,
                                ReadOnly<ArgIndex<1>>,
diff --git a/llvm/include/llvm/Transforms/Coroutines/ABI.h b/llvm/include/llvm/Transforms/Coroutines/ABI.h
index e7568d275c1615..8b83c5308056eb 100644
--- a/llvm/include/llvm/Transforms/Coroutines/ABI.h
+++ b/llvm/include/llvm/Transforms/Coroutines/ABI.h
@@ -29,7 +29,13 @@ namespace coro {
 // This interface/API is to provide an object oriented way to implement ABI
 // functionality. This is intended to replace use of the ABI enum to perform
 // ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
-// ABIs.
+// ABIs. However, specific users may need to modify the behavior of these. This
+// can be accomplished by inheriting one of the common ABIs and overriding one
+// or more of the methods to create a custom ABI. To use a custom ABI for a
+// given coroutine the coro.begin.custom.abi intrinsic is used in place of the
+// coro.begin intrinsic. This takes an additional i32 arg that specifies the
+// index of an ABI generator for the custom ABI object in a SmallVector passed
+// to CoroSplitPass ctor.
 
 class BaseABI {
 public:
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
index a329a06bf13891..3aa30bec85c3a5 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h
@@ -124,7 +124,8 @@ class AnyCoroIdInst : public IntrinsicInst {
   IntrinsicInst *getCoroBegin() {
     for (User *U : users())
       if (auto *II = dyn_cast<IntrinsicInst>(U))
-        if (II->getIntrinsicID() == Intrinsic::coro_begin)
+        if (II->getIntrinsicID() == Intrinsic::coro_begin ||
+            II->getIntrinsicID() == Intrinsic::coro_begin_custom_abi)
           return II;
     llvm_unreachable("no coro.begin associated with coro.id");
   }
@@ -442,20 +443,30 @@ class CoroFreeInst : public IntrinsicInst {
   }
 };
 
-/// This class represents the llvm.coro.begin instructions.
+/// This class represents the llvm.coro.begin or llvm.coro.begin.custom.abi
+/// instructions.
 class CoroBeginInst : public IntrinsicInst {
-  enum { IdArg, MemArg };
+  enum { IdArg, MemArg, CustomABIArg };
 
 public:
   AnyCoroIdInst *getId() const {
     return cast<AnyCoroIdInst>(getArgOperand(IdArg));
   }
 
+  bool hasCustomABI() const {
+    return getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
+  }
+
+  int getCustomABI() const {
+    return cast<ConstantInt>(getArgOperand(CustomABIArg))->getZExtValue();
+  }
+
   Value *getMem() const { return getArgOperand(MemArg); }
 
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const IntrinsicInst *I) {
-    return I->getIntrinsicID() == Intrinsic::coro_begin;
+    return I->getIntrinsicID() == Intrinsic::coro_begin ||
+           I->getIntrinsicID() == Intrinsic::coro_begin_custom_abi;
   }
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
index a5fd57f8f9dfab..6c6a982e828050 100644
--- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
+++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
@@ -28,17 +28,26 @@ struct Shape;
 } // namespace coro
 
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
+  using BaseABITy =
+      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
 
   CoroSplitPass(bool OptimizeFrame = false);
+
+  CoroSplitPass(SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
                 bool OptimizeFrame = false);
 
+  CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+                SmallVector<BaseABITy> GenCustomABIs,
+                bool OptimizeFrame = false);
+
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
+
   static bool isRequired() { return true; }
 
-  using BaseABITy =
-      std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
   // Generator for an ABI transformer
   BaseABITy CreateAndInitABI;
 
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index dd92b3593af92e..1cda7f93f72a2c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -53,6 +53,7 @@ bool Lowerer::lower(Function &F) {
       default:
         continue;
       case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi:
         II->replaceAllUsesWith(II->getArgOperand(1));
         break;
       case Intrinsic::coro_free:
@@ -112,7 +113,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) {
       M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr",
           "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon",
           "llvm.coro.id.async", "llvm.coro.id.retcon.once",
-          "llvm.coro.async.size.replace", "llvm.coro.async.resume"});
+          "llvm.coro.async.size.replace", "llvm.coro.async.resume",
+          "llvm.coro.begin.custom.abi"});
 }
 
 PreservedAnalyses CoroCleanupPass::run(Module &M,
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index ef1f27118bc14b..88ce331c8cfb64 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2200,7 +2200,15 @@ static void addPrepareFunction(const Module &M,
 
 static std::unique_ptr<coro::BaseABI>
 CreateNewABI(Function &F, coro::Shape &S,
-             std::function<bool(Instruction &)> IsMatCallback) {
+             std::function<bool(Instruction &)> IsMatCallback,
+             const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs) {
+  if (S.CoroBegin->hasCustomABI()) {
+    unsigned CustomABI = S.CoroBegin->getCustomABI();
+    if (CustomABI >= GenCustomABIs.size())
+      llvm_unreachable("Custom ABI not found amoung those specified");
+    return GenCustomABIs[CustomABI](F, S);
+  }
+
   switch (S.ABI) {
   case coro::ABI::Switch:
     return std::unique_ptr<coro::BaseABI>(
@@ -2221,7 +2229,17 @@ CreateNewABI(Function &F, coro::Shape &S,
 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
     : CreateAndInitABI([](Function &F, coro::Shape &S) {
         std::unique_ptr<coro::BaseABI> ABI =
-            CreateNewABI(F, S, coro::isTriviallyMaterializable);
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+CoroSplitPass::CoroSplitPass(
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
@@ -2232,7 +2250,21 @@ CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
 CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
                              bool OptimizeFrame)
     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
-        std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, {});
+        ABI->init();
+        return ABI;
+      }),
+      OptimizeFrame(OptimizeFrame) {}
+
+// For back compatibility, constructor takes a materializable callback and
+// creates a generator for an ABI with a modified materializable callback.
+CoroSplitPass::CoroSplitPass(
+    std::function<bool(Instruction &)> IsMatCallback,
+    SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
+    : CreateAndInitABI([=](Function &F, coro::Shape &S) {
+        std::unique_ptr<coro::BaseABI> ABI =
+            CreateNewABI(F, S, IsMatCallback, GenCustomABIs);
         ABI->init();
         return ABI;
       }),
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index f4d9a7a8aa8569..1c45bcd7f6a837 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -73,6 +73,7 @@ static const char *const CoroIntrinsics[] = {
     "llvm.coro.await.suspend.handle",
     "llvm.coro.await.suspend.void",
     "llvm.coro.begin",
+    "llvm.coro.begin.custom.abi",
     "llvm.coro.destroy",
     "llvm.coro.done",
     "llvm.coro.end",
@@ -247,7 +248,8 @@ void coro::Shape::analyze(Function &F,
         }
         break;
       }
-      case Intrinsic::coro_begin: {
+      case Intrinsic::coro_begin:
+      case Intrinsic::coro_begin_custom_abi: {
         auto CB = cast<CoroBeginInst>(II);
 
         // Ignore coro id's that aren't pre-split.
diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
index 1d55889a32d7aa..c3394fdaa940ba 100644
--- a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
+++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
@@ -182,4 +182,91 @@ TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
   CallInst *CI = getCallByName(Resume1, "should.remat");
   ASSERT_TRUE(CI);
 }
+
+StringRef TextCoroBeginCustomABI = R"(
+    define ptr @f(i32 %n) presplitcoroutine {
+    entry:
+      %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+      %size = call i32 @llvm.coro.size.i32()
+      %alloc = call ptr @malloc(i32 %size)
+      %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0)
+
+      %inc1 = add i32 %n, 1
+      %val2 = call i32 @should.remat(i32 %inc1)
+      %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume1
+                                      i8 1, label %cleanup]
+    resume1:
+      %inc2 = add i32 %val2, 1
+      %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume2
+                                      i8 1, label %cleanup]
+
+    resume2:
+      call void @print(i32 %val2)
+      call void @print(i32 %inc2)
+      br label %cleanup
+
+    cleanup:
+      %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+      call void @free(ptr %mem)
+      br label %suspend
+    suspend:
+      call i1 @llvm.coro.end(ptr %hdl, i1 0)
+      ret ptr %hdl
+    }
+
+    declare ptr @llvm.coro.free(token, ptr)
+    declare i32 @llvm.coro.size.i32()
+    declare i8  @llvm.coro.suspend(token, i1)
+    declare void @llvm.coro.resume(ptr)
+    declare void @llvm.coro.destroy(ptr)
+
+    declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+    declare i1 @llvm.coro.alloc(token)
+    declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32)
+    declare i1 @llvm.coro.end(ptr, i1)
+
+    declare i32 @should.remat(i32)
+
+    declare noalias ptr @malloc(i32)
+    declare void @print(i32)
+    declare void @free(ptr)
+  )";
+
+// SwitchABI with overridden isMaterializable
+class ExtraCustomABI : public coro::SwitchABI {
+public:
+  ExtraCustomABI(Function &F, coro::Shape &S)
+      : coro::SwitchABI(F, S, ExtraMaterializable) {}
+};
+
+TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) {
+  ParseAssembly(TextCoroBeginCustomABI);
+
+  ASSERT_TRUE(M);
+
+  CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) {
+    return std::unique_ptr<coro::BaseABI>(new ExtraCustomABI(F, S));
+  };
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(CoroSplitPass({GenCustomABI}));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With callback the extra rematerialization of the function should have
+  // happened
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_TRUE(CI);
+}
+
 } // namespace

@TylerNowicki TylerNowicki self-assigned this Oct 9, 2024
Copy link
Member

@ChuanqiXu9 ChuanqiXu9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For such changes, we need to update coroutines.rst (the https://llvm.org/docs/Coroutines.html page)

@TylerNowicki
Copy link
Collaborator Author

For such changes, we need to update coroutines.rst (the https://llvm.org/docs/Coroutines.html page)

Should I include this (#111781) in the same PR? Perhaps I wrongly assumed that updating the docs should be a separate PR.

@ChuanqiXu9
Copy link
Member

For such changes, we need to update coroutines.rst (the https://llvm.org/docs/Coroutines.html page)

Should I include this (#111781) in the same PR? Perhaps I wrongly assumed that updating the docs should be a separate PR.

It is fine to seperate them. I just didn't know it.

@TylerNowicki
Copy link
Collaborator Author

For such changes, we need to update coroutines.rst (the https://llvm.org/docs/Coroutines.html page)

Should I include this (#111781) in the same PR? Perhaps I wrongly assumed that updating the docs should be a separate PR.

It is fine to seperate them. I just didn't know it.

No no my mistake. I only posted it since you brought it up. I got so used to submitting the PRs one by one I didn't think to post it at the same time.

@TylerNowicki TylerNowicki merged commit 3737a53 into llvm:main Oct 10, 2024
13 checks passed
@TylerNowicki TylerNowicki deleted the users/tylernowicki/coro-refactor8 branch October 15, 2024 14:24
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
This change extends the current method for creating ABI object to allow
users (plugin libraries) to create custom ABI objects for their needs.
This is accomplished by inheriting one of the common ABIs and overriding
one or more of the methods to create a custom ABI. To use a custom ABI
for a given coroutine the coro.begin.custom.abi intrinsic is used in
place of the coro.begin intrinsic. This takes an additional i32 arg that
specifies the index of an ABI generator for the custom ABI object in a
SmallVector passed to the CoroSplitPass ctor.

The detailed changes include:
* Add the llvm.coro.begin.custom intrinsic used to specify the index of
the custom ABI to use for the given coroutine.
* Add constructors to CoroSplit that take a list of generators that
create the custom ABI object.
* Extend the CreateNewABI function used by CoroSplit to return a
unique_ptr to an ABI object.
* Add has/getCustomABI methods to CoroBeginInst class.
* Add a unittest for a custom ABI.

See doc update here: llvm#111781
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
coroutines C++20 coroutines llvm:analysis Includes value tracking, cost tables and constant folding llvm:ir llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants