@@ -10,13 +10,13 @@ def run(f):
10
10
module = Module .create ()
11
11
with InsertionPoint (module .body ):
12
12
print ("\n TEST:" , f .__name__ )
13
- f ()
13
+ f (module )
14
14
print (module )
15
15
return f
16
16
17
17
18
18
@run
19
- def testTypes ():
19
+ def testTypes (module : Module ):
20
20
# CHECK-LABEL: TEST: testTypes
21
21
# CHECK: !transform.any_op
22
22
any_op = transform .AnyOpType .get ()
@@ -44,7 +44,7 @@ def testTypes():
44
44
45
45
46
46
@run
47
- def testSequenceOp ():
47
+ def testSequenceOp (module : Module ):
48
48
sequence = transform .SequenceOp (
49
49
transform .FailurePropagationMode .Propagate ,
50
50
[transform .AnyOpType .get ()],
@@ -58,9 +58,8 @@ def testSequenceOp():
58
58
# CHECK: yield %[[ARG0]] : !transform.any_op
59
59
# CHECK: }
60
60
61
-
62
61
@run
63
- def testNestedSequenceOp ():
62
+ def testNestedSequenceOp (module : Module ):
64
63
sequence = transform .SequenceOp (
65
64
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
66
65
)
@@ -92,7 +91,7 @@ def testNestedSequenceOp():
92
91
93
92
94
93
@run
95
- def testSequenceOpWithExtras ():
94
+ def testSequenceOpWithExtras (module : Module ):
96
95
sequence = transform .SequenceOp (
97
96
transform .FailurePropagationMode .Propagate ,
98
97
[],
@@ -107,7 +106,7 @@ def testSequenceOpWithExtras():
107
106
108
107
109
108
@run
110
- def testNestedSequenceOpWithExtras ():
109
+ def testNestedSequenceOpWithExtras (module : Module ):
111
110
sequence = transform .SequenceOp (
112
111
transform .FailurePropagationMode .Propagate ,
113
112
[],
@@ -131,7 +130,7 @@ def testNestedSequenceOpWithExtras():
131
130
132
131
133
132
@run
134
- def testTransformPDLOps ():
133
+ def testTransformPDLOps (module : Module ):
135
134
withPdl = transform_pdl .WithPDLPatternsOp (transform .AnyOpType .get ())
136
135
with InsertionPoint (withPdl .body ):
137
136
sequence = transform .SequenceOp (
@@ -154,9 +153,24 @@ def testTransformPDLOps():
154
153
# CHECK: }
155
154
# CHECK: }
156
155
156
+ @run
157
+ def testNamedSequenceOp (module : Module ):
158
+ module .operation .attributes ["transform.with_named_sequence" ] = UnitAttr .get ()
159
+ named_sequence = transform .NamedSequenceOp (
160
+ "__transform_main" ,
161
+ [transform .AnyOpType .get ()],
162
+ [transform .AnyOpType .get ()],
163
+ )
164
+ with InsertionPoint (named_sequence .body ):
165
+ transform .YieldOp ([named_sequence .bodyTarget ])
166
+ # CHECK-LABEL: TEST: testNamedSequenceOp
167
+ # CHECK: module attributes {transform.with_named_sequence} {
168
+ # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
169
+ # CHECK: yield %[[ARG0]] : !transform.any_op
170
+
157
171
158
172
@run
159
- def testGetParentOp ():
173
+ def testGetParentOp (module : Module ):
160
174
sequence = transform .SequenceOp (
161
175
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
162
176
)
@@ -175,7 +189,7 @@ def testGetParentOp():
175
189
176
190
177
191
@run
178
- def testMergeHandlesOp ():
192
+ def testMergeHandlesOp (module : Module ):
179
193
sequence = transform .SequenceOp (
180
194
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
181
195
)
@@ -189,7 +203,7 @@ def testMergeHandlesOp():
189
203
190
204
191
205
@run
192
- def testApplyPatternsOpCompact ():
206
+ def testApplyPatternsOpCompact (module : Module ):
193
207
sequence = transform .SequenceOp (
194
208
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
195
209
)
@@ -204,7 +218,7 @@ def testApplyPatternsOpCompact():
204
218
205
219
206
220
@run
207
- def testApplyPatternsOpWithType ():
221
+ def testApplyPatternsOpWithType (module : Module ):
208
222
sequence = transform .SequenceOp (
209
223
transform .FailurePropagationMode .Propagate , [],
210
224
transform .OperationType .get ('test.dummy' )
@@ -220,7 +234,7 @@ def testApplyPatternsOpWithType():
220
234
221
235
222
236
@run
223
- def testReplicateOp ():
237
+ def testReplicateOp (module : Module ):
224
238
with_pdl = transform_pdl .WithPDLPatternsOp (transform .AnyOpType .get ())
225
239
with InsertionPoint (with_pdl .body ):
226
240
sequence = transform .SequenceOp (
0 commit comments