Skip to content

[mlir][py] Enable loading only specified dialects during creation. #121421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 2, 2025

Conversation

jpienaar
Copy link
Member

@jpienaar jpienaar commented Jan 1, 2025

Gives option post as global list as well as arg to control which dialects are loaded during context creation. This enables setting either a good base set or skipping in individual cases.

Gives option post as global list as well as arg to control which
dialects are loaded during context creation. This enables setting either
a good base set or skipping in individual cases.
@jpienaar jpienaar requested a review from makslevental January 1, 2025 01:57
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Jan 1, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 1, 2025

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Gives option post as global list as well as arg to control which dialects are loaded during context creation. This enables setting either a good base set or skipping in individual cases.


Full diff: https://github.com/llvm/llvm-project/pull/121421.diff

3 Files Affected:

  • (modified) mlir/python/mlir/_mlir_libs/init.py (+32-3)
  • (modified) mlir/python/mlir/ir.py (+1-1)
  • (modified) mlir/test/python/ir/dialects.py (+28)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c5cb22c6dccb8f..dbc458b887d671 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
 # needs.
 
 _dialect_registry = None
+_load_on_create_dialects = None
 
 
 def get_dialect_registry():
@@ -71,6 +72,21 @@ def get_dialect_registry():
     return _dialect_registry
 
 
+def append_load_on_create_dialect(dialect: str):
+    global _load_on_create_dialects
+    if _load_on_create_dialects is None:
+        _load_on_create_dialects = [dialect]
+    else:
+        _load_on_create_dialects.append(dialect)
+
+
+def get_load_on_create_dialects():
+    global _load_on_create_dialects
+    if _load_on_create_dialects is None:
+        _load_on_create_dialects = []
+    return _load_on_create_dialects
+
+
 def _site_initialize():
     import importlib
     import itertools
@@ -132,15 +148,28 @@ def process_initializer_module(module_name):
             break
 
     class Context(ir._BaseContext):
-        def __init__(self, *args, **kwargs):
+        def __init__(self, load_on_create_dialects=None, *args, **kwargs):
             super().__init__(*args, **kwargs)
             self.append_dialect_registry(get_dialect_registry())
             for hook in post_init_hooks:
                 hook(self)
             if not disable_multithreading:
                 self.enable_multithreading(True)
-            if not disable_load_all_available_dialects:
-                self.load_all_available_dialects()
+            if load_on_create_dialects is not None:
+                logger.debug("Loading all dialects from load_on_create_dialects arg %r", _load_on_create_dialects)
+                for dialect in load_on_create_dialects:
+                    # Load dialect.
+                    _ = self.dialects[dialect]
+            else:
+                if disable_load_all_available_dialects:
+                    if _load_on_create_dialects:
+                        logger.debug("Loading all dialects from global load_on_create_dialects %r", _load_on_create_dialects)
+                        for dialect in _load_on_create_dialects:
+                            # Load dialect.
+                            _ = self.dialects[dialect]
+                else:
+                    logger.debug("Loading all available dialects")
+                    self.load_all_available_dialects()
             if init_module:
                 logger.debug(
                     "Registering translations from initializer %r", init_module
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9a6ce462047ad2..6f1c0da8a4e5d6 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,7 +5,7 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster, register_value_caster
-from ._mlir_libs import get_dialect_registry
+from ._mlir_libs import get_dialect_registry, append_load_on_create_dialect, get_load_on_create_dialects
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index d59c6a6bc424e6..3742835208a5d9 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -121,3 +121,31 @@ def testAppendPrefixSearchPath():
         sys.path.append(".")
         _cext.globals.append_dialect_search_prefix("custom_dialect")
         assert _cext.globals._check_dialect_module_loaded("custom")
+
+
+# CHECK-LABEL: TEST: testDialectLoadOnCreate
+@run
+def testDialectLoadOnCreate():
+    with Context(load_on_create_dialects=[]) as ctx:
+        ctx.emit_error_diagnostics = True
+        ctx.allow_unregistered_dialects = True
+        
+        def callback(d):
+            # CHECK: DIAGNOSTIC
+            # CHECK-SAME: op created with unregistered dialect
+            print(f"DIAGNOSTIC={d.message}")
+            return True
+
+        handler = ctx.attach_diagnostic_handler(callback)
+        loc = Location.unknown(ctx)
+        try:
+          op = Operation.create("arith.addi", loc=loc)
+          ctx.allow_unregistered_dialects = False
+          op.verify()
+        except MLIRError as e:
+          pass
+  
+    with Context(load_on_create_dialects=["func"]) as ctx:
+      loc = Location.unknown(ctx)
+      fn = Operation.create("func.func", loc=loc)
+

Copy link

github-actions bot commented Jan 1, 2025

✅ With the latest revision this PR passed the Python code formatter.

@jpienaar jpienaar merged commit c703b46 into llvm:main Jan 2, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants