Skip to content

[mlir][Transforms] Add a utility method to move value definitions. #130874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"

#include "llvm/ADT/SetVector.h"

Expand Down Expand Up @@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);

/// Move definitions of `values` before an insertion point. Current support is
/// only for movement of definitions within the same basic block. Note that this
/// is an all-or-nothing approach. Either definitions of all values are moved
/// before insertion point, or none of them are.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint,
DominanceInfo &dominance);
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint);

/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
/// elimination, as well as some other DCE. This function returns success if any
Expand Down
69 changes: 68 additions & 1 deletion mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
op, "unsupported caes where operation and insertion point are not in "
op, "unsupported case where operation and insertion point are not in "
"the same basic block");
}
// If `insertionPoint` does not dominate `op`, do nothing
Expand Down Expand Up @@ -1115,3 +1115,70 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
DominanceInfo dominance(op);
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
}

LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
ValueRange values,
Operation *insertionPoint,
DominanceInfo &dominance) {
// Remove the values that already dominate the insertion point.
SmallVector<Value> prunedValues;
for (auto value : values) {
if (dominance.properlyDominates(value, insertionPoint)) {
continue;
}
// Block arguments are not supported.
if (isa<BlockArgument>(value)) {
return rewriter.notifyMatchFailure(
insertionPoint,
"unsupported case of moving block argument before insertion point");
}
// Check for currently unsupported case if the insertion point is in a
// different block.
if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
insertionPoint,
"unsupported case of moving definition of value before an insertion "
"point in a different basic block");
}
prunedValues.push_back(value);
}

// Find the backward slice of operation for each `Value` the operation
// depends on. Prune the slice to only include operations not already
// dominated by the `insertionPoint`
BackwardSliceOptions options;
options.inclusive = true;
options.omitUsesFromAbove = false;
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
options.filter = [&](Operation *sliceBoundaryOp) {
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
getBackwardSlice(value, &slice, options);
}

// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
insertionPoint,
"cannot move dependencies before operation in backward slice of op");
}

// Sort operations topologically before moving.
mlir::topologicalSort(slice);

for (Operation *op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
ValueRange values,
Operation *insertionPoint) {
DominanceInfo dominance(insertionPoint);
return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
}
226 changes: 226 additions & 0 deletions mlir/test/Transforms/move-operation-deps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// Check simple move value definitions before insertion operation.
func.func @simple_move_values() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "foo"(%1, %2) : (f32, f32) -> (f32)
return %3 : f32
}
// CHECK-LABEL: func @simple_move_values()
// CHECK: %[[MOVED1:.+]] = "moved_op_1"
// CHECK: %[[MOVED2:.+]] = "moved_op_2"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1, %v2 before %op3
: (!transform.any_value, !transform.any_value), !transform.any_op
transform.yield
}
}

// -----

// Compute slice including the implicitly captured values.
func.func @move_region_dependencies_values() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() ({
%3 = "inner_op"(%1) : (f32) -> (f32)
"yield"(%3) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @move_region_dependencies_values()
// CHECK: %[[MOVED1:.+]] = "moved_op_1"
// CHECK: %[[MOVED2:.+]] = "moved_op_2"
// CHECK: %[[BEFORE:.+]] = "before"

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op2
: (!transform.any_value), !transform.any_op
transform.yield
}
}

// -----

// Move operations in toplogical sort order
func.func @move_values_in_topological_sort_order() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "moved_op_3"(%1) : (f32) -> (f32)
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
%5 = "moved_op_5"(%2) : (f32) -> (f32)
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
return %6 : f32
}
// CHECK-LABEL: func @move_values_in_topological_sort_order()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1, %v2 before %op3
: (!transform.any_value, !transform.any_value), !transform.any_op
transform.yield
}
}

// -----

// Move only those value definitions that are not dominated by insertion point

func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
%1 = "dummy_op"() : () -> (f32)
%2 = "before"() : () -> (f32)
%3 = "moved_op"() : () -> (f32)
return %0, %1, %2, %3 : f32, f32, f32, f32
}
// CHECK-LABEL: func @move_only_required_defns()
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
// CHECK: %[[DUMMY:.+]] = "dummy_op"
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["dummy_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op4 = transform.structured.match ops{["moved_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1, %v2 before %op3
: (!transform.any_value, !transform.any_value), !transform.any_op
transform.yield
}
}

// -----

// Move only those value definitions that are not dominated by insertion point

func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
%1 = "dummy_op"() : () -> (f32)
%2 = "before"() : () -> (f32)
%3 = "moved_op"() : () -> (f32)
return %0, %1, %2, %3 : f32, f32, f32, f32
}
// CHECK-LABEL: func @move_only_required_defns()
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
// CHECK: %[[DUMMY:.+]] = "dummy_op"
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["dummy_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op4 = transform.structured.match ops{["moved_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1, %v2 before %op3
: (!transform.any_value, !transform.any_value), !transform.any_op
transform.yield
}
}

// -----

// Check handling of block arguments
func.func @move_only_required_defns() -> (f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
cf.br ^bb0(%0 : f32)
^bb0(%arg0 : f32) :
%1 = "before"() : () -> (f32)
%2 = "moved_op"(%arg0) : (f32) -> (f32)
return %1, %2 : f32, f32
}
// CHECK-LABEL: func @move_only_required_defns()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}

// -----

// Do not move across basic blocks
func.func @no_move_across_basic_blocks() -> (f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
%1 = "before"() : () -> (f32)
cf.br ^bb0(%0 : f32)
^bb0(%arg0 : f32) :
%2 = "moved_op"(%arg0) : (f32) -> (f32)
return %1, %2 : f32, f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}
17 changes: 17 additions & 0 deletions mlir/test/lib/Transforms/TestTransformsOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
TransformResults &TransformResults,
TransformState &state) {
SmallVector<Value> values;
for (auto tdValue : getValues()) {
values.push_back(*state.getPayloadValues(tdValue).begin());
}
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
std::string errorMsg = listener->getLatestMatchFailureMessage();
(void)emitRemark(errorMsg);
}
return DiagnosedSilenceableFailure::success();
}

namespace {

class TestTransformsDialectExtension
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/lib/Transforms/TestTransformsOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,26 @@ def TestMoveOperandDeps :
}];
}

def TestMoveValueDefns :
Op<Transform_Dialect, "test.move_value_defns",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Moves all dependencies of on operation before another operation.
}];

let arguments =
(ins Variadic<TransformValueHandleTypeInterface>:$values,
TransformHandleTypeInterface:$insertion_point);

let results = (outs);

let assemblyFormat = [{
$values `before` $insertion_point attr-dict
`:` `(` type($values) `)` `` `,` type($insertion_point)
}];
}


#endif // TEST_TRANSFORM_OPS