@@ -92,6 +92,23 @@ struct Options {
92
92
" object-filename" ,
93
93
llvm::cl::desc (" Dump JITted-compiled object to file <input file>.o" )};
94
94
};
95
+
96
+ struct CompileAndExecuteConfig {
97
+ // / LLVM module transformer that is passed to ExecutionEngine.
98
+ llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
99
+
100
+ // / A custom function that is passed to ExecutionEngine. It processes MLIR
101
+ // / module and creates LLVM IR module.
102
+ llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
103
+ llvm::LLVMContext &)>
104
+ llvmModuleBuilder;
105
+
106
+ // / A custom function that is passed to ExecutinEngine to register symbols at
107
+ // / runtime.
108
+ llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
109
+ runtimeSymbolMap;
110
+ };
111
+
95
112
} // end anonymous namespace
96
113
97
114
static OwningModuleRef parseMLIRInput (StringRef inputFilename,
@@ -131,23 +148,25 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
131
148
}
132
149
133
150
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
134
- static Error
135
- compileAndExecute (Options &options, ModuleOp module ,
136
- TranslationCallback llvmModuleBuilder, StringRef entryPoint,
137
- std::function<llvm::Error(llvm::Module *)> transformer,
138
- void **args) {
151
+ static Error compileAndExecute (Options &options, ModuleOp module ,
152
+ StringRef entryPoint,
153
+ CompileAndExecuteConfig config, void **args) {
139
154
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
140
155
if (auto clOptLevel = getCommandLineOptLevel (options))
141
156
jitCodeGenOptLevel =
142
157
static_cast <llvm::CodeGenOpt::Level>(clOptLevel.getValue ());
143
158
SmallVector<StringRef, 4 > libs (options.clSharedLibs .begin (),
144
159
options.clSharedLibs .end ());
145
160
auto expectedEngine = mlir::ExecutionEngine::create (
146
- module , llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
161
+ module , config.llvmModuleBuilder , config.transformer , jitCodeGenOptLevel,
162
+ libs);
147
163
if (!expectedEngine)
148
164
return expectedEngine.takeError ();
149
165
150
166
auto engine = std::move (*expectedEngine);
167
+ if (config.runtimeSymbolMap )
168
+ engine->registerSymbols (config.runtimeSymbolMap );
169
+
151
170
auto expectedFPtr = engine->lookup (entryPoint);
152
171
if (!expectedFPtr)
153
172
return expectedFPtr.takeError ();
@@ -163,16 +182,14 @@ compileAndExecute(Options &options, ModuleOp module,
163
182
return Error::success ();
164
183
}
165
184
166
- static Error compileAndExecuteVoidFunction (
167
- Options &options, ModuleOp module , TranslationCallback llvmModuleBuilder,
168
- StringRef entryPoint,
169
- std::function<llvm::Error(llvm::Module *)> transformer) {
185
+ static Error compileAndExecuteVoidFunction (Options &options, ModuleOp module ,
186
+ StringRef entryPoint,
187
+ CompileAndExecuteConfig config) {
170
188
auto mainFunction = module .lookupSymbol <LLVM::LLVMFuncOp>(entryPoint);
171
189
if (!mainFunction || mainFunction.empty ())
172
190
return make_string_error (" entry point not found" );
173
191
void *empty = nullptr ;
174
- return compileAndExecute (options, module , llvmModuleBuilder, entryPoint,
175
- transformer, &empty);
192
+ return compileAndExecute (options, module , entryPoint, config, &empty);
176
193
}
177
194
178
195
template <typename Type>
@@ -196,10 +213,9 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
196
213
return Error::success ();
197
214
}
198
215
template <typename Type>
199
- Error compileAndExecuteSingleReturnFunction (
200
- Options &options, ModuleOp module , TranslationCallback llvmModuleBuilder,
201
- StringRef entryPoint,
202
- std::function<llvm::Error(llvm::Module *)> transformer) {
216
+ Error compileAndExecuteSingleReturnFunction (Options &options, ModuleOp module ,
217
+ StringRef entryPoint,
218
+ CompileAndExecuteConfig config) {
203
219
auto mainFunction = module .lookupSymbol <LLVM::LLVMFuncOp>(entryPoint);
204
220
if (!mainFunction || mainFunction.isExternal ())
205
221
return make_string_error (" entry point not found" );
@@ -215,8 +231,8 @@ Error compileAndExecuteSingleReturnFunction(
215
231
void *data;
216
232
} data;
217
233
data.data = &res;
218
- if (auto error = compileAndExecute (options, module , llvmModuleBuilder ,
219
- entryPoint, transformer, (void **)&data))
234
+ if (auto error = compileAndExecute (options, module , entryPoint, config ,
235
+ (void **)&data))
220
236
return error;
221
237
222
238
// Intentional printing of the output so we can test.
@@ -226,15 +242,8 @@ Error compileAndExecuteSingleReturnFunction(
226
242
}
227
243
228
244
// / Entry point for all CPU runners. Expects the common argc/argv arguments for
229
- // / standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
230
- // / `mlirTransformer` is applied after parsing the input into MLIR IR and before
231
- // / passing the MLIR module to the ExecutionEngine.
232
- // / `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
233
- // / It processes MLIR module and creates LLVM IR module.
234
- int mlir::JitRunnerMain (
235
- int argc, char **argv,
236
- function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
237
- TranslationCallback llvmModuleBuilder) {
245
+ // / standard C++ main functions.
246
+ int mlir::JitRunnerMain (int argc, char **argv, JitRunnerConfig config) {
238
247
// Create the options struct containing the command line options for the
239
248
// runner. This must come before the command line options are parsed.
240
249
Options options;
@@ -274,8 +283,8 @@ int mlir::JitRunnerMain(
274
283
return 1 ;
275
284
}
276
285
277
- if (mlirTransformer)
278
- if (failed (mlirTransformer (m.get ())))
286
+ if (config. mlirTransformer )
287
+ if (failed (config. mlirTransformer (m.get ())))
279
288
return EXIT_FAILURE;
280
289
281
290
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost ();
@@ -292,10 +301,14 @@ int mlir::JitRunnerMain(
292
301
auto transformer = mlir::makeLLVMPassesTransformer (
293
302
passes, optLevel, /* targetMachine=*/ tmOrError->get (), optPosition);
294
303
304
+ CompileAndExecuteConfig compileAndExecuteConfig;
305
+ compileAndExecuteConfig.transformer = transformer;
306
+ compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder ;
307
+ compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap ;
308
+
295
309
// Get the function used to compile and execute the module.
296
310
using CompileAndExecuteFnT =
297
- Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
298
- std::function<llvm::Error (llvm::Module *)>);
311
+ Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
299
312
auto compileAndExecuteFn =
300
313
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType .getValue ())
301
314
.Case (" i32" , compileAndExecuteSingleReturnFunction<int32_t >)
@@ -304,11 +317,11 @@ int mlir::JitRunnerMain(
304
317
.Case (" void" , compileAndExecuteVoidFunction)
305
318
.Default (nullptr );
306
319
307
- Error error =
308
- compileAndExecuteFn
309
- ? compileAndExecuteFn (options, m. get (), llvmModuleBuilder ,
310
- options. mainFuncName . getValue (), transformer )
311
- : make_string_error (" unsupported function type" );
320
+ Error error = compileAndExecuteFn
321
+ ? compileAndExecuteFn (options, m. get (),
322
+ options. mainFuncName . getValue () ,
323
+ compileAndExecuteConfig )
324
+ : make_string_error (" unsupported function type" );
312
325
313
326
int exitCode = EXIT_SUCCESS;
314
327
llvm::handleAllErrors (std::move (error),
0 commit comments