Skip to content

Commit d11ed19

Browse files
xurui1995dchigarev
authored andcommitted
replace all-in-one pass with real pipeline (intel#174)
1 parent 77933ca commit d11ed19

File tree

7 files changed

+37
-46
lines changed

7 files changed

+37
-46
lines changed

include/gc-c/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ extern "C" {
2626

2727
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
2828
#include "gc/Transforms/Passes.capi.h.inc"
29+
30+
MLIR_CAPI_EXPORTED void mlirRegisterAllGCPassesAndPipelines(void);
31+
2932
#ifdef __cplusplus
3033
}
3134
#endif

include/gc/Transforms/Passes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ namespace xegpu {
4040
class XeGPUDialect;
4141
}
4242

43-
class PassManager;
43+
class OpPassManager;
4444

4545
namespace gc {
4646

47-
void populateFrontendPasses(mlir::PassManager &);
48-
void populateCPUPipeline(mlir::PassManager &);
47+
void populateFrontendPasses(mlir::OpPassManager &);
48+
void populateCPUPipeline(mlir::OpPassManager &);
4949

5050
#define GEN_PASS_DECL
5151
#include "gc/Transforms/Passes.h.inc"

include/gc/Transforms/Passes.td

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,6 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
3232
];
3333
}
3434

35-
def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
36-
let summary = "All-in-one pipeline for GC for CPU";
37-
let dependentDialects = ["onednn_graph::OneDNNGraphDialect",
38-
"tensor::TensorDialect",
39-
"memref::MemRefDialect",
40-
"linalg::LinalgDialect",
41-
"linalgx::LinalgxDialect",
42-
"LLVM::LLVMDialect",
43-
"scf::SCFDialect",
44-
"bufferization::BufferizationDialect",
45-
"omp::OpenMPDialect",
46-
"vector::VectorDialect"];
47-
}
4835

4936
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
5037
let summary = "Convert linalg dialect to XeGPU dialect.";

lib/gc/CAPI/Passes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,22 @@
1616
using namespace mlir::gc;
1717
using namespace mlir::cpuruntime;
1818

19+
namespace mlir::gc {
20+
void registerCPUPipeline();
21+
} // namespace mlir::gc
22+
1923
#ifdef __cplusplus
2024
extern "C" {
2125
#endif
2226

2327
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.cpp.inc"
2428
#include "gc/Transforms/Passes.capi.cpp.inc"
2529

30+
MLIR_CAPI_EXPORTED void mlirRegisterAllGCPassesAndPipelines() {
31+
registerCPUPipeline();
32+
mlirRegisterCPURuntimePasses();
33+
mlirRegisterGraphCompilerPasses();
34+
}
2635
#ifdef __cplusplus
2736
}
2837
#endif

lib/gc/Transforms/Pipeline.cpp

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2020
#include "mlir/Dialect/SCF/IR/SCF.h"
2121
#include "mlir/Dialect/Tensor/IR/Tensor.h"
22+
#include "mlir/IR/DialectRegistry.h"
2223
#include "mlir/InitAllPasses.h"
2324
#include "mlir/Pass/PassManager.h"
2425
#include "mlir/Support/LogicalResult.h"
@@ -32,12 +33,12 @@
3233
namespace mlir::gc {
3334

3435
// linalg + linalgX + tensor
35-
void populateFrontendPasses(mlir::PassManager &pm) {
36+
void populateFrontendPasses(mlir::OpPassManager &pm) {
3637
pm.addPass(createConvertOneDNNGraphToLinalg());
3738
}
3839

3940
// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack
40-
void populateTensorPasses(mlir::PassManager &pm) {
41+
void populateTensorPasses(mlir::OpPassManager &pm) {
4142
// todo: padding propagation pass
4243
// todo: layout propagation pass
4344
// todo: tensor constant propagation pass
@@ -51,7 +52,7 @@ void populateTensorPasses(mlir::PassManager &pm) {
5152
}
5253

5354
// scf + arith + math + vector + tensor + linalg.brgemm
54-
void populateVectorPasses(mlir::PassManager &pm) {
55+
void populateVectorPasses(mlir::OpPassManager &pm) {
5556
// Do promotion for math / arith ops
5657
pm.addNestedPass<func::FuncOp>(math::createMathLegalizeToF32());
5758
// sourceTypeStrs can be extended
@@ -69,7 +70,7 @@ void populateVectorPasses(mlir::PassManager &pm) {
6970
}
7071

7172
// scf + arith + math + vector + memref + linalg.brgemm
72-
void populateBufferizationPasses(mlir::PassManager &pm) {
73+
void populateBufferizationPasses(mlir::OpPassManager &pm) {
7374
bufferization::OneShotBufferizationOptions options;
7475
options.bufferizeFunctionBoundaries = true;
7576
options.setFunctionBoundaryTypeConversion(
@@ -88,7 +89,7 @@ void populateBufferizationPasses(mlir::PassManager &pm) {
8889
}
8990

9091
// scf + arith + math + vector + memref + func/microkernel
91-
void populateMicroKernelPasses(mlir::PassManager &pm) {
92+
void populateMicroKernelPasses(mlir::OpPassManager &pm) {
9293
// todo: ConvertLinalgToMicrokernel pass
9394
// todo: CleanupInvalidMicrokernel pass
9495
// todo: InvariantMicrokernelMotion pass
@@ -98,13 +99,13 @@ void populateMicroKernelPasses(mlir::PassManager &pm) {
9899
// todo: DispatchMicrokernel
99100
}
100101

101-
void populateCPURuntimePasses(mlir::PassManager &pm) {
102+
void populateCPURuntimePasses(mlir::OpPassManager &pm) {
102103
// todo: flatten nested parallel pass to support coarse-grain usion
103104
// remove this pass after we add FlattenNestedParallel
104105
pm.addPass(createConvertSCFToOpenMPPass());
105106
}
106107

107-
void populateLoweringToLLVMPasses(mlir::PassManager &pm) {
108+
void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) {
108109
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
109110
pm.addPass(createConvertSCFToCFPass());
110111
pm.addPass(cpuruntime::createCPURuntimeToLLVM());
@@ -120,13 +121,13 @@ void populateLoweringToLLVMPasses(mlir::PassManager &pm) {
120121
pm.addPass(createSymbolDCEPass());
121122
}
122123

123-
void populateLLVMPasses(mlir::PassManager &pm) {
124+
void populateLLVMPasses(mlir::OpPassManager &pm) {
124125
pm.addPass(memref::createExpandOpsPass());
125126
pm.addPass(memref::createExpandStridedMetadataPass());
126127
populateLoweringToLLVMPasses(pm);
127128
}
128129

129-
void populateCPUPipeline(mlir::PassManager &pm) {
130+
void populateCPUPipeline(mlir::OpPassManager &pm) {
130131
// front-end, oneDNN graph dialect
131132
populateFrontendPasses(pm);
132133
// middle-end, LinalgX/Linalg/tensor dialects
@@ -144,24 +145,10 @@ void populateCPUPipeline(mlir::PassManager &pm) {
144145
populateLLVMPasses(pm);
145146
}
146147

147-
#define GEN_PASS_DEF_GCCPUPIPELINE
148-
#include "gc/Transforms/Passes.h.inc"
149-
namespace {
150-
151-
class GCCPUPipeline : public impl::GCCPUPipelineBase<GCCPUPipeline> {
152-
public:
153-
friend struct PassHelper;
154-
using impl::GCCPUPipelineBase<GCCPUPipeline>::GCCPUPipelineBase;
155-
void runOnOperation() final {
156-
auto op = getOperation();
157-
PassManager pm{op->getContext()};
158-
populateCPUPipeline(pm);
159-
// TODO(longsheng): add a option to
160-
// disable threading and enable pm.enableIRPrinting();
161-
if (failed(pm.run(op)))
162-
signalPassFailure();
163-
}
164-
};
148+
void registerCPUPipeline() {
149+
PassPipelineRegistration<>("gc-cpu-pipeline",
150+
"The CPU pipeline for Graph Compiler",
151+
populateCPUPipeline);
152+
}
165153

166-
} // namespace
167154
} // namespace mlir::gc

python/MainModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
PYBIND11_MODULE(_gc_mlir, m) {
2525
m.doc() = "Graph-compiler MLIR Python binding";
2626

27-
mlirRegisterGraphCompilerPasses();
27+
mlirRegisterAllGCPassesAndPipelines();
28+
2829
//===----------------------------------------------------------------------===//
2930
// OneDNNGraph
3031
//===----------------------------------------------------------------------===//
@@ -44,7 +45,6 @@ PYBIND11_MODULE(_gc_mlir, m) {
4445
//===----------------------------------------------------------------------===//
4546
// CPURuntime
4647
//===----------------------------------------------------------------------===//
47-
mlirRegisterCPURuntimePasses();
4848
auto cpuruntimeM = m.def_submodule("cpuruntime");
4949
cpuruntimeM.def(
5050
"register_dialect",

src/gc-opt/gc-opt.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
#include <imex/InitIMEXPasses.h>
3131
#endif
3232

33+
namespace mlir::gc {
34+
void registerCPUPipeline();
35+
} // namespace mlir::gc
36+
3337
int main(int argc, char *argv[]) {
3438
#ifdef GC_USE_GPU
3539
imex::registerTransformsPasses();
@@ -41,6 +45,7 @@ int main(int argc, char *argv[]) {
4145
imex::registerConvertXeTileToXeGPU();
4246
#endif
4347
mlir::registerAllPasses();
48+
mlir::gc::registerCPUPipeline();
4449
mlir::gc::registerGraphCompilerPasses();
4550
mlir::cpuruntime::registerCPURuntimePasses();
4651
mlir::DialectRegistry registry;

0 commit comments

Comments
 (0)