Skip to content

[MLIR][NFC] Retire let constructor for Async #137461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions mlir/include/mlir/Dialect/Async/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,9 @@ class ConversionTarget;
#define GEN_PASS_DECL
#include "mlir/Dialect/Async/Passes.h.inc"

std::unique_ptr<Pass> createAsyncParallelForPass();

std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads,
int32_t minTaskSize);

void populateAsyncFuncToAsyncRuntimeConversionPatterns(
RewritePatternSet &patterns, ConversionTarget &target);

std::unique_ptr<OperationPass<ModuleOp>> createAsyncFuncToAsyncRuntimePass();

std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();

std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();

std::unique_ptr<Pass> createAsyncRuntimeRefCountingOptPass();

std::unique_ptr<Pass> createAsyncRuntimePolicyBasedRefCountingPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 8 additions & 12 deletions mlir/include/mlir/Dialect/Async/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

include "mlir/Pass/PassBase.td"

def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
def AsyncParallelForPass : Pass<"async-parallel-for", "ModuleOp"> {
let summary = "Convert scf.parallel operations to multiple async compute ops "
"executed concurrently for non-overlapping iteration ranges";
let constructor = "mlir::createAsyncParallelForPass()";

let options = [
Option<"asyncDispatch", "async-dispatch",
Expand All @@ -41,21 +40,20 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
];
}

def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
def AsyncToAsyncRuntimePass : Pass<"async-to-async-runtime", "ModuleOp"> {
let summary = "Lower all high level async operations (e.g. async.execute) to"
"the explicit async.runtime and async.coro operations";
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect", "cf::ControlFlowDialect"];
}

def AsyncFuncToAsyncRuntime : Pass<"async-func-to-async-runtime", "ModuleOp"> {
def AsyncFuncToAsyncRuntimePass
: Pass<"async-func-to-async-runtime", "ModuleOp"> {
let summary = "Lower async.func operations to the explicit async.runtime and"
"async.coro operations";
let constructor = "mlir::createAsyncFuncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
}

def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
def AsyncRuntimeRefCountingPass : Pass<"async-runtime-ref-counting"> {
let summary = "Automatic reference counting for Async runtime operations";
let description = [{
This pass works at the async runtime abtraction level, after all
Expand All @@ -68,18 +66,17 @@ def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering
}];

let constructor = "mlir::createAsyncRuntimeRefCountingPass()";
let dependentDialects = ["async::AsyncDialect"];
}

def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> {
def AsyncRuntimeRefCountingOptPass : Pass<"async-runtime-ref-counting-opt"> {
let summary = "Optimize automatic reference counting operations for the"
"Async runtime by removing redundant operations";
let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()";

let dependentDialects = ["async::AsyncDialect"];
}

def AsyncRuntimePolicyBasedRefCounting
def AsyncRuntimePolicyBasedRefCountingPass
: Pass<"async-runtime-policy-based-ref-counting"> {
let summary = "Policy based reference counting for Async runtime operations";
let description = [{
Expand Down Expand Up @@ -107,7 +104,6 @@ def AsyncRuntimePolicyBasedRefCounting
automatic reference counting.
}];

let constructor = "mlir::createAsyncRuntimePolicyBasedRefCountingPass()";
let dependentDialects = ["async::AsyncDialect"];
}

Expand Down
24 changes: 3 additions & 21 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <utility>

namespace mlir {
#define GEN_PASS_DEF_ASYNCPARALLELFOR
#define GEN_PASS_DEF_ASYNCPARALLELFORPASS
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -99,15 +99,8 @@ namespace {
// }
//
struct AsyncParallelForPass
: public impl::AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;

AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
int32_t minTaskSize) {
this->asyncDispatch = asyncDispatch;
this->numWorkerThreads = numWorkerThreads;
this->minTaskSize = minTaskSize;
}
: public impl::AsyncParallelForPassBase<AsyncParallelForPass> {
using Base::Base;

void runOnOperation() override;
};
Expand Down Expand Up @@ -935,17 +928,6 @@ void AsyncParallelForPass::runOnOperation() {
signalPassFailure();
}

std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
return std::make_unique<AsyncParallelForPass>();
}

std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads,
int32_t minTaskSize) {
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
minTaskSize);
}

void mlir::async::populateAsyncParallelForPatterns(
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) {
Expand Down
19 changes: 5 additions & 14 deletions mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
#include "llvm/ADT/SmallSet.h"

namespace mlir {
#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING
#define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING
#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGPASS
#define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTINGPASS
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -109,7 +109,8 @@ static LogicalResult walkReferenceCountedValues(
namespace {

class AsyncRuntimeRefCountingPass
: public impl::AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
: public impl::AsyncRuntimeRefCountingPassBase<
AsyncRuntimeRefCountingPass> {
public:
AsyncRuntimeRefCountingPass() = default;
void runOnOperation() override;
Expand Down Expand Up @@ -468,7 +469,7 @@ void AsyncRuntimeRefCountingPass::runOnOperation() {
namespace {

class AsyncRuntimePolicyBasedRefCountingPass
: public impl::AsyncRuntimePolicyBasedRefCountingBase<
: public impl::AsyncRuntimePolicyBasedRefCountingPassBase<
AsyncRuntimePolicyBasedRefCountingPass> {
public:
AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
Expand Down Expand Up @@ -553,13 +554,3 @@ void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
if (failed(walkReferenceCountedValues(getOperation(), functor)))
signalPassFailure();
}

//----------------------------------------------------------------------------//

std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
return std::make_unique<AsyncRuntimeRefCountingPass>();
}

std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPTPASS
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir

Expand All @@ -30,7 +30,7 @@ using namespace mlir::async;
namespace {

class AsyncRuntimeRefCountingOptPass
: public impl::AsyncRuntimeRefCountingOptBase<
: public impl::AsyncRuntimeRefCountingOptPassBase<
AsyncRuntimeRefCountingOptPass> {
public:
AsyncRuntimeRefCountingOptPass() = default;
Expand Down Expand Up @@ -230,7 +230,3 @@ void AsyncRuntimeRefCountingOptPass::runOnOperation() {
kv.second->erase();
}
}

std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
}
18 changes: 5 additions & 13 deletions mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
#include <optional>

namespace mlir {
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIMEPASS
#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIMEPASS
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir

Expand All @@ -47,7 +47,7 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
namespace {

class AsyncToAsyncRuntimePass
: public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
: public impl::AsyncToAsyncRuntimePassBase<AsyncToAsyncRuntimePass> {
public:
AsyncToAsyncRuntimePass() = default;
void runOnOperation() override;
Expand All @@ -58,7 +58,8 @@ class AsyncToAsyncRuntimePass
namespace {

class AsyncFuncToAsyncRuntimePass
: public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
: public impl::AsyncFuncToAsyncRuntimePassBase<
AsyncFuncToAsyncRuntimePass> {
public:
AsyncFuncToAsyncRuntimePass() = default;
void runOnOperation() override;
Expand Down Expand Up @@ -896,12 +897,3 @@ void AsyncFuncToAsyncRuntimePass::runOnOperation() {
return;
}
}

std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
return std::make_unique<AsyncToAsyncRuntimePass>();
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createAsyncFuncToAsyncRuntimePass() {
return std::make_unique<AsyncFuncToAsyncRuntimePass>();
}
Loading