Skip to content

Commit 7d5f2d2

Browse files
committed
[mlir][mlir-spirv-cpu-runner] Move MLIR pass pipeline to mlir-opt
Adds a new mlir-opt test-only pass, -test-spirv-cpu-runner-pipeline, which runs the set of MLIR passes needed for the mlir-spirv-cpu-runner, and removes them from the runner. The tests are changed to invoke mlir-opt with this flag before running the runner. The eventual goal is to move all host/device code generation steps out of the runner, like with some of the other runners.
1 parent e13cbac commit 7d5f2d2

File tree

6 files changed

+94
-26
lines changed

6 files changed

+94
-26
lines changed

mlir/test/lib/Pass/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
add_mlir_library(MLIRTestPass
33
TestDynamicPipeline.cpp
44
TestPassManager.cpp
5+
TestSPIRVCPURunnerPipeline.cpp
56

67
EXCLUDE_FROM_LIBMLIR
78

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===------------------ TestSPIRVCPURunnerPipeline.cpp --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass for use by mlir-spirv-cpu-runner tests.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
14+
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
15+
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/DLTI/DLTI.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
21+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
23+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
24+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
25+
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
26+
#include "mlir/ExecutionEngine/JitRunner.h"
27+
#include "mlir/ExecutionEngine/OptUtils.h"
28+
#include "mlir/IR/BuiltinOps.h"
29+
#include "mlir/Pass/Pass.h"
30+
#include "mlir/Pass/PassManager.h"
31+
32+
using namespace mlir;
33+
34+
namespace {
35+
36+
class TestSPIRVCPURunnerPipelinePass
37+
: public PassWrapper<TestSPIRVCPURunnerPipelinePass,
38+
OperationPass<ModuleOp>> {
39+
public:
40+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVCPURunnerPipelinePass)
41+
42+
StringRef getArgument() const final {
43+
return "test-spirv-cpu-runner-pipeline";
44+
}
45+
StringRef getDescription() const final {
46+
return "Runs a series of passes for lowering SPIR-V-dialect MLIR to "
47+
"LLVM-dialect MLIR intended for mlir-spirv-cpu-runner.";
48+
}
49+
void getDependentDialects(DialectRegistry &registry) const override {
50+
registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect,
51+
mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
52+
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
53+
mlir::DLTIDialect>();
54+
}
55+
56+
TestSPIRVCPURunnerPipelinePass() = default;
57+
TestSPIRVCPURunnerPipelinePass(const TestSPIRVCPURunnerPipelinePass &) {}
58+
59+
void runOnOperation() override {
60+
ModuleOp module = getOperation();
61+
62+
PassManager passManager(module->getContext(),
63+
module->getName().getStringRef());
64+
if (failed(applyPassManagerCLOptions(passManager)))
65+
return signalPassFailure();
66+
passManager.addPass(createGpuKernelOutliningPass());
67+
passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
68+
69+
OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
70+
nestedPM.addPass(spirv::createSPIRVLowerABIAttributesPass());
71+
nestedPM.addPass(spirv::createSPIRVUpdateVCEPass());
72+
passManager.addPass(createLowerHostCodeToLLVMPass());
73+
passManager.addPass(createConvertSPIRVToLLVMPass());
74+
75+
if (failed(runPipeline(passManager, module)))
76+
signalPassFailure();
77+
}
78+
};
79+
} // namespace
80+
81+
namespace mlir {
82+
namespace test {
83+
void registerTestSPIRVCPURunnerPipelinePass() {
84+
PassRegistration<TestSPIRVCPURunnerPipelinePass>();
85+
}
86+
} // namespace test
87+
} // namespace mlir

mlir/test/mlir-spirv-cpu-runner/double.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-spirv-cpu-runner %s -e main --entry-point-result=void --shared-libs=%mlir_runner_utils,%mlir_test_spirv_cpu_runner_c_wrappers \
1+
// RUN: mlir-opt %s -test-spirv-cpu-runner-pipeline \
2+
// RUN: | mlir-spirv-cpu-runner - -e main --entry-point-result=void --shared-libs=%mlir_runner_utils,%mlir_test_spirv_cpu_runner_c_wrappers \
23
// RUN: | FileCheck %s
34

45
// CHECK: [8, 8, 8, 8, 8, 8]

mlir/test/mlir-spirv-cpu-runner/simple_add.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-spirv-cpu-runner %s -e main --entry-point-result=void --shared-libs=%mlir_runner_utils,%mlir_test_spirv_cpu_runner_c_wrappers \
1+
// RUN: mlir-opt %s -test-spirv-cpu-runner-pipeline \
2+
// RUN: | mlir-spirv-cpu-runner - -e main --entry-point-result=void --shared-libs=%mlir_runner_utils,%mlir_test_spirv_cpu_runner_c_wrappers \
23
// RUN: | FileCheck %s
34

45
// CHECK: data =

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ void registerTestSCFWhileOpBuilderPass();
142142
void registerTestSCFWrapInZeroTripCheckPasses();
143143
void registerTestShapeMappingPass();
144144
void registerTestSliceAnalysisPass();
145+
void registerTestSPIRVCPURunnerPipelinePass();
145146
void registerTestSPIRVFuncSignatureConversion();
146147
void registerTestSPIRVVectorUnrolling();
147148
void registerTestTensorCopyInsertionPass();
@@ -278,6 +279,7 @@ void registerTestPasses() {
278279
mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
279280
mlir::test::registerTestShapeMappingPass();
280281
mlir::test::registerTestSliceAnalysisPass();
282+
mlir::test::registerTestSPIRVCPURunnerPipelinePass();
281283
mlir::test::registerTestSPIRVFuncSignatureConversion();
282284
mlir::test::registerTestSPIRVVectorUnrolling();
283285
mlir::test::registerTestTensorCopyInsertionPass();

mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,12 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15-
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16-
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
17-
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
1815
#include "mlir/Dialect/Arith/IR/Arith.h"
1916
#include "mlir/Dialect/Func/IR/FuncOps.h"
2017
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21-
#include "mlir/Dialect/GPU/Transforms/Passes.h"
2218
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2319
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2420
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
25-
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
26-
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
2721
#include "mlir/ExecutionEngine/JitRunner.h"
2822
#include "mlir/ExecutionEngine/OptUtils.h"
2923
#include "mlir/Pass/Pass.h"
@@ -75,31 +69,13 @@ convertMLIRModule(Operation *op, llvm::LLVMContext &context) {
7569
return mainModule;
7670
}
7771

78-
static LogicalResult runMLIRPasses(Operation *module,
79-
JitRunnerOptions &options) {
80-
PassManager passManager(module->getContext(),
81-
module->getName().getStringRef());
82-
if (failed(applyPassManagerCLOptions(passManager)))
83-
return failure();
84-
passManager.addPass(createGpuKernelOutliningPass());
85-
passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
86-
87-
OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
88-
nestedPM.addPass(spirv::createSPIRVLowerABIAttributesPass());
89-
nestedPM.addPass(spirv::createSPIRVUpdateVCEPass());
90-
passManager.addPass(createLowerHostCodeToLLVMPass());
91-
passManager.addPass(createConvertSPIRVToLLVMPass());
92-
return passManager.run(module);
93-
}
94-
9572
int main(int argc, char **argv) {
9673
llvm::InitLLVM y(argc, argv);
9774

9875
llvm::InitializeNativeTarget();
9976
llvm::InitializeNativeTargetAsmPrinter();
10077

10178
mlir::JitRunnerConfig jitRunnerConfig;
102-
jitRunnerConfig.mlirTransformer = runMLIRPasses;
10379
jitRunnerConfig.llvmModuleBuilder = convertMLIRModule;
10480

10581
mlir::DialectRegistry registry;

0 commit comments

Comments
 (0)