Skip to content

Commit 04736c7

Browse files
[mlir][SCF] Use transform.get_parent_op instead of transform.loop.get_parent_for (#70757)
Add a new attribute to `get_parent_op` to get the n-th parent. Remove `transform.loop.get_parent_for`, which is no longer needed.
1 parent 83bf8e9 commit 04736c7

File tree

11 files changed

+91
-246
lines changed

11 files changed

+91
-246
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,6 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
6868
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
6969
}
7070

71-
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
72-
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
73-
DeclareOpInterfaceMethods<TransformOpInterface>]> {
74-
let summary = "Gets a handle to the parent 'for' loop of the given operation";
75-
let description = [{
76-
Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for`
77-
(when the affine flag is true) loop for each Payload IR operation
78-
associated with the operand. Fails if such a loop cannot be found. The list
79-
of operations associated with the handle contains parent operations in the
80-
same order as the list associated with the operand, except for operations
81-
that are parents to more than one input which are only present once.
82-
}];
83-
84-
let arguments =
85-
(ins TransformHandleTypeInterface:$target,
86-
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
87-
"1">:$num_loops,
88-
DefaultValuedAttr<BoolAttr, "false">:$affine);
89-
let results = (outs TransformHandleTypeInterface : $parent);
90-
91-
let assemblyFormat =
92-
"$target attr-dict `:` functional-type(operands, results)";
93-
}
94-
9571
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
9672
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
9773
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,11 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
620620
that case for each target op, the closest parent op that fulfills all
621621
requirements, is returned.
622622
- `isolated_from_above`: the parent op must be isolated from above
623-
- `allow_empty_results`: get_parent_op is allowed to return an empty list and
624-
still succeeds. In such a case, if get_parent_op fails for any operation
625-
in the list, the entire transform returns an empty handle.
623+
- `allow_empty_results`: get_parent_op is allowed to return an empty list
624+
and still succeeds. In such a case, if get_parent_op fails for any
625+
operation in the list, the entire transform returns an empty handle.
626626
- `op_name`: the parent op must have the specified name
627+
- `nth_parent`: get the n-th parent of that satisfies the above requirements
627628

628629
If `deduplicate` is set, the result handle does not contain any duplicate
629630
ops. For example, given the list
@@ -641,7 +642,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
641642
UnitAttr:$isolated_from_above,
642643
UnitAttr:$allow_empty_results,
643644
OptionalAttr<StrAttr>:$op_name,
644-
UnitAttr:$deduplicate);
645+
UnitAttr:$deduplicate,
646+
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
647+
"1">:$nth_parent);
645648
let results = (outs TransformHandleTypeInterface:$parent);
646649
let assemblyFormat =
647650
"$target attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,39 +49,6 @@ void transform::ApplySCFStructuralConversionPatternsOp::
4949
conversionTarget);
5050
}
5151

52-
//===----------------------------------------------------------------------===//
53-
// GetParentForOp
54-
//===----------------------------------------------------------------------===//
55-
56-
DiagnosedSilenceableFailure
57-
transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
58-
transform::TransformResults &results,
59-
transform::TransformState &state) {
60-
SetVector<Operation *> parents;
61-
for (Operation *target : state.getPayloadOps(getTarget())) {
62-
Operation *loop, *current = target;
63-
for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
64-
loop = getAffine()
65-
? current->getParentOfType<AffineForOp>().getOperation()
66-
: current->getParentOfType<scf::ForOp>().getOperation();
67-
if (!loop) {
68-
DiagnosedSilenceableFailure diag =
69-
emitSilenceableError()
70-
<< "could not find an '"
71-
<< (getAffine() ? AffineForOp::getOperationName()
72-
: scf::ForOp::getOperationName())
73-
<< "' parent";
74-
diag.attachNote(target->getLoc()) << "target op";
75-
return diag;
76-
}
77-
current = loop;
78-
}
79-
parents.insert(loop);
80-
}
81-
results.set(cast<OpResult>(getResult()), parents.getArrayRef());
82-
return DiagnosedSilenceableFailure::success();
83-
}
84-
8552
//===----------------------------------------------------------------------===//
8653
// ForallToForOp
8754
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,27 +1232,30 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
12321232
SmallVector<Operation *> parents;
12331233
DenseSet<Operation *> resultSet;
12341234
for (Operation *target : state.getPayloadOps(getTarget())) {
1235-
Operation *parent = target->getParentOp();
1236-
while (parent) {
1237-
bool checkIsolatedFromAbove =
1238-
!getIsolatedFromAbove() ||
1239-
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1240-
bool checkOpName = !getOpName().has_value() ||
1241-
parent->getName().getStringRef() == *getOpName();
1242-
if (checkIsolatedFromAbove && checkOpName)
1243-
break;
1235+
Operation *parent = target;
1236+
for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
12441237
parent = parent->getParentOp();
1245-
}
1246-
if (!parent) {
1247-
if (getAllowEmptyResults()) {
1248-
results.set(llvm::cast<OpResult>(getResult()), parents);
1249-
return DiagnosedSilenceableFailure::success();
1238+
while (parent) {
1239+
bool checkIsolatedFromAbove =
1240+
!getIsolatedFromAbove() ||
1241+
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
1242+
bool checkOpName = !getOpName().has_value() ||
1243+
parent->getName().getStringRef() == *getOpName();
1244+
if (checkIsolatedFromAbove && checkOpName)
1245+
break;
1246+
parent = parent->getParentOp();
1247+
}
1248+
if (!parent) {
1249+
if (getAllowEmptyResults()) {
1250+
results.set(llvm::cast<OpResult>(getResult()), parents);
1251+
return DiagnosedSilenceableFailure::success();
1252+
}
1253+
DiagnosedSilenceableFailure diag =
1254+
emitSilenceableError()
1255+
<< "could not find a parent op that matches all requirements";
1256+
diag.attachNote(target->getLoc()) << "target op";
1257+
return diag;
12501258
}
1251-
DiagnosedSilenceableFailure diag =
1252-
emitSilenceableError()
1253-
<< "could not find a parent op that matches all requirements";
1254-
diag.attachNote(target->getLoc()) << "target op";
1255-
return diag;
12561259
}
12571260
if (getDeduplicate()) {
12581261
if (!resultSet.contains(parent)) {

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,28 @@ def patterns(self) -> Block:
5252

5353
@_ods_cext.register_operation(_Dialect, replace=True)
5454
class GetParentOp(GetParentOp):
55-
def __init__(
56-
self,
57-
result_type: Type,
58-
target: Union[Operation, Value],
59-
*,
60-
isolated_from_above: bool = False,
61-
op_name: Optional[str] = None,
62-
deduplicate: bool = False,
63-
loc=None,
64-
ip=None,
65-
):
66-
super().__init__(
67-
result_type,
68-
_get_op_result_or_value(target),
69-
isolated_from_above=isolated_from_above,
70-
op_name=op_name,
71-
deduplicate=deduplicate,
72-
loc=loc,
73-
ip=ip,
74-
)
55+
def __init__(
56+
self,
57+
result_type: Type,
58+
target: Union[Operation, Value],
59+
*,
60+
isolated_from_above: bool = False,
61+
op_name: Optional[str] = None,
62+
deduplicate: bool = False,
63+
nth_parent: int = 1,
64+
loc=None,
65+
ip=None,
66+
):
67+
super().__init__(
68+
result_type,
69+
_get_op_result_or_value(target),
70+
isolated_from_above=isolated_from_above,
71+
op_name=op_name,
72+
deduplicate=deduplicate,
73+
nth_parent=nth_parent,
74+
loc=loc,
75+
ip=ip,
76+
)
7577

7678

7779
@_ods_cext.register_operation(_Dialect, replace=True)

mlir/python/mlir/dialects/transform/loop.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,6 @@
1717
from typing import Optional, Union
1818

1919

20-
@_ods_cext.register_operation(_Dialect, replace=True)
21-
class GetParentForOp(GetParentForOp):
22-
"""Extension for GetParentForOp."""
23-
24-
def __init__(
25-
self,
26-
result_type: Type,
27-
target: Union[Operation, Value],
28-
*,
29-
num_loops: Optional[int] = None,
30-
ip=None,
31-
loc=None,
32-
):
33-
if num_loops is None:
34-
num_loops = 1
35-
super().__init__(
36-
result_type,
37-
_get_op_result_or_value(target),
38-
num_loops=num_loops,
39-
ip=ip,
40-
loc=loc,
41-
)
42-
43-
4420
@_ods_cext.register_operation(_Dialect, replace=True)
4521
class LoopOutlineOp(LoopOutlineOp):
4622
"""Extension for LoopOutlineOp."""

mlir/test/Dialect/SCF/transform-ops-invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func.func @test_loops_do_not_get_unrolled() {
3232
module attributes {transform.with_named_sequence} {
3333
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
3434
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
35-
%1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
35+
%1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
3636
// expected-error @below {{failed to unroll}}
3737
transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for">
3838
transform.yield
@@ -81,7 +81,7 @@ func.func @test_loops_do_not_get_peeled() {
8181
module attributes {transform.with_named_sequence} {
8282
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
8383
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
84-
%1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
84+
%1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
8585
// expected-error @below {{failed to peel}}
8686
transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
8787
transform.yield

0 commit comments

Comments
 (0)