Skip to content

[mlir][SCF] Use transform.get_parent_op instead of transform.loop.get_parent_for #70757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,6 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Gets a handle to the parent 'for' loop of the given operation";
let description = [{
Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for`
(when the affine flag is true) loop for each Payload IR operation
associated with the operand. Fails if such a loop cannot be found. The list
of operations associated with the handle contains parent operations in the
same order as the list associated with the operand, except for operations
that are parents to more than one input which are only present once.
}];

let arguments =
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
"1">:$num_loops,
DefaultValuedAttr<BoolAttr, "false">:$affine);
let results = (outs TransformHandleTypeInterface : $parent);

let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}

def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
Expand Down
11 changes: 7 additions & 4 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,11 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
that case for each target op, the closest parent op that fulfills all
requirements, is returned.
- `isolated_from_above`: the parent op must be isolated from above
- `allow_empty_results`: get_parent_op is allowed to return an empty list and
still succeeds. In such a case, if get_parent_op fails for any operation
in the list, the entire transform returns an empty handle.
- `allow_empty_results`: get_parent_op is allowed to return an empty list
and still succeeds. In such a case, if get_parent_op fails for any
operation in the list, the entire transform returns an empty handle.
- `op_name`: the parent op must have the specified name
- `nth_parent`: get the n-th parent of that satisfies the above requirements

If `deduplicate` is set, the result handle does not contain any duplicate
ops. For example, given the list
Expand All @@ -641,7 +642,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
UnitAttr:$isolated_from_above,
UnitAttr:$allow_empty_results,
OptionalAttr<StrAttr>:$op_name,
UnitAttr:$deduplicate);
UnitAttr:$deduplicate,
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
"1">:$nth_parent);
let results = (outs TransformHandleTypeInterface:$parent);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
Expand Down
33 changes: 0 additions & 33 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,6 @@ void transform::ApplySCFStructuralConversionPatternsOp::
conversionTarget);
}

//===----------------------------------------------------------------------===//
// GetParentForOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SetVector<Operation *> parents;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *loop, *current = target;
for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
loop = getAffine()
? current->getParentOfType<AffineForOp>().getOperation()
: current->getParentOfType<scf::ForOp>().getOperation();
if (!loop) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find an '"
<< (getAffine() ? AffineForOp::getOperationName()
: scf::ForOp::getOperationName())
<< "' parent";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
current = loop;
}
parents.insert(loop);
}
results.set(cast<OpResult>(getResult()), parents.getArrayRef());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//
Expand Down
41 changes: 22 additions & 19 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1232,27 +1232,30 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Operation *> parents;
DenseSet<Operation *> resultSet;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *parent = target->getParentOp();
while (parent) {
bool checkIsolatedFromAbove =
!getIsolatedFromAbove() ||
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
bool checkOpName = !getOpName().has_value() ||
parent->getName().getStringRef() == *getOpName();
if (checkIsolatedFromAbove && checkOpName)
break;
Operation *parent = target;
for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
parent = parent->getParentOp();
}
if (!parent) {
if (getAllowEmptyResults()) {
results.set(llvm::cast<OpResult>(getResult()), parents);
return DiagnosedSilenceableFailure::success();
while (parent) {
bool checkIsolatedFromAbove =
!getIsolatedFromAbove() ||
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
bool checkOpName = !getOpName().has_value() ||
parent->getName().getStringRef() == *getOpName();
if (checkIsolatedFromAbove && checkOpName)
break;
parent = parent->getParentOp();
}
if (!parent) {
if (getAllowEmptyResults()) {
results.set(llvm::cast<OpResult>(getResult()), parents);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a parent op that matches all requirements";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a parent op that matches all requirements";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
if (getDeduplicate()) {
if (!resultSet.contains(parent)) {
Expand Down
42 changes: 22 additions & 20 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,28 @@ def patterns(self) -> Block:

@_ods_cext.register_operation(_Dialect, replace=True)
class GetParentOp(GetParentOp):
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
isolated_from_above: bool = False,
op_name: Optional[str] = None,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
isolated_from_above=isolated_from_above,
op_name=op_name,
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
isolated_from_above: bool = False,
op_name: Optional[str] = None,
deduplicate: bool = False,
nth_parent: int = 1,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
isolated_from_above=isolated_from_above,
op_name=op_name,
deduplicate=deduplicate,
nth_parent=nth_parent,
loc=loc,
ip=ip,
)


@_ods_cext.register_operation(_Dialect, replace=True)
Expand Down
24 changes: 0 additions & 24 deletions mlir/python/mlir/dialects/transform/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,6 @@
from typing import Optional, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class GetParentForOp(GetParentForOp):
"""Extension for GetParentForOp."""

def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: Optional[int] = None,
ip=None,
loc=None,
):
if num_loops is None:
num_loops = 1
super().__init__(
result_type,
_get_op_result_or_value(target),
num_loops=num_loops,
ip=ip,
loc=loc,
)


@_ods_cext.register_operation(_Dialect, replace=True)
class LoopOutlineOp(LoopOutlineOp):
"""Extension for LoopOutlineOp."""
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SCF/transform-ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func.func @test_loops_do_not_get_unrolled() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for">
%1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
// expected-error @below {{failed to unroll}}
transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for">
transform.yield
Expand Down Expand Up @@ -81,7 +81,7 @@ func.func @test_loops_do_not_get_peeled() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
%1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
// expected-error @below {{failed to peel}}
transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
transform.yield
Expand Down
Loading