Skip to content

Commit be056f6

Browse files
Revert "[mlir][python]Add sugared buider for transform.named_sequence (llvm#71597)"
This reverts commit 4f51b2b.
1 parent 4c9f7b6 commit be056f6

File tree

2 files changed

+99
-47
lines changed

2 files changed

+99
-47
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
165165
return self.body.arguments[1:]
166166

167167

168-
@_ods_cext.register_operation(_Dialect, replace=True)
169-
class NamedSequenceOp(NamedSequenceOp):
170-
def __init__(
171-
self,
172-
sym_name,
173-
input_types: Sequence[Type],
174-
result_types: Sequence[Type],
175-
):
176-
function_type = FunctionType.get(input_types, result_types)
177-
super().__init__(
178-
sym_name=sym_name,
179-
function_type=TypeAttr.get(function_type),
180-
)
181-
self.regions[0].blocks.append(*input_types)
182-
183-
@property
184-
def body(self) -> Block:
185-
return self.regions[0].blocks[0]
186-
187-
@property
188-
def bodyTarget(self) -> Value:
189-
return self.body.arguments[0]
190-
191-
@property
192-
def bodyExtraArgs(self) -> BlockArgumentList:
193-
return self.body.arguments[1:]
194-
195-
196168
@_ods_cext.register_operation(_Dialect, replace=True)
197169
class YieldOp(YieldOp):
198170
def __init__(

mlir/test/python/dialects/transform.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ def run(f):
1010
module = Module.create()
1111
with InsertionPoint(module.body):
1212
print("\nTEST:", f.__name__)
13-
f(module)
13+
f()
1414
print(module)
1515
return f
1616

1717

1818
@run
19-
def testTypes(module: Module):
19+
def testTypes():
2020
# CHECK-LABEL: TEST: testTypes
2121
# CHECK: !transform.any_op
2222
any_op = transform.AnyOpType.get()
@@ -44,7 +44,7 @@ def testTypes(module: Module):
4444

4545

4646
@run
47-
def testSequenceOp(module: Module):
47+
def testSequenceOp():
4848
sequence = transform.SequenceOp(
4949
transform.FailurePropagationMode.Propagate,
5050
[transform.AnyOpType.get()],
@@ -60,23 +60,103 @@ def testSequenceOp(module: Module):
6060

6161

6262
@run
63-
def testNamedSequenceOp(module: Module):
64-
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
65-
named_sequence = transform.NamedSequenceOp(
66-
'__transform_main',
67-
[transform.AnyOpType.get()],
63+
def testNestedSequenceOp():
64+
sequence = transform.SequenceOp(
65+
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
66+
)
67+
with InsertionPoint(sequence.body):
68+
nested = transform.SequenceOp(
69+
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
70+
)
71+
with InsertionPoint(nested.body):
72+
doubly_nested = transform.SequenceOp(
73+
transform.FailurePropagationMode.Propagate,
74+
[transform.AnyOpType.get()],
75+
nested.bodyTarget,
76+
)
77+
with InsertionPoint(doubly_nested.body):
78+
transform.YieldOp([doubly_nested.bodyTarget])
79+
transform.YieldOp()
80+
transform.YieldOp()
81+
# CHECK-LABEL: TEST: testNestedSequenceOp
82+
# CHECK: transform.sequence failures(propagate) {
83+
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
84+
# CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
85+
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
86+
# CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
87+
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
88+
# CHECK: yield %[[ARG2]] : !transform.any_op
89+
# CHECK: }
90+
# CHECK: }
91+
# CHECK: }
92+
93+
94+
@run
95+
def testSequenceOpWithExtras():
96+
sequence = transform.SequenceOp(
97+
transform.FailurePropagationMode.Propagate,
98+
[],
99+
transform.AnyOpType.get(),
100+
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
101+
)
102+
with InsertionPoint(sequence.body):
103+
transform.YieldOp()
104+
# CHECK-LABEL: TEST: testSequenceOpWithExtras
105+
# CHECK: transform.sequence failures(propagate)
106+
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
107+
108+
109+
@run
110+
def testNestedSequenceOpWithExtras():
111+
sequence = transform.SequenceOp(
112+
transform.FailurePropagationMode.Propagate,
113+
[],
114+
transform.AnyOpType.get(),
115+
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
116+
)
117+
with InsertionPoint(sequence.body):
118+
nested = transform.SequenceOp(
119+
transform.FailurePropagationMode.Propagate,
120+
[],
121+
sequence.bodyTarget,
122+
sequence.bodyExtraArgs,
123+
)
124+
with InsertionPoint(nested.body):
125+
transform.YieldOp()
126+
transform.YieldOp()
127+
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
128+
# CHECK: transform.sequence failures(propagate)
129+
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
130+
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
131+
132+
133+
@run
134+
def testTransformPDLOps():
135+
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
136+
with InsertionPoint(withPdl.body):
137+
sequence = transform.SequenceOp(
138+
transform.FailurePropagationMode.Propagate,
68139
[transform.AnyOpType.get()],
140+
withPdl.bodyTarget,
69141
)
70-
with InsertionPoint(named_sequence.body):
71-
transform.YieldOp([named_sequence.bodyTarget])
72-
# CHECK-LABEL: TEST: testNamedSequenceOp
73-
# CHECK: module attributes {transform.with_named_sequence} {
74-
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
75-
# CHECK: yield %[[ARG0]] : !transform.any_op
142+
with InsertionPoint(sequence.body):
143+
match = transform_pdl.PDLMatchOp(
144+
transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
145+
)
146+
transform.YieldOp(match)
147+
# CHECK-LABEL: TEST: testTransformPDLOps
148+
# CHECK: transform.with_pdl_patterns {
149+
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
150+
# CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
151+
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
152+
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
153+
# CHECK: yield %[[RES]] : !transform.any_op
154+
# CHECK: }
155+
# CHECK: }
76156

77157

78158
@run
79-
def testGetParentOp(module: Module):
159+
def testGetParentOp():
80160
sequence = transform.SequenceOp(
81161
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
82162
)
@@ -95,7 +175,7 @@ def testGetParentOp(module: Module):
95175

96176

97177
@run
98-
def testMergeHandlesOp(module: Module):
178+
def testMergeHandlesOp():
99179
sequence = transform.SequenceOp(
100180
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
101181
)
@@ -109,7 +189,7 @@ def testMergeHandlesOp(module: Module):
109189

110190

111191
@run
112-
def testApplyPatternsOpCompact(module: Module):
192+
def testApplyPatternsOpCompact():
113193
sequence = transform.SequenceOp(
114194
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
115195
)
@@ -124,7 +204,7 @@ def testApplyPatternsOpCompact(module: Module):
124204

125205

126206
@run
127-
def testApplyPatternsOpWithType(module: Module):
207+
def testApplyPatternsOpWithType():
128208
sequence = transform.SequenceOp(
129209
transform.FailurePropagationMode.Propagate, [],
130210
transform.OperationType.get('test.dummy')
@@ -140,7 +220,7 @@ def testApplyPatternsOpWithType(module: Module):
140220

141221

142222
@run
143-
def testReplicateOp(module: Module):
223+
def testReplicateOp():
144224
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
145225
with InsertionPoint(with_pdl.body):
146226
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)