Skip to content

Commit 17ec364

Browse files
authored
[mlir][python] enable registering dialects with the default Context (#72488)
1 parent 7cbf959 commit 17ec364

File tree

5 files changed

+31
-19
lines changed

5 files changed

+31
-19
lines changed

mlir/python/mlir/_mlir_libs/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,28 @@ def get_include_dirs() -> Sequence[str]:
5656
#
5757
# This facility allows downstreams to customize Context creation to their
5858
# needs.
59+
60+
_dialect_registry = None
61+
62+
63+
def get_dialect_registry():
64+
global _dialect_registry
65+
66+
if _dialect_registry is None:
67+
from ._mlir import ir
68+
69+
_dialect_registry = ir.DialectRegistry()
70+
71+
return _dialect_registry
72+
73+
5974
def _site_initialize():
6075
import importlib
6176
import itertools
6277
import logging
6378
from ._mlir import ir
6479

6580
logger = logging.getLogger(__name__)
66-
registry = ir.DialectRegistry()
6781
post_init_hooks = []
6882
disable_multithreading = False
6983

@@ -84,7 +98,7 @@ def process_initializer_module(module_name):
8498
logger.debug("Initializing MLIR with module: %s", module_name)
8599
if hasattr(m, "register_dialects"):
86100
logger.debug("Registering dialects from initializer %r", m)
87-
m.register_dialects(registry)
101+
m.register_dialects(get_dialect_registry())
88102
if hasattr(m, "context_init_hook"):
89103
logger.debug("Adding context init hook from %r", m)
90104
post_init_hooks.append(m.context_init_hook)
@@ -110,7 +124,7 @@ def process_initializer_module(module_name):
110124
class Context(ir._BaseContext):
111125
def __init__(self, *args, **kwargs):
112126
super().__init__(*args, **kwargs)
113-
self.append_dialect_registry(registry)
127+
self.append_dialect_registry(get_dialect_registry())
114128
for hook in post_init_hooks:
115129
hook(self)
116130
if not disable_multithreading:

mlir/python/mlir/dialects/python_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313

14-
def register_python_test_dialect(context, load=True):
14+
def register_python_test_dialect(registry):
1515
from .._mlir_libs import _mlirPythonTest
1616

17-
_mlirPythonTest.register_python_test_dialect(context, load)
17+
_mlirPythonTest.register_dialect(registry)

mlir/python/mlir/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._mlir_libs._mlir.ir import *
66
from ._mlir_libs._mlir.ir import _GlobalDebug
77
from ._mlir_libs._mlir import register_type_caster, register_value_caster
8+
from ._mlir_libs import get_dialect_registry
89

910

1011
# Convenience decorator for registering user-friendly Attribute builders.

mlir/test/python/dialects/python_test.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import mlir.dialects.tensor as tensor
77
import mlir.dialects.arith as arith
88

9+
test.register_python_test_dialect(get_dialect_registry())
10+
911

1012
def run(f):
1113
print("\nTEST:", f.__name__)
@@ -17,7 +19,6 @@ def run(f):
1719
@run
1820
def testAttributes():
1921
with Context() as ctx, Location.unknown():
20-
test.register_python_test_dialect(ctx)
2122
#
2223
# Check op construction with attributes.
2324
#
@@ -138,7 +139,6 @@ def testAttributes():
138139
@run
139140
def attrBuilder():
140141
with Context() as ctx, Location.unknown():
141-
test.register_python_test_dialect(ctx)
142142
# CHECK: python_test.attributes_op
143143
op = test.AttributesOp(
144144
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
@@ -215,7 +215,6 @@ def attrBuilder():
215215
@run
216216
def inferReturnTypes():
217217
with Context() as ctx, Location.unknown(ctx):
218-
test.register_python_test_dialect(ctx)
219218
module = Module.create()
220219
with InsertionPoint(module.body):
221220
op = test.InferResultsOp()
@@ -260,7 +259,6 @@ def inferReturnTypes():
260259
@run
261260
def resultTypesDefinedByTraits():
262261
with Context() as ctx, Location.unknown(ctx):
263-
test.register_python_test_dialect(ctx)
264262
module = Module.create()
265263
with InsertionPoint(module.body):
266264
inferred = test.InferResultsOp()
@@ -295,8 +293,6 @@ def resultTypesDefinedByTraits():
295293
@run
296294
def testOptionalOperandOp():
297295
with Context() as ctx, Location.unknown():
298-
test.register_python_test_dialect(ctx)
299-
300296
module = Module.create()
301297
with InsertionPoint(module.body):
302298
op1 = test.OptionalOperandOp()
@@ -312,7 +308,6 @@ def testOptionalOperandOp():
312308
@run
313309
def testCustomAttribute():
314310
with Context() as ctx:
315-
test.register_python_test_dialect(ctx)
316311
a = test.TestAttr.get()
317312
# CHECK: #python_test.test_attr
318313
print(a)
@@ -350,7 +345,6 @@ def testCustomAttribute():
350345
@run
351346
def testCustomType():
352347
with Context() as ctx:
353-
test.register_python_test_dialect(ctx)
354348
a = test.TestType.get()
355349
# CHECK: !python_test.test_type
356350
print(a)
@@ -397,8 +391,6 @@ def testCustomType():
397391
# CHECK-LABEL: TEST: testTensorValue
398392
def testTensorValue():
399393
with Context() as ctx, Location.unknown():
400-
test.register_python_test_dialect(ctx)
401-
402394
i8 = IntegerType.get_signless(8)
403395

404396
class Tensor(test.TestTensorValue):
@@ -436,7 +428,6 @@ def __str__(self):
436428
@run
437429
def inferReturnTypeComponents():
438430
with Context() as ctx, Location.unknown(ctx):
439-
test.register_python_test_dialect(ctx)
440431
module = Module.create()
441432
i32 = IntegerType.get_signless(32)
442433
with InsertionPoint(module.body):
@@ -488,8 +479,6 @@ def inferReturnTypeComponents():
488479
@run
489480
def testCustomTypeTypeCaster():
490481
with Context() as ctx, Location.unknown():
491-
test.register_python_test_dialect(ctx)
492-
493482
a = test.TestType.get()
494483
assert a.typeid is not None
495484

@@ -542,7 +531,6 @@ def type_caster(pytype):
542531
@run
543532
def testInferTypeOpInterface():
544533
with Context() as ctx, Location.unknown(ctx):
545-
test.register_python_test_dialect(ctx)
546534
module = Module.create()
547535
with InsertionPoint(module.body):
548536
i64 = IntegerType.get_signless(64)

mlir/test/python/lib/PythonTestModule.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
3434
},
3535
py::arg("context"), py::arg("load") = true);
3636

37+
m.def(
38+
"register_dialect",
39+
[](MlirDialectRegistry registry) {
40+
MlirDialectHandle pythonTestDialect =
41+
mlirGetDialectHandle__python_test__();
42+
mlirDialectHandleInsertDialect(pythonTestDialect, registry);
43+
},
44+
py::arg("registry"));
45+
3746
mlir_attribute_subclass(m, "TestAttr",
3847
mlirAttributeIsAPythonTestTestAttribute)
3948
.def_classmethod(

0 commit comments

Comments
 (0)