Skip to content

Commit 97fc568

Browse files
committed
[mlir][capi] Add DialectRegistry to MLIR C-API
Exposes mlir::DialectRegistry to the C API as MlirDialectRegistry along with helper functions. A hook has been added to MlirDialectHandle that inserts the dialect into a registry. A future possible change is removing mlirDialectHandleRegisterDialect in favor of using mlirDialectHandleInsertDialect, which it is now implemented with. Differential Revision: https://reviews.llvm.org/D118293
1 parent 79606ee commit 97fc568

File tree

7 files changed

+93
-9
lines changed

7 files changed

+93
-9
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ extern "C" {
5050

5151
DEFINE_C_API_STRUCT(MlirContext, void);
5252
DEFINE_C_API_STRUCT(MlirDialect, void);
53+
DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
5354
DEFINE_C_API_STRUCT(MlirOperation, void);
5455
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
5556
DEFINE_C_API_STRUCT(MlirBlock, void);
@@ -108,6 +109,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context);
108109
MLIR_CAPI_EXPORTED intptr_t
109110
mlirContextGetNumRegisteredDialects(MlirContext context);
110111

112+
/// Append the contents of the given dialect registry to the registry associated
113+
/// with the context.
114+
MLIR_CAPI_EXPORTED void
115+
mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry);
116+
111117
/// Returns the number of dialects loaded by the context.
112118

113119
MLIR_CAPI_EXPORTED intptr_t
@@ -152,6 +158,22 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1,
152158
/// Returns the namespace of the given dialect.
153159
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
154160

161+
//===----------------------------------------------------------------------===//
162+
// DialectRegistry API.
163+
//===----------------------------------------------------------------------===//
164+
165+
/// Creates a dialect registry and transfers its ownership to the caller.
166+
MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate();
167+
168+
/// Checks if the dialect registry is null.
169+
static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
170+
return !registry.ptr;
171+
}
172+
173+
/// Takes a dialect registry owned by the caller and destroys it.
174+
MLIR_CAPI_EXPORTED void
175+
mlirDialectRegistryDestroy(MlirDialectRegistry registry);
176+
155177
//===----------------------------------------------------------------------===//
156178
// Location API.
157179
//===----------------------------------------------------------------------===//

mlir/include/mlir-c/Registration.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ typedef struct MlirDialectHandle MlirDialectHandle;
4444
MLIR_CAPI_EXPORTED
4545
MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
4646

47+
/// Inserts the dialect associated with the provided dialect handle into the
48+
/// provided dialect registry
49+
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle,
50+
MlirDialectRegistry);
51+
4752
/// Registers the dialect associated with the provided dialect handle.
4853
MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
4954
MlirContext);

mlir/include/mlir/CAPI/IR.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
2424
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
25+
DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
2526
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
2627
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
2728
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)

mlir/include/mlir/CAPI/Registration.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,23 @@
2121
//===----------------------------------------------------------------------===//
2222

2323
/// Hooks for dynamic discovery of dialects.
24-
typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
24+
typedef void (*MlirDialectRegistryInsertDialectHook)(
25+
MlirDialectRegistry registry);
2526
typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
2627
typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
2728

2829
/// Structure of dialect registration hooks.
2930
struct MlirDialectRegistrationHooks {
30-
MlirContextRegisterDialectHook registerHook;
31+
MlirDialectRegistryInsertDialectHook insertHook;
3132
MlirContextLoadDialectHook loadHook;
3233
MlirDialectGetNamespaceHook getNamespaceHook;
3334
};
3435
typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
3536

3637
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \
37-
static void mlirContextRegister##Name##Dialect(MlirContext context) { \
38-
mlir::DialectRegistry registry; \
39-
registry.insert<ClassName>(); \
40-
unwrap(context)->appendDialectRegistry(registry); \
38+
static void mlirDialectRegistryInsert##Name##Dialect( \
39+
MlirDialectRegistry registry) { \
40+
unwrap(registry)->insert<ClassName>(); \
4141
} \
4242
static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \
4343
return wrap(unwrap(context)->getOrLoadDialect<ClassName>()); \
@@ -47,8 +47,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
4747
} \
4848
MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \
4949
static MlirDialectRegistrationHooks hooks = { \
50-
mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \
51-
mlir##Name##DialectGetNamespace}; \
50+
mlirDialectRegistryInsert##Name##Dialect, \
51+
mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace}; \
5252
return MlirDialectHandle{&hooks}; \
5353
}
5454

mlir/lib/CAPI/IR/DialectHandle.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
1717
return unwrap(handle)->getNamespaceHook();
1818
}
1919

20+
void mlirDialectHandleInsertDialect(MlirDialectHandle handle,
21+
MlirDialectRegistry registry) {
22+
unwrap(handle)->insertHook(registry);
23+
}
24+
2025
void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
2126
MlirContext ctx) {
22-
unwrap(handle)->registerHook(ctx);
27+
mlir::DialectRegistry registry;
28+
mlirDialectHandleInsertDialect(handle, wrap(&registry));
29+
unwrap(ctx)->appendDialectRegistry(registry);
2330
}
2431

2532
MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
5353
return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
5454
}
5555

56+
void mlirContextAppendDialectRegistry(MlirContext ctx,
57+
MlirDialectRegistry registry) {
58+
unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
59+
}
60+
5661
// TODO: expose a cheaper way than constructing + sorting a vector only to take
5762
// its size.
5863
intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
@@ -88,6 +93,18 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
8893
return wrap(unwrap(dialect)->getNamespace());
8994
}
9095

96+
//===----------------------------------------------------------------------===//
97+
// DialectRegistry API.
98+
//===----------------------------------------------------------------------===//
99+
100+
MlirDialectRegistry mlirDialectRegistryCreate() {
101+
return wrap(new DialectRegistry());
102+
}
103+
104+
void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
105+
delete unwrap(registry);
106+
}
107+
91108
//===----------------------------------------------------------------------===//
92109
// Printing flags API.
93110
//===----------------------------------------------------------------------===//

mlir/test/CAPI/ir.c

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,36 @@ int testSymbolTable(MlirContext ctx) {
19041904
return 0;
19051905
}
19061906

1907+
int testDialectRegistry() {
1908+
fprintf(stderr, "@testDialectRegistry\n");
1909+
1910+
MlirDialectRegistry registry = mlirDialectRegistryCreate();
1911+
if (mlirDialectRegistryIsNull(registry)) {
1912+
fprintf(stderr, "ERROR: Expected registry to be present\n");
1913+
return 1;
1914+
}
1915+
1916+
MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
1917+
mlirDialectHandleInsertDialect(stdHandle, registry);
1918+
1919+
MlirContext ctx = mlirContextCreate();
1920+
if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
1921+
fprintf(stderr,
1922+
"ERROR: Expected no dialects to be registered to new context\n");
1923+
}
1924+
1925+
mlirContextAppendDialectRegistry(ctx, registry);
1926+
if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
1927+
fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
1928+
"registered to the context\n");
1929+
}
1930+
1931+
mlirContextDestroy(ctx);
1932+
mlirDialectRegistryDestroy(registry);
1933+
1934+
return 0;
1935+
}
1936+
19071937
void testDiagnostics() {
19081938
MlirContext ctx = mlirContextCreate();
19091939
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@@ -1988,6 +2018,8 @@ int main() {
19882018
return 13;
19892019
if (testSymbolTable(ctx))
19902020
return 14;
2021+
if (testDialectRegistry())
2022+
return 15;
19912023

19922024
mlirContextDestroy(ctx);
19932025

0 commit comments

Comments
 (0)