Skip to content

[mlir][Transforms] Add a utility method to move value definitions. #130874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

MaheshRavishankar
Copy link
Contributor

205c532 added a transform utility that moved all SSA dependences of an operation before an insertion point. Similar to that, this PR adds a transform utility function, moveValueDefinitions to move the slice of operations that define all values in a ValueRange before the insertion point. While very similar to moveOperationDependencies, this method differs in a few ways

  1. When computing the backward slice since the start of the slice is value, the slice computed needs to be inclusive.
  2. The combined backward slice needs to be sorted topologically before moving them to avoid SSA use-def violations while moving individual ops.

The PR also adds a new transform op to test this new utility function.

llvm@205c532
added a transform utility that moved all SSA dependences of an
operation before an insertion point. Similar to that, this PR adds a
transform utility function, `moveValueDefinitions` to move the slice
of operations that define all values in a `ValueRange` before the
insertion point. While very similar to `moveOperationDependencies`,
this method differs in a few ways

1. When computing the backward slice since the start of the slice is
   value, the slice computed needs to be inclusive.

2. The combined backward slice needs to be sorted topologically before
   moving them to avoid SSA use-def violations while moving individual
   ops.

The PR also adds a new transform op to test this new utility function.

Signed-off-by: MaheshRavishankar <[email protected]>
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

205c532 added a transform utility that moved all SSA dependences of an operation before an insertion point. Similar to that, this PR adds a transform utility function, moveValueDefinitions to move the slice of operations that define all values in a ValueRange before the insertion point. While very similar to moveOperationDependencies, this method differs in a few ways

  1. When computing the backward slice since the start of the slice is value, the slice computed needs to be inclusive.
  2. The combined backward slice needs to be sorted topologically before moving them to avoid SSA use-def violations while moving individual ops.

The PR also adds a new transform op to test this new utility function.


Full diff: https://github.com/llvm/llvm-project/pull/130874.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+11)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+70-1)
  • (modified) mlir/test/Transforms/move-operation-deps.mlir (+226)
  • (modified) mlir/test/lib/Transforms/TestTransformsOps.cpp (+18)
  • (modified) mlir/test/lib/Transforms/TestTransformsOps.td (+22)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index e6b928d8ebecc..2ed96afbace81 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Region.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 
 #include "llvm/ADT/SetVector.h"
 
@@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
 LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
                                         Operation *insertionPoint);
 
+/// Move definitions of `values` before an insertion point. Current support is
+/// only for movement of definitions within the same basic block. Note that this
+/// is an all-or-nothing approach. Either definitions of all values are moved
+/// before insertion point, or none of them are.
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+                                   Operation *insertionPoint,
+                                   DominanceInfo &dominance);
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+                                   Operation *insertionPoint);
+
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index da0d486f0fdcb..6987a13b309d7 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
   // in different basic blocks.
   if (op->getBlock() != insertionPoint->getBlock()) {
     return rewriter.notifyMatchFailure(
-        op, "unsupported caes where operation and insertion point are not in "
+        op, "unsupported case where operation and insertion point are not in "
             "the same basic block");
   }
   // If `insertionPoint` does not dominate `op`, do nothing
@@ -1115,3 +1115,72 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
   DominanceInfo dominance(op);
   return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
 }
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+                                         ValueRange values,
+                                         Operation *insertionPoint,
+                                         DominanceInfo &dominance) {
+  // Remove the values that already dominate the insertion point.
+  SmallVector<Value> prunedValues;
+  for (auto value : values) {
+    if (dominance.properlyDominates(value, insertionPoint)) {
+      continue;
+    }
+    // Block arguments are not supported.
+    if (isa<BlockArgument>(value)) {
+      return rewriter.notifyMatchFailure(
+          insertionPoint,
+          "unsupported case of moving block argument before insertion point");
+    }
+    // Check for currently unsupported case if the insertion point is in a
+    // different block.
+    if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
+      return rewriter.notifyMatchFailure(
+          insertionPoint,
+          "unsupported case of moving definition of value before an insertion "
+          "point in a different basic block");
+    }
+    prunedValues.push_back(value);
+  }
+
+  // Find the backward slice of operation for each `Value` the operation
+  // depends on. Prune the slice to only include operations not already
+  // dominated by the `insertionPoint`
+  BackwardSliceOptions options;
+  options.inclusive = true;
+  options.omitUsesFromAbove = false;
+  // Since current support is to only move within a same basic block,
+  // the slices dont need to look past block arguments.
+  options.omitBlockArguments = true;
+  options.filter = [&](Operation *sliceBoundaryOp) {
+    return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+  };
+  llvm::SetVector<Operation *> slice;
+  for (auto value : prunedValues) {
+    getBackwardSlice(value, &slice, options);
+  }
+
+  // If the slice contains `insertionPoint` cannot move the dependencies.
+  if (slice.contains(insertionPoint)) {
+    return rewriter.notifyMatchFailure(
+        insertionPoint,
+        "cannot move dependencies before operation in backward slice of op");
+  }
+
+  // Sort operations topologically before moving.
+  mlir::topologicalSort(slice);
+
+  // We should move the slice in topological order, but `getBackwardSlice`
+  // already does that. So no need to sort again.
+  for (Operation *op : slice) {
+    rewriter.moveOpBefore(op, insertionPoint);
+  }
+  return success();
+}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+                                         ValueRange values,
+                                         Operation *insertionPoint) {
+  DominanceInfo dominance(insertionPoint);
+  return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+}
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
index 37637152938f6..aa7b5dc2a240a 100644
--- a/mlir/test/Transforms/move-operation-deps.mlir
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Check simple move value definitions before insertion operation.
+func.func @simple_move_values() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "foo"(%1, %2) : (f32, f32) -> (f32)
+  return %3 : f32
+}
+// CHECK-LABEL: func @simple_move_values()
+//       CHECK:   %[[MOVED1:.+]] = "moved_op_1"
+//       CHECK:   %[[MOVED2:.+]] = "moved_op_2"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Compute slice including the implicitly captured values.
+func.func @move_region_dependencies_values() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() ({
+    %3 = "inner_op"(%1) : (f32) -> (f32)
+    "yield"(%3) : (f32) -> ()
+  }) : () -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies_values()
+//       CHECK:   %[[MOVED1:.+]] = "moved_op_1"
+//       CHECK:   %[[MOVED2:.+]] = "moved_op_2"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1 before %op2
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_values_in_topological_sort_order() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "moved_op_3"(%1) : (f32) -> (f32)
+  %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+  %5 = "moved_op_5"(%2) : (f32) -> (f32)
+  %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+  return %6 : f32
+}
+// CHECK-LABEL: func @move_values_in_topological_sort_order()
+//       CHECK:   %[[MOVED_1:.+]] = "moved_op_1"
+//   CHECK-DAG:   %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+//   CHECK-DAG:   %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+//   CHECK-DAG:   %[[MOVED_4:.+]] = "moved_op_2"
+//   CHECK-DAG:   %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "dummy_op"() : () -> (f32)
+  %2 = "before"() : () -> (f32)
+  %3 = "moved_op"() : () -> (f32)
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[UNMOVED:.+]] = "unmoved_op"
+//       CHECK:   %[[DUMMY:.+]] = "dummy_op"
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "dummy_op"() : () -> (f32)
+  %2 = "before"() : () -> (f32)
+  %3 = "moved_op"() : () -> (f32)
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[UNMOVED:.+]] = "unmoved_op"
+//       CHECK:   %[[DUMMY:.+]] = "dummy_op"
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Check handling of block arguments
+func.func @move_only_required_defns() -> (f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  cf.br ^bb0(%0 : f32) 
+ ^bb0(%arg0 : f32) :
+  %1 = "before"() : () -> (f32)
+  %2 = "moved_op"(%arg0) : (f32) -> (f32)
+  return %1, %2 : f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1 before %op1
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Do not move across basic blocks
+func.func @no_move_across_basic_blocks() -> (f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "before"() : () -> (f32)
+  cf.br ^bb0(%0 : f32) 
+ ^bb0(%arg0 : f32) :
+  %2 = "moved_op"(%arg0) : (f32) -> (f32)
+  return %1, %2 : f32, f32
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    // expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
+    transform.test.move_value_defns %v1 before %op1
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index aaa566d9938a3..3d95af59f6da3 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -39,6 +39,24 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure
+transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
+                                      TransformResults &TransformResults,
+                                      TransformState &state) {
+  SmallVector<Value> values;
+  for (auto tdValue : getValues()) {
+    values.push_back(*state.getPayloadValues(tdValue).begin());
+  }
+  Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+  if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
+    auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+    std::string errorMsg = listener->getLatestMatchFailureMessage();
+    (void)emitRemark(errorMsg);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+
 namespace {
 
 class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index f514702cef5bc..495579b452dfc 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -38,4 +38,26 @@ def TestMoveOperandDeps :
   }];
 }
 
+def TestMoveValueDefns :
+    Op<Transform_Dialect, "test.move_value_defns",
+        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+         DeclareOpInterfaceMethods<TransformOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Moves all dependencies of on operation before another operation.
+  }];
+
+  let arguments =
+    (ins Variadic<TransformValueHandleTypeInterface>:$values,
+         TransformHandleTypeInterface:$insertion_point);
+  
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $values `before` $insertion_point attr-dict
+    `:` `(` type($values) `)` `` `,` type($insertion_point)
+  }];
+}
+
+
 #endif // TEST_TRANSFORM_OPS

@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-mlir-core

Author: None (MaheshRavishankar)

Changes

205c532 added a transform utility that moved all SSA dependences of an operation before an insertion point. Similar to that, this PR adds a transform utility function, moveValueDefinitions to move the slice of operations that define all values in a ValueRange before the insertion point. While very similar to moveOperationDependencies, this method differs in a few ways

  1. When computing the backward slice since the start of the slice is value, the slice computed needs to be inclusive.
  2. The combined backward slice needs to be sorted topologically before moving them to avoid SSA use-def violations while moving individual ops.

The PR also adds a new transform op to test this new utility function.


Full diff: https://github.com/llvm/llvm-project/pull/130874.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+11)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+70-1)
  • (modified) mlir/test/Transforms/move-operation-deps.mlir (+226)
  • (modified) mlir/test/lib/Transforms/TestTransformsOps.cpp (+18)
  • (modified) mlir/test/lib/Transforms/TestTransformsOps.td (+22)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index e6b928d8ebecc..2ed96afbace81 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Region.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 
 #include "llvm/ADT/SetVector.h"
 
@@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
 LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
                                         Operation *insertionPoint);
 
+/// Move definitions of `values` before an insertion point. Current support is
+/// only for movement of definitions within the same basic block. Note that this
+/// is an all-or-nothing approach. Either definitions of all values are moved
+/// before insertion point, or none of them are.
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+                                   Operation *insertionPoint,
+                                   DominanceInfo &dominance);
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+                                   Operation *insertionPoint);
+
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index da0d486f0fdcb..6987a13b309d7 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
   // in different basic blocks.
   if (op->getBlock() != insertionPoint->getBlock()) {
     return rewriter.notifyMatchFailure(
-        op, "unsupported caes where operation and insertion point are not in "
+        op, "unsupported case where operation and insertion point are not in "
             "the same basic block");
   }
   // If `insertionPoint` does not dominate `op`, do nothing
@@ -1115,3 +1115,72 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
   DominanceInfo dominance(op);
   return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
 }
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+                                         ValueRange values,
+                                         Operation *insertionPoint,
+                                         DominanceInfo &dominance) {
+  // Remove the values that already dominate the insertion point.
+  SmallVector<Value> prunedValues;
+  for (auto value : values) {
+    if (dominance.properlyDominates(value, insertionPoint)) {
+      continue;
+    }
+    // Block arguments are not supported.
+    if (isa<BlockArgument>(value)) {
+      return rewriter.notifyMatchFailure(
+          insertionPoint,
+          "unsupported case of moving block argument before insertion point");
+    }
+    // Check for currently unsupported case if the insertion point is in a
+    // different block.
+    if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
+      return rewriter.notifyMatchFailure(
+          insertionPoint,
+          "unsupported case of moving definition of value before an insertion "
+          "point in a different basic block");
+    }
+    prunedValues.push_back(value);
+  }
+
+  // Find the backward slice of operation for each `Value` the operation
+  // depends on. Prune the slice to only include operations not already
+  // dominated by the `insertionPoint`
+  BackwardSliceOptions options;
+  options.inclusive = true;
+  options.omitUsesFromAbove = false;
+  // Since current support is to only move within a same basic block,
+  // the slices dont need to look past block arguments.
+  options.omitBlockArguments = true;
+  options.filter = [&](Operation *sliceBoundaryOp) {
+    return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+  };
+  llvm::SetVector<Operation *> slice;
+  for (auto value : prunedValues) {
+    getBackwardSlice(value, &slice, options);
+  }
+
+  // If the slice contains `insertionPoint` cannot move the dependencies.
+  if (slice.contains(insertionPoint)) {
+    return rewriter.notifyMatchFailure(
+        insertionPoint,
+        "cannot move dependencies before operation in backward slice of op");
+  }
+
+  // Sort operations topologically before moving.
+  mlir::topologicalSort(slice);
+
+  // We should move the slice in topological order, but `getBackwardSlice`
+  // already does that. So no need to sort again.
+  for (Operation *op : slice) {
+    rewriter.moveOpBefore(op, insertionPoint);
+  }
+  return success();
+}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+                                         ValueRange values,
+                                         Operation *insertionPoint) {
+  DominanceInfo dominance(insertionPoint);
+  return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+}
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
index 37637152938f6..aa7b5dc2a240a 100644
--- a/mlir/test/Transforms/move-operation-deps.mlir
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Check simple move value definitions before insertion operation.
+func.func @simple_move_values() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "foo"(%1, %2) : (f32, f32) -> (f32)
+  return %3 : f32
+}
+// CHECK-LABEL: func @simple_move_values()
+//       CHECK:   %[[MOVED1:.+]] = "moved_op_1"
+//       CHECK:   %[[MOVED2:.+]] = "moved_op_2"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Compute slice including the implicitly captured values.
+func.func @move_region_dependencies_values() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() ({
+    %3 = "inner_op"(%1) : (f32) -> (f32)
+    "yield"(%3) : (f32) -> ()
+  }) : () -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies_values()
+//       CHECK:   %[[MOVED1:.+]] = "moved_op_1"
+//       CHECK:   %[[MOVED2:.+]] = "moved_op_2"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1 before %op2
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_values_in_topological_sort_order() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "moved_op_3"(%1) : (f32) -> (f32)
+  %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+  %5 = "moved_op_5"(%2) : (f32) -> (f32)
+  %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+  return %6 : f32
+}
+// CHECK-LABEL: func @move_values_in_topological_sort_order()
+//       CHECK:   %[[MOVED_1:.+]] = "moved_op_1"
+//   CHECK-DAG:   %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+//   CHECK-DAG:   %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+//   CHECK-DAG:   %[[MOVED_4:.+]] = "moved_op_2"
+//   CHECK-DAG:   %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "dummy_op"() : () -> (f32)
+  %2 = "before"() : () -> (f32)
+  %3 = "moved_op"() : () -> (f32)
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[UNMOVED:.+]] = "unmoved_op"
+//       CHECK:   %[[DUMMY:.+]] = "dummy_op"
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "dummy_op"() : () -> (f32)
+  %2 = "before"() : () -> (f32)
+  %3 = "moved_op"() : () -> (f32)
+  return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[UNMOVED:.+]] = "unmoved_op"
+//       CHECK:   %[[DUMMY:.+]] = "dummy_op"
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op3 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+    %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1, %v2 before %op3
+        : (!transform.any_value, !transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Check handling of block arguments
+func.func @move_only_required_defns() -> (f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  cf.br ^bb0(%0 : f32) 
+ ^bb0(%arg0 : f32) :
+  %1 = "before"() : () -> (f32)
+  %2 = "moved_op"(%arg0) : (f32) -> (f32)
+  return %1, %2 : f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    transform.test.move_value_defns %v1 before %op1
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Do not move across basic blocks
+func.func @no_move_across_basic_blocks() -> (f32, f32) {
+  %0 = "unmoved_op"() : () -> (f32)
+  %1 = "before"() : () -> (f32)
+  cf.br ^bb0(%0 : f32) 
+ ^bb0(%arg0 : f32) :
+  %2 = "moved_op"(%arg0) : (f32) -> (f32)
+  return %1, %2 : f32, f32
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+    // expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
+    transform.test.move_value_defns %v1 before %op1
+        : (!transform.any_value), !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index aaa566d9938a3..3d95af59f6da3 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -39,6 +39,24 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure
+transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
+                                      TransformResults &TransformResults,
+                                      TransformState &state) {
+  SmallVector<Value> values;
+  for (auto tdValue : getValues()) {
+    values.push_back(*state.getPayloadValues(tdValue).begin());
+  }
+  Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+  if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
+    auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+    std::string errorMsg = listener->getLatestMatchFailureMessage();
+    (void)emitRemark(errorMsg);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+
 namespace {
 
 class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index f514702cef5bc..495579b452dfc 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -38,4 +38,26 @@ def TestMoveOperandDeps :
   }];
 }
 
+def TestMoveValueDefns :
+    Op<Transform_Dialect, "test.move_value_defns",
+        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+         DeclareOpInterfaceMethods<TransformOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Moves all dependencies of on operation before another operation.
+  }];
+
+  let arguments =
+    (ins Variadic<TransformValueHandleTypeInterface>:$values,
+         TransformHandleTypeInterface:$insertion_point);
+  
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $values `before` $insertion_point attr-dict
+    `:` `(` type($values) `)` `` `,` type($insertion_point)
+  }];
+}
+
+
 #endif // TEST_TRANSFORM_OPS

Copy link

github-actions bot commented Mar 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar merged commit 665299e into llvm:main Mar 12, 2025
11 checks passed
frederik-h pushed a commit to frederik-h/llvm-project that referenced this pull request Mar 18, 2025
…lvm#130874)

llvm@205c532
added a transform utility that moved all SSA dependences of an operation
before an insertion point. Similar to that, this PR adds a transform
utility function, `moveValueDefinitions` to move the slice of operations
that define all values in a `ValueRange` before the insertion point.
While very similar to `moveOperationDependencies`, this method differs
in a few ways

1. When computing the backward slice since the start of the slice is
value, the slice computed needs to be inclusive.
2. The combined backward slice needs to be sorted topologically before
moving them to avoid SSA use-def violations while moving individual ops.

The PR also adds a new transform op to test this new utility function.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants