Skip to content

Commit c703b46

Browse files
authored
[mlir][py] Enable loading only specified dialects during creation. (#121421)
Gives option post as global list as well as arg to control which dialects are loaded during context creation. This enables setting either a good base set or skipping in individual cases.
1 parent 4b57783 commit c703b46

File tree

3 files changed

+80
-4
lines changed

3 files changed

+80
-4
lines changed

mlir/python/mlir/_mlir_libs/__init__.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
5858
# needs.
5959

6060
_dialect_registry = None
61+
_load_on_create_dialects = None
6162

6263

6364
def get_dialect_registry():
@@ -71,6 +72,21 @@ def get_dialect_registry():
7172
return _dialect_registry
7273

7374

75+
def append_load_on_create_dialect(dialect: str):
76+
global _load_on_create_dialects
77+
if _load_on_create_dialects is None:
78+
_load_on_create_dialects = [dialect]
79+
else:
80+
_load_on_create_dialects.append(dialect)
81+
82+
83+
def get_load_on_create_dialects():
84+
global _load_on_create_dialects
85+
if _load_on_create_dialects is None:
86+
_load_on_create_dialects = []
87+
return _load_on_create_dialects
88+
89+
7490
def _site_initialize():
7591
import importlib
7692
import itertools
@@ -132,15 +148,35 @@ def process_initializer_module(module_name):
132148
break
133149

134150
class Context(ir._BaseContext):
135-
def __init__(self, *args, **kwargs):
151+
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
136152
super().__init__(*args, **kwargs)
137153
self.append_dialect_registry(get_dialect_registry())
138154
for hook in post_init_hooks:
139155
hook(self)
140156
if not disable_multithreading:
141157
self.enable_multithreading(True)
142-
if not disable_load_all_available_dialects:
143-
self.load_all_available_dialects()
158+
if load_on_create_dialects is not None:
159+
logger.debug(
160+
"Loading all dialects from load_on_create_dialects arg %r",
161+
load_on_create_dialects,
162+
)
163+
for dialect in load_on_create_dialects:
164+
# This triggers loading the dialect into the context.
165+
_ = self.dialects[dialect]
166+
else:
167+
if disable_load_all_available_dialects:
168+
dialects = get_load_on_create_dialects()
169+
if dialects:
170+
logger.debug(
171+
"Loading all dialects from global load_on_create_dialects %r",
172+
dialects,
173+
)
174+
for dialect in dialects:
175+
# This triggers loading the dialect into the context.
176+
_ = self.dialects[dialect]
177+
else:
178+
logger.debug("Loading all available dialects")
179+
self.load_all_available_dialects()
144180
if init_module:
145181
logger.debug(
146182
"Registering translations from initializer %r", init_module

mlir/python/mlir/ir.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from ._mlir_libs._mlir.ir import *
66
from ._mlir_libs._mlir.ir import _GlobalDebug
77
from ._mlir_libs._mlir import register_type_caster, register_value_caster
8-
from ._mlir_libs import get_dialect_registry
8+
from ._mlir_libs import (
9+
get_dialect_registry,
10+
append_load_on_create_dialect,
11+
get_load_on_create_dialects,
12+
)
913

1014

1115
# Convenience decorator for registering user-friendly Attribute builders.

mlir/test/python/ir/dialects.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
121121
sys.path.append(".")
122122
_cext.globals.append_dialect_search_prefix("custom_dialect")
123123
assert _cext.globals._check_dialect_module_loaded("custom")
124+
125+
126+
# CHECK-LABEL: TEST: testDialectLoadOnCreate
127+
@run
128+
def testDialectLoadOnCreate():
129+
with Context(load_on_create_dialects=[]) as ctx:
130+
ctx.emit_error_diagnostics = True
131+
ctx.allow_unregistered_dialects = True
132+
133+
def callback(d):
134+
# CHECK: DIAGNOSTIC
135+
# CHECK-SAME: op created with unregistered dialect
136+
print(f"DIAGNOSTIC={d.message}")
137+
return True
138+
139+
handler = ctx.attach_diagnostic_handler(callback)
140+
loc = Location.unknown(ctx)
141+
try:
142+
op = Operation.create("arith.addi", loc=loc)
143+
ctx.allow_unregistered_dialects = False
144+
op.verify()
145+
except MLIRError as e:
146+
pass
147+
148+
with Context(load_on_create_dialects=["func"]) as ctx:
149+
loc = Location.unknown(ctx)
150+
fn = Operation.create("func.func", loc=loc)
151+
152+
# TODO: This may require an update if a site wide policy is set.
153+
# CHECK: Load on create: []
154+
print(f"Load on create: {get_load_on_create_dialects()}")
155+
append_load_on_create_dialect("func")
156+
# CHECK: Load on create:
157+
# CHECK-SAME: func
158+
print(f"Load on create: {get_load_on_create_dialects()}")
159+
print(get_load_on_create_dialects())

0 commit comments

Comments
 (0)