-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] enable registering dialects with the default Context
#72488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] enable registering dialects with the default Context
#72488
Conversation
✅ With the latest revision this PR passed the Python code formatter. |
f846272
to
829c627
Compare
Context
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesRight now the only way to sign up for inclusion in the default registry (the registry appended to Full diff: https://github.com/llvm/llvm-project/pull/72488.diff 5 Files Affected:
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..0761579da15fb94 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,6 +56,17 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.
+
+
+def get_registry():
+ if not hasattr(get_registry, "__registry"):
+ from ._mlir import ir
+
+ get_registry.__registry = ir.DialectRegistry()
+
+ return get_registry.__registry
+
+
def _site_initialize():
import importlib
import itertools
@@ -63,7 +74,6 @@ def _site_initialize():
from ._mlir import ir
logger = logging.getLogger(__name__)
- registry = ir.DialectRegistry()
post_init_hooks = []
disable_multithreading = False
@@ -84,7 +94,7 @@ def process_initializer_module(module_name):
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(registry)
+ m.register_dialects(get_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
@@ -110,7 +120,7 @@ def process_initializer_module(module_name):
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.append_dialect_registry(registry)
+ self.append_dialect_registry(get_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..b5baa80bc767fb3 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -11,7 +11,7 @@
)
-def register_python_test_dialect(context, load=True):
+def register_python_test_dialect(registry):
from .._mlir_libs import _mlirPythonTest
- _mlirPythonTest.register_python_test_dialect(context, load)
+ _mlirPythonTest.register_dialect(registry)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 18526ab8c3c02dc..82403c0b8d5fed1 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
+from ._mlir_libs import get_registry
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index f313a400b73c0a5..562190c6fcdf5d5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -6,6 +6,8 @@
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
+test.register_python_test_dialect(get_registry())
+
def run(f):
print("\nTEST:", f.__name__)
@@ -17,7 +19,6 @@ def run(f):
@run
def testAttributes():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
#
# Check op construction with attributes.
#
@@ -138,7 +139,6 @@ def testAttributes():
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
@@ -215,7 +215,6 @@ def attrBuilder():
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
@@ -260,7 +259,6 @@ def inferReturnTypes():
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
@@ -295,8 +293,6 @@ def resultTypesDefinedByTraits():
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
-
module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp()
@@ -312,7 +308,6 @@ def testOptionalOperandOp():
@run
def testCustomAttribute():
with Context() as ctx:
- test.register_python_test_dialect(ctx)
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
@@ -350,7 +345,6 @@ def testCustomAttribute():
@run
def testCustomType():
with Context() as ctx:
- test.register_python_test_dialect(ctx)
a = test.TestType.get()
# CHECK: !python_test.test_type
print(a)
@@ -397,8 +391,6 @@ def testCustomType():
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
-
i8 = IntegerType.get_signless(8)
class Tensor(test.TestTensorValue):
@@ -436,7 +428,6 @@ def __str__(self):
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
@@ -488,8 +479,6 @@ def inferReturnTypeComponents():
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
-
a = test.TestType.get()
assert a.typeid is not None
@@ -542,7 +531,6 @@ def type_caster(pytype):
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index aff414894cb825a..f81b851f8759bf7 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
},
py::arg("context"), py::arg("load") = true);
+ m.def(
+ "register_dialect",
+ [](MlirDialectRegistry registry) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+ },
+ py::arg("registry"));
+
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute)
.def_classmethod(
|
Context
Context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming nits.
|
||
|
||
def get_registry(): | ||
if not hasattr(get_registry, "__registry"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember if this is still a thing, but didn't it used to be that double underscore prefixed attributes were lexically mangled? I've got it stuck in my "never do that" category, but may be a legacy lint check in my brain :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's true but that's only for class fields - here I'm doing something even "dirtier" and setting it on the function object so it doesn't get mangled.
But just this morning was thinking I'd refactor this to be a module global behind threading.local()
instead of this hackery. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, just move it up a level and use nonlocal. No need for dirty tricks that someone will need to grok later. I'm not sure it needs to be thread local. This is a really basic facility in the same vein as site_initialize, which is global-global.
Also, I note that your branch is named "remove_site_initialize_2". I assume this is all in addition to the current approach, which is used and works just fine for what we need it for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I note that your branch is named "remove_site_initialize_2". I assume this is all in addition to the current approach, which is used and works just fine for what we need it for.
remove_site_initialize_1
changed things around a lot (didn't exactly remove but large refactor) and then I had the lightbulb moment that all I needed was this one helper and hence remove_site_initialize_2
.
@@ -56,14 +56,28 @@ def get_include_dirs() -> Sequence[str]: | |||
# | |||
# This facility allows downstreams to customize Context creation to their | |||
# needs. | |||
|
|||
_dialect_registry = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first todo for tomorrow is to update the python docs with all the stuff from the last ~6 months 😄.
Right now the only way to sign up for inclusion in the default registry (the registry appended to
Context
by default) is through_site_initialize
. That's fine but it's path dependent for the_site_initialize_0
C extension module. This change enables registering a dialect by default in a path independent way.