@@ -10,13 +10,13 @@ def run(f):
10
10
module = Module .create ()
11
11
with InsertionPoint (module .body ):
12
12
print ("\n TEST:" , f .__name__ )
13
- f ()
13
+ f (module )
14
14
print (module )
15
15
return f
16
16
17
17
18
18
@run
19
- def testTypes ():
19
+ def testTypes (module : Module ):
20
20
# CHECK-LABEL: TEST: testTypes
21
21
# CHECK: !transform.any_op
22
22
any_op = transform .AnyOpType .get ()
@@ -44,7 +44,7 @@ def testTypes():
44
44
45
45
46
46
@run
47
- def testSequenceOp ():
47
+ def testSequenceOp (module : Module ):
48
48
sequence = transform .SequenceOp (
49
49
transform .FailurePropagationMode .Propagate ,
50
50
[transform .AnyOpType .get ()],
@@ -60,103 +60,23 @@ def testSequenceOp():
60
60
61
61
62
62
@run
63
- def testNestedSequenceOp ():
64
- sequence = transform .SequenceOp (
65
- transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
66
- )
67
- with InsertionPoint (sequence .body ):
68
- nested = transform .SequenceOp (
69
- transform .FailurePropagationMode .Propagate , [], sequence .bodyTarget
70
- )
71
- with InsertionPoint (nested .body ):
72
- doubly_nested = transform .SequenceOp (
73
- transform .FailurePropagationMode .Propagate ,
74
- [transform .AnyOpType .get ()],
75
- nested .bodyTarget ,
76
- )
77
- with InsertionPoint (doubly_nested .body ):
78
- transform .YieldOp ([doubly_nested .bodyTarget ])
79
- transform .YieldOp ()
80
- transform .YieldOp ()
81
- # CHECK-LABEL: TEST: testNestedSequenceOp
82
- # CHECK: transform.sequence failures(propagate) {
83
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
84
- # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
85
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
86
- # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
87
- # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
88
- # CHECK: yield %[[ARG2]] : !transform.any_op
89
- # CHECK: }
90
- # CHECK: }
91
- # CHECK: }
92
-
93
-
94
- @run
95
- def testSequenceOpWithExtras ():
96
- sequence = transform .SequenceOp (
97
- transform .FailurePropagationMode .Propagate ,
98
- [],
99
- transform .AnyOpType .get (),
100
- [transform .AnyOpType .get (), transform .OperationType .get ("foo.bar" )],
101
- )
102
- with InsertionPoint (sequence .body ):
103
- transform .YieldOp ()
104
- # CHECK-LABEL: TEST: testSequenceOpWithExtras
105
- # CHECK: transform.sequence failures(propagate)
106
- # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
107
-
108
-
109
- @run
110
- def testNestedSequenceOpWithExtras ():
111
- sequence = transform .SequenceOp (
112
- transform .FailurePropagationMode .Propagate ,
113
- [],
114
- transform .AnyOpType .get (),
115
- [transform .AnyOpType .get (), transform .OperationType .get ("foo.bar" )],
116
- )
117
- with InsertionPoint (sequence .body ):
118
- nested = transform .SequenceOp (
119
- transform .FailurePropagationMode .Propagate ,
120
- [],
121
- sequence .bodyTarget ,
122
- sequence .bodyExtraArgs ,
123
- )
124
- with InsertionPoint (nested .body ):
125
- transform .YieldOp ()
126
- transform .YieldOp ()
127
- # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
128
- # CHECK: transform.sequence failures(propagate)
129
- # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
130
- # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
131
-
132
-
133
- @run
134
- def testTransformPDLOps ():
135
- withPdl = transform_pdl .WithPDLPatternsOp (transform .AnyOpType .get ())
136
- with InsertionPoint (withPdl .body ):
137
- sequence = transform .SequenceOp (
138
- transform .FailurePropagationMode .Propagate ,
63
+ def testNamedSequenceOp (module : Module ):
64
+ module .operation .attributes ["transform.with_named_sequence" ] = UnitAttr .get ()
65
+ named_sequence = transform .NamedSequenceOp (
66
+ '__transform_main' ,
67
+ [transform .AnyOpType .get ()],
139
68
[transform .AnyOpType .get ()],
140
- withPdl .bodyTarget ,
141
69
)
142
- with InsertionPoint (sequence .body ):
143
- match = transform_pdl .PDLMatchOp (
144
- transform .AnyOpType .get (), sequence .bodyTarget , "pdl_matcher"
145
- )
146
- transform .YieldOp (match )
147
- # CHECK-LABEL: TEST: testTransformPDLOps
148
- # CHECK: transform.with_pdl_patterns {
149
- # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
150
- # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
151
- # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
152
- # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
153
- # CHECK: yield %[[RES]] : !transform.any_op
154
- # CHECK: }
155
- # CHECK: }
70
+ with InsertionPoint (named_sequence .body ):
71
+ transform .YieldOp ([named_sequence .bodyTarget ])
72
+ # CHECK-LABEL: TEST: testNamedSequenceOp
73
+ # CHECK: module attributes {transform.with_named_sequence} {
74
+ # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
75
+ # CHECK: yield %[[ARG0]] : !transform.any_op
156
76
157
77
158
78
@run
159
- def testGetParentOp ():
79
+ def testGetParentOp (module : Module ):
160
80
sequence = transform .SequenceOp (
161
81
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
162
82
)
@@ -175,7 +95,7 @@ def testGetParentOp():
175
95
176
96
177
97
@run
178
- def testMergeHandlesOp ():
98
+ def testMergeHandlesOp (module : Module ):
179
99
sequence = transform .SequenceOp (
180
100
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
181
101
)
@@ -189,7 +109,7 @@ def testMergeHandlesOp():
189
109
190
110
191
111
@run
192
- def testApplyPatternsOpCompact ():
112
+ def testApplyPatternsOpCompact (module : Module ):
193
113
sequence = transform .SequenceOp (
194
114
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
195
115
)
@@ -204,7 +124,7 @@ def testApplyPatternsOpCompact():
204
124
205
125
206
126
@run
207
- def testApplyPatternsOpWithType ():
127
+ def testApplyPatternsOpWithType (module : Module ):
208
128
sequence = transform .SequenceOp (
209
129
transform .FailurePropagationMode .Propagate , [],
210
130
transform .OperationType .get ('test.dummy' )
@@ -220,7 +140,7 @@ def testApplyPatternsOpWithType():
220
140
221
141
222
142
@run
223
- def testReplicateOp ():
143
+ def testReplicateOp (module : Module ):
224
144
with_pdl = transform_pdl .WithPDLPatternsOp (transform .AnyOpType .get ())
225
145
with InsertionPoint (with_pdl .body ):
226
146
sequence = transform .SequenceOp (
0 commit comments