2
2
# See https://llvm.org/LICENSE.txt for license information.
3
3
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
4
5
- from __future__ import annotations
6
- from typing import Callable , Optional , Sequence
5
+ from typing import Callable , Optional , Sequence , Union
7
6
8
7
from .... import ir
9
- from .... dialects import transform
10
- from .... dialects . transform import structured
8
+ from .. import AnyOpType , OperationType , NamedSequenceOp , YieldOp
9
+ from .. import structured
11
10
12
11
13
12
class Handle (ir .Value ):
@@ -25,16 +24,16 @@ def __init__(
25
24
self ,
26
25
v : ir .Value ,
27
26
* ,
28
- parent : Optional [Handle ] = None ,
29
- children : Optional [Sequence [Handle ]] = None ,
27
+ parent : Optional [" Handle" ] = None ,
28
+ children : Optional [Sequence [" Handle" ]] = None ,
30
29
):
31
30
super ().__init__ (v )
32
31
self .parent = parent
33
32
self .children = children if children is not None else []
34
33
35
34
36
- @ir .register_value_caster (transform . AnyOpType .get_static_typeid ())
37
- @ir .register_value_caster (transform . OperationType .get_static_typeid ())
35
+ @ir .register_value_caster (AnyOpType .get_static_typeid ())
36
+ @ir .register_value_caster (OperationType .get_static_typeid ())
38
37
class OpHandle (Handle ):
39
38
"""
40
39
Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
52
51
53
52
def match_ops (
54
53
self ,
55
- ops : str
56
- | ir .OpView
57
- | structured .MatchInterfaceEnum
58
- | Sequence [str | ir .OpView ],
59
- ) -> OpHandle :
54
+ ops : Union [
55
+ str ,
56
+ ir .OpView ,
57
+ structured .MatchInterfaceEnum ,
58
+ Sequence [Union [str , ir .OpView ]],
59
+ ],
60
+ ) -> "OpHandle" :
60
61
"""
61
62
Emits a `transform.structured.MatchOp`.
62
63
Returns a handle to payload ops that match the given names, types, or
@@ -70,23 +71,23 @@ def match_ops(
70
71
if isinstance (ops , str ):
71
72
ops = structured .MatchInterfaceEnum [ops ]
72
73
match_op = structured .MatchOp (
73
- transform . AnyOpType .get (),
74
+ AnyOpType .get (),
74
75
self ,
75
76
interface = ops ,
76
77
)
77
78
78
79
# Handle op name(s), either given directly as string or given as op.
79
80
else :
80
81
if isinstance (ops , str ):
81
- op_type = transform . OperationType .get (ops )
82
+ op_type = OperationType .get (ops )
82
83
op_names = [ops ]
83
84
elif isinstance (ops , Sequence ):
84
- op_type = transform . AnyOpType .get ()
85
+ op_type = AnyOpType .get ()
85
86
op_names = [
86
87
op if isinstance (op , str ) else op .OPERATION_NAME for op in ops
87
88
]
88
89
else :
89
- op_type = transform . OperationType .get (ops .OPERATION_NAME )
90
+ op_type = OperationType .get (ops .OPERATION_NAME )
90
91
op_names = [ops .OPERATION_NAME ]
91
92
match_op = structured .MatchOp .match_op_names (
92
93
op_type ,
@@ -100,7 +101,7 @@ def match_ops(
100
101
101
102
102
103
def insert_transform_script (
103
- block_or_insertion_point : ir .Block | ir .InsertionPoint ,
104
+ block_or_insertion_point : Union [ ir .Block , ir .InsertionPoint ] ,
104
105
script : Callable [[OpHandle ], None ],
105
106
dump_script : bool = False ,
106
107
) -> None :
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
137
138
138
139
with context , ir .Location .unknown (context ):
139
140
with insertion_point :
140
- named_sequence_op = transform . NamedSequenceOp (
141
- "__transform_main" , [transform . AnyOpType .get ()], []
141
+ named_sequence_op = NamedSequenceOp (
142
+ "__transform_main" , [AnyOpType .get ()], []
142
143
)
143
144
with ir .InsertionPoint (named_sequence_op .body ):
144
145
script (named_sequence_op .bodyTarget )
145
- transform . YieldOp ([])
146
+ YieldOp ([])
146
147
147
148
if dump_script :
148
149
print (named_sequence_op )
0 commit comments