Skip to content

Commit acaff70

Browse files
authored
[mlir][python] move transform extras (#76102)
1 parent f324584 commit acaff70

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

mlir/python/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ declare_mlir_python_sources(
172172
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
173173
GEN_ENUM_BINDINGS
174174
SOURCES
175-
extras/dialects/transform/__init__.py)
175+
dialects/transform/extras/__init__.py)
176176

177177
declare_mlir_dialect_extension_python_bindings(
178178
ADD_TO_PARENT MLIRPythonSources.Dialects

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .._transform_ops_gen import *
77
from .._transform_ops_gen import _Dialect
88
from ..._mlir_libs._mlirDialectsTransform import *
9+
from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
910

1011
try:
1112
from ...ir import *

mlir/python/mlir/extras/dialects/transform/__init__.py renamed to mlir/python/mlir/dialects/transform/extras/__init__.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from __future__ import annotations
6-
from typing import Callable, Optional, Sequence
5+
from typing import Callable, Optional, Sequence, Union
76

87
from .... import ir
9-
from ....dialects import transform
10-
from ....dialects.transform import structured
8+
from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
9+
from .. import structured
1110

1211

1312
class Handle(ir.Value):
@@ -25,16 +24,16 @@ def __init__(
2524
self,
2625
v: ir.Value,
2726
*,
28-
parent: Optional[Handle] = None,
29-
children: Optional[Sequence[Handle]] = None,
27+
parent: Optional["Handle"] = None,
28+
children: Optional[Sequence["Handle"]] = None,
3029
):
3130
super().__init__(v)
3231
self.parent = parent
3332
self.children = children if children is not None else []
3433

3534

36-
@ir.register_value_caster(transform.AnyOpType.get_static_typeid())
37-
@ir.register_value_caster(transform.OperationType.get_static_typeid())
35+
@ir.register_value_caster(AnyOpType.get_static_typeid())
36+
@ir.register_value_caster(OperationType.get_static_typeid())
3837
class OpHandle(Handle):
3938
"""
4039
Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
5251

5352
def match_ops(
5453
self,
55-
ops: str
56-
| ir.OpView
57-
| structured.MatchInterfaceEnum
58-
| Sequence[str | ir.OpView],
59-
) -> OpHandle:
54+
ops: Union[
55+
str,
56+
ir.OpView,
57+
structured.MatchInterfaceEnum,
58+
Sequence[Union[str, ir.OpView]],
59+
],
60+
) -> "OpHandle":
6061
"""
6162
Emits a `transform.structured.MatchOp`.
6263
Returns a handle to payload ops that match the given names, types, or
@@ -70,23 +71,23 @@ def match_ops(
7071
if isinstance(ops, str):
7172
ops = structured.MatchInterfaceEnum[ops]
7273
match_op = structured.MatchOp(
73-
transform.AnyOpType.get(),
74+
AnyOpType.get(),
7475
self,
7576
interface=ops,
7677
)
7778

7879
# Handle op name(s), either given directly as string or given as op.
7980
else:
8081
if isinstance(ops, str):
81-
op_type = transform.OperationType.get(ops)
82+
op_type = OperationType.get(ops)
8283
op_names = [ops]
8384
elif isinstance(ops, Sequence):
84-
op_type = transform.AnyOpType.get()
85+
op_type = AnyOpType.get()
8586
op_names = [
8687
op if isinstance(op, str) else op.OPERATION_NAME for op in ops
8788
]
8889
else:
89-
op_type = transform.OperationType.get(ops.OPERATION_NAME)
90+
op_type = OperationType.get(ops.OPERATION_NAME)
9091
op_names = [ops.OPERATION_NAME]
9192
match_op = structured.MatchOp.match_op_names(
9293
op_type,
@@ -100,7 +101,7 @@ def match_ops(
100101

101102

102103
def insert_transform_script(
103-
block_or_insertion_point: ir.Block | ir.InsertionPoint,
104+
block_or_insertion_point: Union[ir.Block, ir.InsertionPoint],
104105
script: Callable[[OpHandle], None],
105106
dump_script: bool = False,
106107
) -> None:
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
137138

138139
with context, ir.Location.unknown(context):
139140
with insertion_point:
140-
named_sequence_op = transform.NamedSequenceOp(
141-
"__transform_main", [transform.AnyOpType.get()], []
141+
named_sequence_op = NamedSequenceOp(
142+
"__transform_main", [AnyOpType.get()], []
142143
)
143144
with ir.InsertionPoint(named_sequence_op.body):
144145
script(named_sequence_op.bodyTarget)
145-
transform.YieldOp([])
146+
YieldOp([])
146147

147148
if dump_script:
148149
print(named_sequence_op)

mlir/test/python/dialects/transform_extras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mlir import ir
55
from mlir.dialects import scf
66
from mlir.dialects.transform import structured
7-
from mlir.extras.dialects.transform import OpHandle, insert_transform_script
7+
from mlir.dialects.transform.extras import OpHandle, insert_transform_script
88

99

1010
def build_transform_script(script: Callable[[OpHandle], None]):

0 commit comments

Comments
 (0)