Skip to content

[mlir][python] enable registering dialects with the default Context #72488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2023

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Nov 16, 2023

Right now the only way to sign up for inclusion in the default registry (the registry appended to Context by default) is through _site_initialize. That's fine but it's path dependent for the _site_initialize_0 C extension module. This change enables registering a dialect by default in a path independent way.

Copy link

github-actions bot commented Nov 16, 2023

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the remove_site_initialize_2 branch from f846272 to 829c627 Compare November 16, 2023 16:08
@makslevental makslevental changed the title [mlir][python] allow people to register dialects with the default context [mlir][python] allow people to register dialects with the default Context Nov 16, 2023
@makslevental makslevental requested review from ftynse, stellaraccident and rkayaith and removed request for ftynse November 16, 2023 16:38
@makslevental makslevental marked this pull request as ready for review November 16, 2023 16:38
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Nov 16, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Right now the only way to sign up for inclusion in the default registry (the registry appended to Context by default) is through _site_initialize. That's fine but it's path dependent for the _site_initialize_0 C extension module. This change enables registering a dialect by default in a path independent way.


Full diff: https://github.com/llvm/llvm-project/pull/72488.diff

5 Files Affected:

  • (modified) mlir/python/mlir/_mlir_libs/init.py (+13-3)
  • (modified) mlir/python/mlir/dialects/python_test.py (+2-2)
  • (modified) mlir/python/mlir/ir.py (+1)
  • (modified) mlir/test/python/dialects/python_test.py (+2-14)
  • (modified) mlir/test/python/lib/PythonTestModule.cpp (+9)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..0761579da15fb94 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,6 +56,17 @@ def get_include_dirs() -> Sequence[str]:
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
+
+
+def get_registry():
+    if not hasattr(get_registry, "__registry"):
+        from ._mlir import ir
+
+        get_registry.__registry = ir.DialectRegistry()
+
+    return get_registry.__registry
+
+
 def _site_initialize():
     import importlib
     import itertools
@@ -63,7 +74,6 @@ def _site_initialize():
     from ._mlir import ir
 
     logger = logging.getLogger(__name__)
-    registry = ir.DialectRegistry()
     post_init_hooks = []
     disable_multithreading = False
 
@@ -84,7 +94,7 @@ def process_initializer_module(module_name):
         logger.debug("Initializing MLIR with module: %s", module_name)
         if hasattr(m, "register_dialects"):
             logger.debug("Registering dialects from initializer %r", m)
-            m.register_dialects(registry)
+            m.register_dialects(get_registry())
         if hasattr(m, "context_init_hook"):
             logger.debug("Adding context init hook from %r", m)
             post_init_hooks.append(m.context_init_hook)
@@ -110,7 +120,7 @@ def process_initializer_module(module_name):
     class Context(ir._BaseContext):
         def __init__(self, *args, **kwargs):
             super().__init__(*args, **kwargs)
-            self.append_dialect_registry(registry)
+            self.append_dialect_registry(get_registry())
             for hook in post_init_hooks:
                 hook(self)
             if not disable_multithreading:
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..b5baa80bc767fb3 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -11,7 +11,7 @@
 )
 
 
-def register_python_test_dialect(context, load=True):
+def register_python_test_dialect(registry):
     from .._mlir_libs import _mlirPythonTest
 
-    _mlirPythonTest.register_python_test_dialect(context, load)
+    _mlirPythonTest.register_dialect(registry)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 18526ab8c3c02dc..82403c0b8d5fed1 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,7 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster, register_value_caster
+from ._mlir_libs import get_registry
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index f313a400b73c0a5..562190c6fcdf5d5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -6,6 +6,8 @@
 import mlir.dialects.tensor as tensor
 import mlir.dialects.arith as arith
 
+test.register_python_test_dialect(get_registry())
+
 
 def run(f):
     print("\nTEST:", f.__name__)
@@ -17,7 +19,6 @@ def run(f):
 @run
 def testAttributes():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
         #
         # Check op construction with attributes.
         #
@@ -138,7 +139,6 @@ def testAttributes():
 @run
 def attrBuilder():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
         # CHECK: python_test.attributes_op
         op = test.AttributesOp(
             # CHECK-DAG: x_affinemap = affine_map<() -> (2)>
@@ -215,7 +215,6 @@ def attrBuilder():
 @run
 def inferReturnTypes():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         with InsertionPoint(module.body):
             op = test.InferResultsOp()
@@ -260,7 +259,6 @@ def inferReturnTypes():
 @run
 def resultTypesDefinedByTraits():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         with InsertionPoint(module.body):
             inferred = test.InferResultsOp()
@@ -295,8 +293,6 @@ def resultTypesDefinedByTraits():
 @run
 def testOptionalOperandOp():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
-
         module = Module.create()
         with InsertionPoint(module.body):
             op1 = test.OptionalOperandOp()
@@ -312,7 +308,6 @@ def testOptionalOperandOp():
 @run
 def testCustomAttribute():
     with Context() as ctx:
-        test.register_python_test_dialect(ctx)
         a = test.TestAttr.get()
         # CHECK: #python_test.test_attr
         print(a)
@@ -350,7 +345,6 @@ def testCustomAttribute():
 @run
 def testCustomType():
     with Context() as ctx:
-        test.register_python_test_dialect(ctx)
         a = test.TestType.get()
         # CHECK: !python_test.test_type
         print(a)
@@ -397,8 +391,6 @@ def testCustomType():
 # CHECK-LABEL: TEST: testTensorValue
 def testTensorValue():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
-
         i8 = IntegerType.get_signless(8)
 
         class Tensor(test.TestTensorValue):
@@ -436,7 +428,6 @@ def __str__(self):
 @run
 def inferReturnTypeComponents():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         i32 = IntegerType.get_signless(32)
         with InsertionPoint(module.body):
@@ -488,8 +479,6 @@ def inferReturnTypeComponents():
 @run
 def testCustomTypeTypeCaster():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
-
         a = test.TestType.get()
         assert a.typeid is not None
 
@@ -542,7 +531,6 @@ def type_caster(pytype):
 @run
 def testInferTypeOpInterface():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         with InsertionPoint(module.body):
             i64 = IntegerType.get_signless(64)
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index aff414894cb825a..f81b851f8759bf7 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       },
       py::arg("context"), py::arg("load") = true);
 
+  m.def(
+      "register_dialect",
+      [](MlirDialectRegistry registry) {
+        MlirDialectHandle pythonTestDialect =
+            mlirGetDialectHandle__python_test__();
+        mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+      },
+      py::arg("registry"));
+
   mlir_attribute_subclass(m, "TestAttr",
                           mlirAttributeIsAPythonTestTestAttribute)
       .def_classmethod(

@makslevental makslevental changed the title [mlir][python] allow people to register dialects with the default Context [mlir][python] enable registering dialects with the default Context Nov 16, 2023
Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming nits.



def get_registry():
if not hasattr(get_registry, "__registry"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if this is still a thing, but didn't it used to be that double underscore prefixed attributes were lexically mangled? I've got it stuck in my "never do that" category, but may be a legacy lint check in my brain :)

Copy link
Contributor Author

@makslevental makslevental Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true but that's only for class fields - here I'm doing something even "dirtier" and setting it on the function object so it doesn't get mangled.

But just this morning was thinking I'd refactor this to be a module global behind threading.local() instead of this hackery. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just move it up a level and use nonlocal. No need for dirty tricks that someone will need to grok later. I'm not sure it needs to be thread local. This is a really basic facility in the same vein as site_initialize, which is global-global.

Also, I note that your branch is named "remove_site_initialize_2". I assume this is all in addition to the current approach, which is used and works just fine for what we need it for.

Copy link
Contributor Author

@makslevental makslevental Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I note that your branch is named "remove_site_initialize_2". I assume this is all in addition to the current approach, which is used and works just fine for what we need it for.

remove_site_initialize_1 changed things around a lot (didn't exactly remove but large refactor) and then I had the lightbulb moment that all I needed was this one helper and hence remove_site_initialize_2.

@makslevental makslevental merged commit 17ec364 into llvm:main Nov 28, 2023
@makslevental makslevental deleted the remove_site_initialize_2 branch November 28, 2023 01:26
@@ -56,14 +56,28 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.

_dialect_registry = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first todo for tomorrow is to update the python docs with all the stuff from the last ~6 months 😄.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants