Skip to content

[mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs #86108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2024

Conversation

stevenvar
Copy link
Contributor

@stevenvar stevenvar commented Mar 21, 2024

This transformation, inspired by what is done in hoist_redundant_transfers, hoists pairs of extract/broadcast operations out of scf.for loops.

It changes a loop of the form:

%res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
  %e = vector.extract %iarg : t1 to t2
  %u = "some_use"(%e) : (t2) -> t2
  %b = vector.broadcast %u : t2 to t1
  scf.yield %b : t1
}

into the following:

%e = vector.extract %v: t1 to t2
%res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
  %u' = "some_use"(%iarg) : (t2) -> t2
  scf.yield %u' : t2
}
%res = vector.broadcast %res' : t2 to t1

@stevenvar
Copy link
Contributor Author

stevenvar commented Mar 21, 2024

Motivation: during internal development we noticed quite a few instances of such patterns where many loops would start with vector.extract and end with vector.broadcast. This transform results in important speedup and we think it could benefit the community when encountering a similar pattern.

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Steven Varoumas (stevenvar)

Changes

This transformation, inspired by what is done in hoist_redundant_transfers, hoists pairs of extract/broadcast operations out of scf.for loops.

It changes a loop of the form:

%loop = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
  %e = vector.extract %iarg : t1 to t2
  %u = "some_use"(%e) : (t2) -> t2
  %b = vector.broadcast %u : t2 to t1
  scf.yield %b : t1
}

into the following:

%e = vector.extract %v: t1 to t2
%loop' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
  %u' = "some_use"(%iarg) : (t2) -> t2
  scf.yield %u' : t2
}
%loop = vector.broadcast %loop' : t2 to t1

Full diff: https://github.com/llvm/llvm-project/pull/86108.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+36)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h (+2)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+14)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+114)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+109)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4f34016066b4ce..7e1b2894dc0126 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2213,6 +2213,42 @@ def HoistRedundantVectorTransfersOp :
    }];
 }
 
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantVectorBroadcastsOp :
+  Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformEachOpTrait, TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Hoist vector.extract / vector.broadcasts pairs out of immediately
+    enclosing scf::ForOp iteratively.
+
+    #### Return modes:
+
+    The operation always succeeds and returns a handle to the transformed
+    function op.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+         ::mlir::transform::TransformRewriter &rewriter,
+         ::mlir::func::FuncOp target,
+         ::mlir::transform::ApplyToEachResultList &results,
+         ::mlir::transform::TransformState &state);
+   }];
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertConv2DToImg2ColOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 186e83a57580f3..11886d4876a97f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -43,6 +43,8 @@ namespace linalg {
 /// when used on distributed loops with memref semantics!
 void hoistRedundantVectorTransfers(Operation *root);
 
+void hoistRedundantVectorBroadcasts(Operation *root);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ecf9983124821a..d7e6a21a565a75 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3224,6 +3224,20 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantVectorBroadcastsOp::applyToOne(
+    transform::TransformRewriter &rewriter, func::FuncOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  linalg::hoistRedundantVectorBroadcasts(target);
+  results.push_back(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertConv2DToImg2ColOp.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 34c9b2c282965c..98521cd745216c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,6 +43,120 @@ using llvm::dbgs;
 using namespace mlir;
 using namespace mlir::linalg;
 
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+                                     Value newInitOperand, int index,
+                                     Value newYieldValue) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loop.getOperation());
+  auto inits = llvm::to_vector(loop.getInits());
+
+  // Replace the init value with the new operand
+  inits[index] = newInitOperand;
+
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Generate the new yield with the replaced operand
+  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+  yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           loop.getBody()->getNumArguments()));
+
+  // Replace the old loop.
+  rewriter.replaceOp(loop.getOperation(),
+                     newLoop->getResults().take_front(loop.getNumResults()));
+  return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+//  %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
+//   %e = vector.extract %iterarg : t1 to t2
+//   %u = // do something with %e : t2
+//   %b = vector.broadcast %u : t2 to t1
+//   scf.yield %b : t1
+//  }
+// into the following:
+//  %e = vector.extract %v: t1 to t2
+//  %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
+//   %u' = // do something with %iterarg : t2
+//   scf.yield %u' : t2
+//  }
+//  %loop = vector.broadcast %loop' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    // First move loop invariant ops outside of their loop. This needs to be
+    // done before as we cannot move ops without interrupting the function walk.
+    root->walk(
+        [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+    root->walk([&](vector::ExtractOp extractOp) {
+      LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+                        << *extractOp.getOperation() << "\n");
+
+      auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+      if (!loop)
+        return WalkResult::advance();
+
+      // Check that the vector to extract from is an iter_arg
+      auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+      if (!blockArg)
+        return WalkResult::advance();
+
+      // If the iter_arg does not have only one use, it won't be possible to
+      // hoist the extractOp out.
+      if (!blockArg.hasOneUse())
+        return WalkResult::advance();
+
+      auto initArg = loop.getTiedLoopInit(blockArg)->get();
+      auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+      // Check that the loop yields a broadcast
+      auto lastOp = loop.getBody()->getTerminator();
+      auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
+      if (!yieldOp)
+        return WalkResult::advance();
+
+      auto broadcast = dyn_cast<vector::BroadcastOp>(
+          yieldOp->getOperand(index).getDefiningOp());
+
+      LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+      Type broadcastInputType = broadcast.getSourceType();
+      if (broadcastInputType != extractOp.getType())
+        return WalkResult::advance();
+
+      // The position of the extract must be defined outside of the loop if
+      // it is dynamic
+      for (auto operand : extractOp.getDynamicPosition())
+        if (!loop.isDefinedOutsideOfLoop(operand))
+          return WalkResult::advance();
+
+      extractOp.getVectorMutable().assign(initArg);
+      loop.moveOutOfLoop(extractOp);
+      broadcast->moveAfter(loop);
+
+      IRRewriter rewriter(extractOp.getContext());
+      auto newLoop = replaceWithDifferentYield(
+          rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
+
+      LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
+
+      newLoop.getResult(index).replaceAllUsesWith(broadcast);
+      broadcast.getSourceMutable().assign(newLoop.getResult(index));
+
+      changed = true;
+      return WalkResult::interrupt();
+    });
+  }
+}
+
 static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
                                 LoopLikeOpInterface loop) {
   Value source = transferRead.getSource();
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 550ffbc7bab678..5a640348be90cf 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs
+
+// CHECK-LABEL:  func.func @hoist_vector_broadcasts
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
+//       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
+//       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+//       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
+//       CHECK-NEXT:   }
+//       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+//       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+  %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+    %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+    %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+    %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+    scf.yield %broadcast : vector<3x4xf32>
+  }
+  return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_broadcasts %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
+
+// CHECK-LABEL:  func.func @hoist_vector_broadcasts
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
+//       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
+//       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+//       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
+//       CHECK-NEXT:   }
+//       CHECK-NEXT:   %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+//       CHECK-NEXT:   return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+  %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+    %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
+    %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+    %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+    scf.yield %broadcast : vector<3x4xf32>
+  }
+  return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_broadcasts %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
+
+// CHECK-LABEL:  func.func @hoist_vector_broadcasts_multiple
+//       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+//       CHECK-SAME:  %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+//       CHECK-DAG:     %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
+//       CHECK-DAG:     %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
+//       CHECK-NEXT:    %[[LOOP:.+]]:2 = scf.for {{.*}} {
+//       CHECK-DAG:       %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+//       CHECK-DAG:       %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
+//       CHECK-NEXT:      scf.yield %[[USE1]], %[[USE2]]  : vector<4xf32>, vector<5xf32>
+//       CHECK-NEXT:    }
+//       CHECK-DAG:     %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
+//       CHECK-DAG:     %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
+//       CHECK-NEXT:    return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
+
+func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) ->  (vector<3x4xf32>, vector<3x5xf32>) {
+  %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
+    %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+    %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
+    %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
+    %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
+    %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
+    %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
+    scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
+  }
+  return %bcast_vec#0, %bcast_vec#1 :  vector<3x4xf32>, vector<3x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_broadcasts %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
\ No newline at end of file

@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch from bc1e38a to 9516800 Compare March 21, 2024 10:26
@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch from 9516800 to e245656 Compare April 2, 2024 08:26
@stevenvar stevenvar requested a review from ftynse as a code owner April 2, 2024 08:26
@stevenvar
Copy link
Contributor Author

Ping

@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch from e245656 to 26ccd91 Compare April 10, 2024 10:39
@stevenvar
Copy link
Contributor Author

Ping

@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch 2 times, most recently from ec49491 to 13d780e Compare April 16, 2024 08:53
@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch from 13d780e to 5ebc4fb Compare April 18, 2024 08:59
@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch from e859861 to 3835681 Compare April 19, 2024 10:31
@stevenvar stevenvar force-pushed the hoist_extract_broadcast branch from 3835681 to bf4eb78 Compare April 19, 2024 15:00
@stevenvar stevenvar requested a review from ftynse April 22, 2024 11:25
@stevenvar
Copy link
Contributor Author

Thank you @ftynse for the quick approval, are you happy with merging this patch? I do not have committer permission

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants