Skip to content

Commit d9e04b0

Browse files
committed
[mlir][CAPI] Expose the rest of MLIRContext's constructors
It's recommended practice that people calling MLIR in a loop pre-create a LLVM ThreadPool and a dialect registry and then explicitly pass those into a MLIRContext for each compilation. However, the C API does not expose the functions needed to follow this recommendation from a project that isn't calling MLIR's C++ dilectly. Add the necessary APIs to mlir-c, including a wrapper around LLVM's ThreadPool struct (so as to avoid having to amend or re-export parts of the LLVM API). Reviewed By: makslevental Differential Revision: https://reviews.llvm.org/D153593
1 parent 29252fd commit d9e04b0

File tree

6 files changed

+83
-0
lines changed

6 files changed

+83
-0
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,19 @@ typedef struct MlirNamedAttribute MlirNamedAttribute;
8484
//===----------------------------------------------------------------------===//
8585

8686
/// Creates an MLIR context and transfers its ownership to the caller.
87+
/// This sets the default multithreading option (enabled).
8788
MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void);
8889

90+
/// Creates an MLIR context with an explicit setting of the multithreading
91+
/// setting and transfers its ownership to the caller.
92+
MLIR_CAPI_EXPORTED MlirContext
93+
mlirContextCreateWithThreading(bool threadingEnabled);
94+
95+
/// Creates an MLIR context, setting the multithreading setting explicitly and
96+
/// pre-loading the dialects from the provided DialectRegistry.
97+
MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithRegistry(
98+
MlirDialectRegistry registry, bool threadingEnabled);
99+
89100
/// Checks if two contexts are equal.
90101
MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
91102

@@ -144,6 +155,13 @@ mlirContextLoadAllAvailableDialects(MlirContext context);
144155
MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
145156
MlirStringRef name);
146157

158+
/// Sets the thread pool of the context explicitly, enabling multithreading in
159+
/// the process. This API should be used to avoid re-creating thread pools in
160+
/// long-running applications that perform multiple compilations, see
161+
/// the C++ documentation for MLIRContext for details.
162+
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
163+
MlirLlvmThreadPool threadPool);
164+
147165
//===----------------------------------------------------------------------===//
148166
// Dialect API.
149167
//===----------------------------------------------------------------------===//

mlir/include/mlir-c/Support.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ extern "C" {
5656
}; \
5757
typedef struct name name
5858

59+
/// Re-export llvm::ThreadPool so as to avoid including the LLVM C API directly.
60+
DEFINE_C_API_STRUCT(MlirLlvmThreadPool, void);
5961
DEFINE_C_API_STRUCT(MlirTypeID, const void);
6062
DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void);
6163

@@ -138,6 +140,17 @@ inline static MlirLogicalResult mlirLogicalResultFailure(void) {
138140
return res;
139141
}
140142

143+
//===----------------------------------------------------------------------===//
144+
// MlirLlvmThreadPool.
145+
//===----------------------------------------------------------------------===//
146+
147+
/// Create an LLVM thread pool. This is reexported here to avoid directly
148+
/// pulling in the LLVM headers directly.
149+
MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void);
150+
151+
/// Destroy an LLVM thread pool.
152+
MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool);
153+
141154
//===----------------------------------------------------------------------===//
142155
// TypeID API.
143156
//===----------------------------------------------------------------------===//

mlir/include/mlir/CAPI/Support.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
#include "mlir/Support/TypeID.h"
2222
#include "llvm/ADT/StringRef.h"
2323

24+
namespace llvm {
25+
class ThreadPool;
26+
} // namespace llvm
27+
2428
/// Converts a StringRef into its MLIR C API equivalent.
2529
inline MlirStringRef wrap(llvm::StringRef ref) {
2630
return mlirStringRefCreate(ref.data(), ref.size());
@@ -41,6 +45,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
4145
return mlir::success(mlirLogicalResultIsSuccess(res));
4246
}
4347

48+
DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool)
4449
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
4550
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
4651

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ MlirContext mlirContextCreate() {
3939
return wrap(context);
4040
}
4141

42+
static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) {
43+
return threadingEnabled ? MLIRContext::Threading::ENABLED
44+
: MLIRContext::Threading::DISABLED;
45+
}
46+
47+
MlirContext mlirContextCreateWithThreading(bool threadingEnabled) {
48+
auto *context = new MLIRContext(toThreadingEnum(threadingEnabled));
49+
return wrap(context);
50+
}
51+
52+
MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry,
53+
bool threadingEnabled) {
54+
auto *context =
55+
new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled));
56+
return wrap(context);
57+
}
58+
4259
bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
4360
return unwrap(ctx1) == unwrap(ctx2);
4461
}
@@ -84,6 +101,11 @@ void mlirContextLoadAllAvailableDialects(MlirContext context) {
84101
unwrap(context)->loadAllAvailableDialects();
85102
}
86103

104+
void mlirContextSetThreadPool(MlirContext context,
105+
MlirLlvmThreadPool threadPool) {
106+
unwrap(context)->setThreadPool(*unwrap(threadPool));
107+
}
108+
87109
//===----------------------------------------------------------------------===//
88110
// Dialect API.
89111
//===----------------------------------------------------------------------===//

mlir/lib/CAPI/IR/Support.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/CAPI/Support.h"
1010
#include "llvm/ADT/StringRef.h"
11+
#include "llvm/Support/ThreadPool.h"
1112

1213
#include <cstring>
1314

@@ -20,6 +21,17 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
2021
llvm::StringRef(other.data, other.length);
2122
}
2223

24+
//===----------------------------------------------------------------------===//
25+
// LLVM ThreadPool API.
26+
//===----------------------------------------------------------------------===//
27+
MlirLlvmThreadPool mlirLlvmThreadPoolCreate() {
28+
return wrap(new llvm::ThreadPool());
29+
}
30+
31+
void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) {
32+
delete unwrap(threadPool);
33+
}
34+
2335
//===----------------------------------------------------------------------===//
2436
// TypeID API.
2537
//===----------------------------------------------------------------------===//

mlir/test/CAPI/ir.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2210,6 +2210,18 @@ int testDialectRegistry(void) {
22102210
return 0;
22112211
}
22122212

2213+
void testExplicitThreadPools(void) {
2214+
MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate();
2215+
MlirDialectRegistry registry = mlirDialectRegistryCreate();
2216+
mlirRegisterAllDialects(registry);
2217+
MlirContext context =
2218+
mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false);
2219+
mlirContextSetThreadPool(context, threadPool);
2220+
mlirContextDestroy(context);
2221+
mlirDialectRegistryDestroy(registry);
2222+
mlirLlvmThreadPoolDestroy(threadPool);
2223+
}
2224+
22132225
void testDiagnostics(void) {
22142226
MlirContext ctx = mlirContextCreate();
22152227
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@@ -2310,6 +2322,7 @@ int main(void) {
23102322

23112323
mlirContextDestroy(ctx);
23122324

2325+
testExplicitThreadPools();
23132326
testDiagnostics();
23142327
return 0;
23152328
}

0 commit comments

Comments
 (0)