Skip to content

Commit f6c9f6e

Browse files
committed
[mlir] JitRunner: add a config option to register symbols with ExecutionEngine at runtime
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D90264
1 parent 50dfa19 commit f6c9f6e

File tree

6 files changed

+92
-55
lines changed

6 files changed

+92
-55
lines changed

mlir/include/mlir/ExecutionEngine/JitRunner.h

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,42 @@
1818
#ifndef MLIR_SUPPORT_JITRUNNER_H_
1919
#define MLIR_SUPPORT_JITRUNNER_H_
2020

21-
#include "mlir/IR/Module.h"
22-
2321
#include "llvm/ADT/STLExtras.h"
24-
#include "llvm/IR/Module.h"
22+
#include "llvm/ExecutionEngine/Orc/Core.h"
2523

26-
namespace mlir {
24+
namespace llvm {
25+
class Module;
26+
class LLVMContext;
2727

28-
using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
29-
ModuleOp, llvm::LLVMContext &)>;
28+
namespace orc {
29+
class MangleAndInterner;
30+
} // namespace orc
31+
} // namespace llvm
32+
33+
namespace mlir {
3034

3135
class ModuleOp;
3236
struct LogicalResult;
3337

38+
struct JitRunnerConfig {
39+
/// MLIR transformer applied after parsing the input into MLIR IR and before
40+
/// passing the MLIR module to the ExecutionEngine.
41+
llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
42+
43+
/// A custom function that is passed to ExecutionEngine. It processes MLIR
44+
/// module and creates LLVM IR module.
45+
llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
46+
llvm::LLVMContext &)>
47+
llvmModuleBuilder = nullptr;
48+
49+
/// A callback to register symbols with ExecutionEngine at runtime.
50+
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
51+
runtimesymbolMap = nullptr;
52+
};
53+
3454
// Entry point for all CPU runners. Expects the common argc/argv arguments for
35-
// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
36-
/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
37-
/// passing the MLIR module to the ExecutionEngine.
38-
/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
39-
/// It processes MLIR module and creates LLVM IR module.
40-
int JitRunnerMain(
41-
int argc, char **argv,
42-
llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
43-
TranslationCallback llvmModuleBuilder = nullptr);
55+
// standard C++ main functions.
56+
int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});
4457

4558
} // namespace mlir
4659

mlir/lib/ExecutionEngine/JitRunner.cpp

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,23 @@ struct Options {
9292
"object-filename",
9393
llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
9494
};
95+
96+
struct CompileAndExecuteConfig {
97+
/// LLVM module transformer that is passed to ExecutionEngine.
98+
llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
99+
100+
/// A custom function that is passed to ExecutionEngine. It processes MLIR
101+
/// module and creates LLVM IR module.
102+
llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
103+
llvm::LLVMContext &)>
104+
llvmModuleBuilder;
105+
106+
/// A custom function that is passed to ExecutinEngine to register symbols at
107+
/// runtime.
108+
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
109+
runtimeSymbolMap;
110+
};
111+
95112
} // end anonymous namespace
96113

97114
static OwningModuleRef parseMLIRInput(StringRef inputFilename,
@@ -131,23 +148,25 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
131148
}
132149

133150
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
134-
static Error
135-
compileAndExecute(Options &options, ModuleOp module,
136-
TranslationCallback llvmModuleBuilder, StringRef entryPoint,
137-
std::function<llvm::Error(llvm::Module *)> transformer,
138-
void **args) {
151+
static Error compileAndExecute(Options &options, ModuleOp module,
152+
StringRef entryPoint,
153+
CompileAndExecuteConfig config, void **args) {
139154
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
140155
if (auto clOptLevel = getCommandLineOptLevel(options))
141156
jitCodeGenOptLevel =
142157
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
143158
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
144159
options.clSharedLibs.end());
145160
auto expectedEngine = mlir::ExecutionEngine::create(
146-
module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
161+
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
162+
libs);
147163
if (!expectedEngine)
148164
return expectedEngine.takeError();
149165

150166
auto engine = std::move(*expectedEngine);
167+
if (config.runtimeSymbolMap)
168+
engine->registerSymbols(config.runtimeSymbolMap);
169+
151170
auto expectedFPtr = engine->lookup(entryPoint);
152171
if (!expectedFPtr)
153172
return expectedFPtr.takeError();
@@ -163,16 +182,14 @@ compileAndExecute(Options &options, ModuleOp module,
163182
return Error::success();
164183
}
165184

166-
static Error compileAndExecuteVoidFunction(
167-
Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
168-
StringRef entryPoint,
169-
std::function<llvm::Error(llvm::Module *)> transformer) {
185+
static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
186+
StringRef entryPoint,
187+
CompileAndExecuteConfig config) {
170188
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
171189
if (!mainFunction || mainFunction.empty())
172190
return make_string_error("entry point not found");
173191
void *empty = nullptr;
174-
return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
175-
transformer, &empty);
192+
return compileAndExecute(options, module, entryPoint, config, &empty);
176193
}
177194

178195
template <typename Type>
@@ -196,10 +213,9 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
196213
return Error::success();
197214
}
198215
template <typename Type>
199-
Error compileAndExecuteSingleReturnFunction(
200-
Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
201-
StringRef entryPoint,
202-
std::function<llvm::Error(llvm::Module *)> transformer) {
216+
Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
217+
StringRef entryPoint,
218+
CompileAndExecuteConfig config) {
203219
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
204220
if (!mainFunction || mainFunction.isExternal())
205221
return make_string_error("entry point not found");
@@ -215,8 +231,8 @@ Error compileAndExecuteSingleReturnFunction(
215231
void *data;
216232
} data;
217233
data.data = &res;
218-
if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
219-
entryPoint, transformer, (void **)&data))
234+
if (auto error = compileAndExecute(options, module, entryPoint, config,
235+
(void **)&data))
220236
return error;
221237

222238
// Intentional printing of the output so we can test.
@@ -226,15 +242,8 @@ Error compileAndExecuteSingleReturnFunction(
226242
}
227243

228244
/// Entry point for all CPU runners. Expects the common argc/argv arguments for
229-
/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
230-
/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
231-
/// passing the MLIR module to the ExecutionEngine.
232-
/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
233-
/// It processes MLIR module and creates LLVM IR module.
234-
int mlir::JitRunnerMain(
235-
int argc, char **argv,
236-
function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
237-
TranslationCallback llvmModuleBuilder) {
245+
/// standard C++ main functions.
246+
int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
238247
// Create the options struct containing the command line options for the
239248
// runner. This must come before the command line options are parsed.
240249
Options options;
@@ -274,8 +283,8 @@ int mlir::JitRunnerMain(
274283
return 1;
275284
}
276285

277-
if (mlirTransformer)
278-
if (failed(mlirTransformer(m.get())))
286+
if (config.mlirTransformer)
287+
if (failed(config.mlirTransformer(m.get())))
279288
return EXIT_FAILURE;
280289

281290
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
@@ -292,10 +301,14 @@ int mlir::JitRunnerMain(
292301
auto transformer = mlir::makeLLVMPassesTransformer(
293302
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
294303

304+
CompileAndExecuteConfig compileAndExecuteConfig;
305+
compileAndExecuteConfig.transformer = transformer;
306+
compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
307+
compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
308+
295309
// Get the function used to compile and execute the module.
296310
using CompileAndExecuteFnT =
297-
Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
298-
std::function<llvm::Error(llvm::Module *)>);
311+
Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
299312
auto compileAndExecuteFn =
300313
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
301314
.Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
@@ -304,11 +317,11 @@ int mlir::JitRunnerMain(
304317
.Case("void", compileAndExecuteVoidFunction)
305318
.Default(nullptr);
306319

307-
Error error =
308-
compileAndExecuteFn
309-
? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
310-
options.mainFuncName.getValue(), transformer)
311-
: make_string_error("unsupported function type");
320+
Error error = compileAndExecuteFn
321+
? compileAndExecuteFn(options, m.get(),
322+
options.mainFuncName.getValue(),
323+
compileAndExecuteConfig)
324+
: make_string_error("unsupported function type");
312325

313326
int exitCode = EXIT_SUCCESS;
314327
llvm::handleAllErrors(std::move(error),

mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ int main(int argc, char **argv) {
2424
llvm::InitializeNativeTargetAsmPrinter();
2525
mlir::initializeLLVMPasses();
2626

27-
return mlir::JitRunnerMain(argc, argv, nullptr);
27+
return mlir::JitRunnerMain(argc, argv);
2828
}

mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,9 @@ int main(int argc, char **argv) {
136136
LLVMInitializeNVPTXAsmPrinter();
137137

138138
mlir::initializeLLVMPasses();
139-
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
139+
140+
mlir::JitRunnerConfig jitRunnerConfig;
141+
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
142+
143+
return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
140144
}

mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,9 @@ int main(int argc, char **argv) {
8686
llvm::InitializeNativeTargetAsmPrinter();
8787
mlir::initializeLLVMPasses();
8888

89-
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule);
89+
mlir::JitRunnerConfig jitRunnerConfig;
90+
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
91+
jitRunnerConfig.llvmModuleBuilder = &convertMLIRModule;
92+
93+
return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
9094
}

mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,8 @@ int main(int argc, char **argv) {
5858
llvm::InitializeNativeTargetAsmPrinter();
5959
mlir::initializeLLVMPasses();
6060

61-
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
61+
mlir::JitRunnerConfig jitRunnerConfig;
62+
jitRunnerConfig.mlirTransformer = &runMLIRPasses;
63+
64+
return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
6265
}

0 commit comments

Comments
 (0)