Skip to content

Commit 5b116bf

Browse files
[mlir][python] Reland - Add sugared builder for transform.named_sequence (llvm#71597)
1 parent be056f6 commit 5b116bf

File tree

2 files changed

+55
-13
lines changed

2 files changed

+55
-13
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList:
165165
return self.body.arguments[1:]
166166

167167

168+
@_ods_cext.register_operation(_Dialect, replace=True)
169+
class NamedSequenceOp(NamedSequenceOp):
170+
def __init__(
171+
self,
172+
sym_name,
173+
input_types: Sequence[Type],
174+
result_types: Sequence[Type],
175+
):
176+
function_type = FunctionType.get(input_types, result_types)
177+
super().__init__(
178+
sym_name=sym_name,
179+
function_type=TypeAttr.get(function_type),
180+
)
181+
self.regions[0].blocks.append(*input_types)
182+
183+
@property
184+
def body(self) -> Block:
185+
return self.regions[0].blocks[0]
186+
187+
@property
188+
def bodyTarget(self) -> Value:
189+
return self.body.arguments[0]
190+
191+
@property
192+
def bodyExtraArgs(self) -> BlockArgumentList:
193+
return self.body.arguments[1:]
194+
195+
168196
@_ods_cext.register_operation(_Dialect, replace=True)
169197
class YieldOp(YieldOp):
170198
def __init__(

mlir/test/python/dialects/transform.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ def run(f):
1010
module = Module.create()
1111
with InsertionPoint(module.body):
1212
print("\nTEST:", f.__name__)
13-
f()
13+
f(module)
1414
print(module)
1515
return f
1616

1717

1818
@run
19-
def testTypes():
19+
def testTypes(module: Module):
2020
# CHECK-LABEL: TEST: testTypes
2121
# CHECK: !transform.any_op
2222
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes():
4444

4545

4646
@run
47-
def testSequenceOp():
47+
def testSequenceOp(module: Module):
4848
sequence = transform.SequenceOp(
4949
transform.FailurePropagationMode.Propagate,
5050
[transform.AnyOpType.get()],
@@ -58,9 +58,8 @@ def testSequenceOp():
5858
# CHECK: yield %[[ARG0]] : !transform.any_op
5959
# CHECK: }
6060

61-
6261
@run
63-
def testNestedSequenceOp():
62+
def testNestedSequenceOp(module: Module):
6463
sequence = transform.SequenceOp(
6564
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
6665
)
@@ -92,7 +91,7 @@ def testNestedSequenceOp():
9291

9392

9493
@run
95-
def testSequenceOpWithExtras():
94+
def testSequenceOpWithExtras(module: Module):
9695
sequence = transform.SequenceOp(
9796
transform.FailurePropagationMode.Propagate,
9897
[],
@@ -107,7 +106,7 @@ def testSequenceOpWithExtras():
107106

108107

109108
@run
110-
def testNestedSequenceOpWithExtras():
109+
def testNestedSequenceOpWithExtras(module: Module):
111110
sequence = transform.SequenceOp(
112111
transform.FailurePropagationMode.Propagate,
113112
[],
@@ -131,7 +130,7 @@ def testNestedSequenceOpWithExtras():
131130

132131

133132
@run
134-
def testTransformPDLOps():
133+
def testTransformPDLOps(module: Module):
135134
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
136135
with InsertionPoint(withPdl.body):
137136
sequence = transform.SequenceOp(
@@ -154,9 +153,24 @@ def testTransformPDLOps():
154153
# CHECK: }
155154
# CHECK: }
156155

156+
@run
157+
def testNamedSequenceOp(module: Module):
158+
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
159+
named_sequence = transform.NamedSequenceOp(
160+
"__transform_main",
161+
[transform.AnyOpType.get()],
162+
[transform.AnyOpType.get()],
163+
)
164+
with InsertionPoint(named_sequence.body):
165+
transform.YieldOp([named_sequence.bodyTarget])
166+
# CHECK-LABEL: TEST: testNamedSequenceOp
167+
# CHECK: module attributes {transform.with_named_sequence} {
168+
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
169+
# CHECK: yield %[[ARG0]] : !transform.any_op
170+
157171

158172
@run
159-
def testGetParentOp():
173+
def testGetParentOp(module: Module):
160174
sequence = transform.SequenceOp(
161175
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
162176
)
@@ -175,7 +189,7 @@ def testGetParentOp():
175189

176190

177191
@run
178-
def testMergeHandlesOp():
192+
def testMergeHandlesOp(module: Module):
179193
sequence = transform.SequenceOp(
180194
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
181195
)
@@ -189,7 +203,7 @@ def testMergeHandlesOp():
189203

190204

191205
@run
192-
def testApplyPatternsOpCompact():
206+
def testApplyPatternsOpCompact(module: Module):
193207
sequence = transform.SequenceOp(
194208
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
195209
)
@@ -204,7 +218,7 @@ def testApplyPatternsOpCompact():
204218

205219

206220
@run
207-
def testApplyPatternsOpWithType():
221+
def testApplyPatternsOpWithType(module: Module):
208222
sequence = transform.SequenceOp(
209223
transform.FailurePropagationMode.Propagate, [],
210224
transform.OperationType.get('test.dummy')
@@ -220,7 +234,7 @@ def testApplyPatternsOpWithType():
220234

221235

222236
@run
223-
def testReplicateOp():
237+
def testReplicateOp(module: Module):
224238
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
225239
with InsertionPoint(with_pdl.body):
226240
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)