@@ -10,13 +10,13 @@ def run(f):
10
10
module = Module .create ()
11
11
with InsertionPoint (module .body ):
12
12
print ("\n TEST:" , f .__name__ )
13
- f (module )
13
+ f ()
14
14
print (module )
15
15
return f
16
16
17
17
18
18
@run
19
- def testTypes (module : Module ):
19
+ def testTypes ():
20
20
# CHECK-LABEL: TEST: testTypes
21
21
# CHECK: !transform.any_op
22
22
any_op = transform .AnyOpType .get ()
@@ -44,7 +44,7 @@ def testTypes(module: Module):
44
44
45
45
46
46
@run
47
- def testSequenceOp (module : Module ):
47
+ def testSequenceOp ():
48
48
sequence = transform .SequenceOp (
49
49
transform .FailurePropagationMode .Propagate ,
50
50
[transform .AnyOpType .get ()],
@@ -60,23 +60,103 @@ def testSequenceOp(module: Module):
60
60
61
61
62
62
@run
63
- def testNamedSequenceOp (module : Module ):
64
- module .operation .attributes ["transform.with_named_sequence" ] = UnitAttr .get ()
65
- named_sequence = transform .NamedSequenceOp (
66
- '__transform_main' ,
67
- [transform .AnyOpType .get ()],
63
+ def testNestedSequenceOp ():
64
+ sequence = transform .SequenceOp (
65
+ transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
66
+ )
67
+ with InsertionPoint (sequence .body ):
68
+ nested = transform .SequenceOp (
69
+ transform .FailurePropagationMode .Propagate , [], sequence .bodyTarget
70
+ )
71
+ with InsertionPoint (nested .body ):
72
+ doubly_nested = transform .SequenceOp (
73
+ transform .FailurePropagationMode .Propagate ,
74
+ [transform .AnyOpType .get ()],
75
+ nested .bodyTarget ,
76
+ )
77
+ with InsertionPoint (doubly_nested .body ):
78
+ transform .YieldOp ([doubly_nested .bodyTarget ])
79
+ transform .YieldOp ()
80
+ transform .YieldOp ()
81
+ # CHECK-LABEL: TEST: testNestedSequenceOp
82
+ # CHECK: transform.sequence failures(propagate) {
83
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
84
+ # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
85
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
86
+ # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
87
+ # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
88
+ # CHECK: yield %[[ARG2]] : !transform.any_op
89
+ # CHECK: }
90
+ # CHECK: }
91
+ # CHECK: }
92
+
93
+
94
+ @run
95
+ def testSequenceOpWithExtras ():
96
+ sequence = transform .SequenceOp (
97
+ transform .FailurePropagationMode .Propagate ,
98
+ [],
99
+ transform .AnyOpType .get (),
100
+ [transform .AnyOpType .get (), transform .OperationType .get ("foo.bar" )],
101
+ )
102
+ with InsertionPoint (sequence .body ):
103
+ transform .YieldOp ()
104
+ # CHECK-LABEL: TEST: testSequenceOpWithExtras
105
+ # CHECK: transform.sequence failures(propagate)
106
+ # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
107
+
108
+
109
+ @run
110
+ def testNestedSequenceOpWithExtras ():
111
+ sequence = transform .SequenceOp (
112
+ transform .FailurePropagationMode .Propagate ,
113
+ [],
114
+ transform .AnyOpType .get (),
115
+ [transform .AnyOpType .get (), transform .OperationType .get ("foo.bar" )],
116
+ )
117
+ with InsertionPoint (sequence .body ):
118
+ nested = transform .SequenceOp (
119
+ transform .FailurePropagationMode .Propagate ,
120
+ [],
121
+ sequence .bodyTarget ,
122
+ sequence .bodyExtraArgs ,
123
+ )
124
+ with InsertionPoint (nested .body ):
125
+ transform .YieldOp ()
126
+ transform .YieldOp ()
127
+ # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
128
+ # CHECK: transform.sequence failures(propagate)
129
+ # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
130
+ # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
131
+
132
+
133
+ @run
134
+ def testTransformPDLOps ():
135
+ withPdl = transform_pdl .WithPDLPatternsOp (transform .AnyOpType .get ())
136
+ with InsertionPoint (withPdl .body ):
137
+ sequence = transform .SequenceOp (
138
+ transform .FailurePropagationMode .Propagate ,
68
139
[transform .AnyOpType .get ()],
140
+ withPdl .bodyTarget ,
69
141
)
70
- with InsertionPoint (named_sequence .body ):
71
- transform .YieldOp ([named_sequence .bodyTarget ])
72
- # CHECK-LABEL: TEST: testNamedSequenceOp
73
- # CHECK: module attributes {transform.with_named_sequence} {
74
- # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
75
- # CHECK: yield %[[ARG0]] : !transform.any_op
142
+ with InsertionPoint (sequence .body ):
143
+ match = transform_pdl .PDLMatchOp (
144
+ transform .AnyOpType .get (), sequence .bodyTarget , "pdl_matcher"
145
+ )
146
+ transform .YieldOp (match )
147
+ # CHECK-LABEL: TEST: testTransformPDLOps
148
+ # CHECK: transform.with_pdl_patterns {
149
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
150
+ # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
151
+ # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
152
+ # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
153
+ # CHECK: yield %[[RES]] : !transform.any_op
154
+ # CHECK: }
155
+ # CHECK: }
76
156
77
157
78
158
@run
79
- def testGetParentOp (module : Module ):
159
+ def testGetParentOp ():
80
160
sequence = transform .SequenceOp (
81
161
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
82
162
)
@@ -95,7 +175,7 @@ def testGetParentOp(module: Module):
95
175
96
176
97
177
@run
98
- def testMergeHandlesOp (module : Module ):
178
+ def testMergeHandlesOp ():
99
179
sequence = transform .SequenceOp (
100
180
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
101
181
)
@@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):
109
189
110
190
111
191
@run
112
- def testApplyPatternsOpCompact (module : Module ):
192
+ def testApplyPatternsOpCompact ():
113
193
sequence = transform .SequenceOp (
114
194
transform .FailurePropagationMode .Propagate , [], transform .AnyOpType .get ()
115
195
)
@@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):
124
204
125
205
126
206
@run
127
- def testApplyPatternsOpWithType (module : Module ):
207
+ def testApplyPatternsOpWithType ():
128
208
sequence = transform .SequenceOp (
129
209
transform .FailurePropagationMode .Propagate , [],
130
210
transform .OperationType .get ('test.dummy' )
@@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):
140
220
141
221
142
222
@run
143
- def testReplicateOp (module : Module ):
223
+ def testReplicateOp ():
144
224
with_pdl = transform_pdl .WithPDLPatternsOp (transform .AnyOpType .get ())
145
225
with InsertionPoint (with_pdl .body ):
146
226
sequence = transform .SequenceOp (
0 commit comments