Skip to content

Commit 4f51b2b

Browse files
[mlir][python]Add sugared buider for transform.named_sequence (#71597)
1 parent 5918f62 commit 4f51b2b

File tree

2 files changed

+47
-99
lines changed

2 files changed

+47
-99
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
165165
return self.body.arguments[1:]
166166

167167

168+
@_ods_cext.register_operation(_Dialect, replace=True)
169+
class NamedSequenceOp(NamedSequenceOp):
170+
def __init__(
171+
self,
172+
sym_name,
173+
input_types: Sequence[Type],
174+
result_types: Sequence[Type],
175+
):
176+
function_type = FunctionType.get(input_types, result_types)
177+
super().__init__(
178+
sym_name=sym_name,
179+
function_type=TypeAttr.get(function_type),
180+
)
181+
self.regions[0].blocks.append(*input_types)
182+
183+
@property
184+
def body(self) -> Block:
185+
return self.regions[0].blocks[0]
186+
187+
@property
188+
def bodyTarget(self) -> Value:
189+
return self.body.arguments[0]
190+
191+
@property
192+
def bodyExtraArgs(self) -> BlockArgumentList:
193+
return self.body.arguments[1:]
194+
195+
168196
@_ods_cext.register_operation(_Dialect, replace=True)
169197
class YieldOp(YieldOp):
170198
def __init__(

mlir/test/python/dialects/transform.py

Lines changed: 19 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ def run(f):
1010
module = Module.create()
1111
with InsertionPoint(module.body):
1212
print("\nTEST:", f.__name__)
13-
f()
13+
f(module)
1414
print(module)
1515
return f
1616

1717

1818
@run
19-
def testTypes():
19+
def testTypes(module: Module):
2020
# CHECK-LABEL: TEST: testTypes
2121
# CHECK: !transform.any_op
2222
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
4444

4545

4646
@run
47-
def testSequenceOp():
47+
def testSequenceOp(module: Module):
4848
sequence = transform.SequenceOp(
4949
transform.FailurePropagationMode.Propagate,
5050
[transform.AnyOpType.get()],
@@ -60,103 +60,23 @@ def testSequenceOp():
6060

6161

6262
@run
63-
def testNestedSequenceOp():
64-
sequence = transform.SequenceOp(
65-
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
66-
)
67-
with InsertionPoint(sequence.body):
68-
nested = transform.SequenceOp(
69-
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
70-
)
71-
with InsertionPoint(nested.body):
72-
doubly_nested = transform.SequenceOp(
73-
transform.FailurePropagationMode.Propagate,
74-
[transform.AnyOpType.get()],
75-
nested.bodyTarget,
76-
)
77-
with InsertionPoint(doubly_nested.body):
78-
transform.YieldOp([doubly_nested.bodyTarget])
79-
transform.YieldOp()
80-
transform.YieldOp()
81-
# CHECK-LABEL: TEST: testNestedSequenceOp
82-
# CHECK: transform.sequence failures(propagate) {
83-
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
84-
# CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
85-
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
86-
# CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
87-
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
88-
# CHECK: yield %[[ARG2]] : !transform.any_op
89-
# CHECK: }
90-
# CHECK: }
91-
# CHECK: }
92-
93-
94-
@run
95-
def testSequenceOpWithExtras():
96-
sequence = transform.SequenceOp(
97-
transform.FailurePropagationMode.Propagate,
98-
[],
99-
transform.AnyOpType.get(),
100-
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
101-
)
102-
with InsertionPoint(sequence.body):
103-
transform.YieldOp()
104-
# CHECK-LABEL: TEST: testSequenceOpWithExtras
105-
# CHECK: transform.sequence failures(propagate)
106-
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
107-
108-
109-
@run
110-
def testNestedSequenceOpWithExtras():
111-
sequence = transform.SequenceOp(
112-
transform.FailurePropagationMode.Propagate,
113-
[],
114-
transform.AnyOpType.get(),
115-
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
116-
)
117-
with InsertionPoint(sequence.body):
118-
nested = transform.SequenceOp(
119-
transform.FailurePropagationMode.Propagate,
120-
[],
121-
sequence.bodyTarget,
122-
sequence.bodyExtraArgs,
123-
)
124-
with InsertionPoint(nested.body):
125-
transform.YieldOp()
126-
transform.YieldOp()
127-
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
128-
# CHECK: transform.sequence failures(propagate)
129-
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
130-
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
131-
132-
133-
@run
134-
def testTransformPDLOps():
135-
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
136-
with InsertionPoint(withPdl.body):
137-
sequence = transform.SequenceOp(
138-
transform.FailurePropagationMode.Propagate,
63+
def testNamedSequenceOp(module: Module):
64+
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
65+
named_sequence = transform.NamedSequenceOp(
66+
'__transform_main',
67+
[transform.AnyOpType.get()],
13968
[transform.AnyOpType.get()],
140-
withPdl.bodyTarget,
14169
)
142-
with InsertionPoint(sequence.body):
143-
match = transform_pdl.PDLMatchOp(
144-
transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
145-
)
146-
transform.YieldOp(match)
147-
# CHECK-LABEL: TEST: testTransformPDLOps
148-
# CHECK: transform.with_pdl_patterns {
149-
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
150-
# CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
151-
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
152-
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
153-
# CHECK: yield %[[RES]] : !transform.any_op
154-
# CHECK: }
155-
# CHECK: }
70+
with InsertionPoint(named_sequence.body):
71+
transform.YieldOp([named_sequence.bodyTarget])
72+
# CHECK-LABEL: TEST: testNamedSequenceOp
73+
# CHECK: module attributes {transform.with_named_sequence} {
74+
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
75+
# CHECK: yield %[[ARG0]] : !transform.any_op
15676

15777

15878
@run
159-
def testGetParentOp():
79+
def testGetParentOp(module: Module):
16080
sequence = transform.SequenceOp(
16181
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
16282
)
@@ -175,7 +95,7 @@ def testGetParentOp():
17595

17696

17797
@run
178-
def testMergeHandlesOp():
98+
def testMergeHandlesOp(module: Module):
17999
sequence = transform.SequenceOp(
180100
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
181101
)
@@ -189,7 +109,7 @@ def testMergeHandlesOp():
189109

190110

191111
@run
192-
def testApplyPatternsOpCompact():
112+
def testApplyPatternsOpCompact(module: Module):
193113
sequence = transform.SequenceOp(
194114
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
195115
)
@@ -204,7 +124,7 @@ def testApplyPatternsOpCompact():
204124

205125

206126
@run
207-
def testApplyPatternsOpWithType():
127+
def testApplyPatternsOpWithType(module: Module):
208128
sequence = transform.SequenceOp(
209129
transform.FailurePropagationMode.Propagate, [],
210130
transform.OperationType.get('test.dummy')
@@ -220,7 +140,7 @@ def testApplyPatternsOpWithType():
220140

221141

222142
@run
223-
def testReplicateOp():
143+
def testReplicateOp(module: Module):
224144
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
225145
with InsertionPoint(with_pdl.body):
226146
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)