Skip to content

[mlir][python] enable registering dialects with the default Context #72488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions mlir/python/mlir/_mlir_libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,28 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.

_dialect_registry = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first todo for tomorrow is to update the python docs with all the stuff from the last ~6 months 😄.



def get_dialect_registry():
global _dialect_registry

if _dialect_registry is None:
from ._mlir import ir

_dialect_registry = ir.DialectRegistry()

return _dialect_registry


def _site_initialize():
import importlib
import itertools
import logging
from ._mlir import ir

logger = logging.getLogger(__name__)
registry = ir.DialectRegistry()
post_init_hooks = []
disable_multithreading = False

Expand All @@ -84,7 +98,7 @@ def process_initializer_module(module_name):
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(registry)
m.register_dialects(get_dialect_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
Expand All @@ -110,7 +124,7 @@ def process_initializer_module(module_name):
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(registry)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
Expand Down
4 changes: 2 additions & 2 deletions mlir/python/mlir/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)


def register_python_test_dialect(context, load=True):
def register_python_test_dialect(registry):
from .._mlir_libs import _mlirPythonTest

_mlirPythonTest.register_python_test_dialect(context, load)
_mlirPythonTest.register_dialect(registry)
1 change: 1 addition & 0 deletions mlir/python/mlir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
from ._mlir_libs import get_dialect_registry


# Convenience decorator for registering user-friendly Attribute builders.
Expand Down
16 changes: 2 additions & 14 deletions mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith

test.register_python_test_dialect(get_dialect_registry())


def run(f):
print("\nTEST:", f.__name__)
Expand All @@ -17,7 +19,6 @@ def run(f):
@run
def testAttributes():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
#
# Check op construction with attributes.
#
Expand Down Expand Up @@ -138,7 +139,6 @@ def testAttributes():
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
Expand Down Expand Up @@ -215,7 +215,6 @@ def attrBuilder():
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
Expand Down Expand Up @@ -260,7 +259,6 @@ def inferReturnTypes():
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
Expand Down Expand Up @@ -295,8 +293,6 @@ def resultTypesDefinedByTraits():
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp()
Expand All @@ -312,7 +308,6 @@ def testOptionalOperandOp():
@run
def testCustomAttribute():
with Context() as ctx:
test.register_python_test_dialect(ctx)
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
Expand Down Expand Up @@ -350,7 +345,6 @@ def testCustomAttribute():
@run
def testCustomType():
with Context() as ctx:
test.register_python_test_dialect(ctx)
a = test.TestType.get()
# CHECK: !python_test.test_type
print(a)
Expand Down Expand Up @@ -397,8 +391,6 @@ def testCustomType():
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

i8 = IntegerType.get_signless(8)

class Tensor(test.TestTensorValue):
Expand Down Expand Up @@ -436,7 +428,6 @@ def __str__(self):
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
Expand Down Expand Up @@ -488,8 +479,6 @@ def inferReturnTypeComponents():
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

a = test.TestType.get()
assert a.typeid is not None

Expand Down Expand Up @@ -542,7 +531,6 @@ def type_caster(pytype):
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/python/lib/PythonTestModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
},
py::arg("context"), py::arg("load") = true);

m.def(
"register_dialect",
[](MlirDialectRegistry registry) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleInsertDialect(pythonTestDialect, registry);
},
py::arg("registry"));

mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute)
.def_classmethod(
Expand Down