Skip to content

[mlir][py] Enable loading only specified dialects during creation. #121421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions mlir/python/mlir/_mlir_libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
# needs.

_dialect_registry = None
_load_on_create_dialects = None


def get_dialect_registry():
Expand All @@ -71,6 +72,21 @@ def get_dialect_registry():
return _dialect_registry


def append_load_on_create_dialect(dialect: str):
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = [dialect]
else:
_load_on_create_dialects.append(dialect)


def get_load_on_create_dialects():
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = []
return _load_on_create_dialects


def _site_initialize():
import importlib
import itertools
Expand Down Expand Up @@ -132,15 +148,35 @@ def process_initializer_module(module_name):
break

class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
self.enable_multithreading(True)
if not disable_load_all_available_dialects:
self.load_all_available_dialects()
if load_on_create_dialects is not None:
logger.debug(
"Loading all dialects from load_on_create_dialects arg %r",
load_on_create_dialects,
)
for dialect in load_on_create_dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
if disable_load_all_available_dialects:
dialects = get_load_on_create_dialects()
if dialects:
logger.debug(
"Loading all dialects from global load_on_create_dialects %r",
dialects,
)
for dialect in dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
logger.debug("Loading all available dialects")
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
Expand Down
6 changes: 5 additions & 1 deletion mlir/python/mlir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
from ._mlir_libs import get_dialect_registry
from ._mlir_libs import (
get_dialect_registry,
append_load_on_create_dialect,
get_load_on_create_dialects,
)


# Convenience decorator for registering user-friendly Attribute builders.
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/python/ir/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")


# CHECK-LABEL: TEST: testDialectLoadOnCreate
@run
def testDialectLoadOnCreate():
with Context(load_on_create_dialects=[]) as ctx:
ctx.emit_error_diagnostics = True
ctx.allow_unregistered_dialects = True

def callback(d):
# CHECK: DIAGNOSTIC
# CHECK-SAME: op created with unregistered dialect
print(f"DIAGNOSTIC={d.message}")
return True

handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
try:
op = Operation.create("arith.addi", loc=loc)
ctx.allow_unregistered_dialects = False
op.verify()
except MLIRError as e:
pass

with Context(load_on_create_dialects=["func"]) as ctx:
loc = Location.unknown(ctx)
fn = Operation.create("func.func", loc=loc)

# TODO: This may require an update if a site wide policy is set.
# CHECK: Load on create: []
print(f"Load on create: {get_load_on_create_dialects()}")
append_load_on_create_dialect("func")
# CHECK: Load on create:
# CHECK-SAME: func
print(f"Load on create: {get_load_on_create_dialects()}")
print(get_load_on_create_dialects())
Loading