Skip to content

[mlir][TilingInterface] Allow controlling what fusion is done within tile and fuse #76871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

Currently the tileConsumerAndFuseProducerGreedilyUsingSCFFor method greedily fuses through all slices that are generated during the tile and fuse flow. That is not the normal use case. Ideally the caller would like to control which slices get fused and which dont. This patch introduces a new field to the SCFTileAndFuseOptions to specify this control.

The contol function also allows the caller to specify if the replacement for the fused producer needs to be yielded from within the tiled computation. This allows replacing the fused producers in case they have other uses. Without this the original producers still survive negating the utility of the fusion.

The change here also means that the name of the function tileConsumerAndFuseProducerGreedily... can be updated. Defering that to a later stage to reduce the churn of API changes.

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (MaheshRavishankar)

Changes

Currently the tileConsumerAndFuseProducerGreedilyUsingSCFFor method greedily fuses through all slices that are generated during the tile and fuse flow. That is not the normal use case. Ideally the caller would like to control which slices get fused and which dont. This patch introduces a new field to the SCFTileAndFuseOptions to specify this control.

The contol function also allows the caller to specify if the replacement for the fused producer needs to be yielded from within the tiled computation. This allows replacing the fused producers in case they have other uses. Without this the original producers still survive negating the utility of the fusion.

The change here also means that the name of the function tileConsumerAndFuseProducerGreedily... can be updated. Defering that to a later stage to reduce the churn of API changes.


Full diff: https://github.com/llvm/llvm-project/pull/76871.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+17)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+47-18)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp (+34-64)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 2f8f337bb8057c..0571b4aede4885 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -97,6 +97,23 @@ struct SCFTileAndFuseOptions {
     tilingOptions = options;
     return *this;
   }
+
+  /// Control function to check if a slice needs to be fused or not,
+  /// The control function recieves
+  /// 1) the slice along which fusion is to be done,
+  /// 2) the producer value that is to be fused
+  /// 3) a boolean value set to `true` if the fusion is from
+  ///    a destination operand.
+  using ControlFnTy = std::function<std::tuple<bool, bool>(
+      tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+      bool isDestinationOperand)>;
+  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
+    return std::make_tuple(true, false);
+  };
+  SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
+    fusionControlFn = controlFn;
+    return *this;
+  }
 };
 
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1b6b4db9d20907..22826cababe779 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -728,32 +728,36 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   }
 
   // 1. First tile the consumer.
-  SmallVector<scf::ForOp> forLoops;
   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
-  DenseMap<Value, Value> replacements;
-  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
-  {
-    FailureOr<scf::SCFTilingResult> tilingResult =
-        tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
-    if (failed(tilingResult))
-      return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
-    for (auto *tiledOp : tilingResult->tiledOps)
-      tiledAndFusedOps.insert(tiledOp);
-    forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
-    for (auto [index, origValue, replacement] :
-         llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
-      replacements[origValue] = replacement;
-      yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
-          index)] = index;
-    }
-  }
+  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
+  FailureOr<scf::SCFTilingResult> tilingResult =
+      tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+  if (failed(tilingResult))
+    return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+  for (auto *tiledOp : tilingResult->tiledOps)
+    tiledAndFusedOps.insert(tiledOp);
+  SmallVector<scf::ForOp> forLoops =
+      castToTypedOperations<scf::ForOp>(tilingResult->loops);
 
   // If there are no loops generated, fusion is immaterial.
   if (forLoops.empty()) {
+    DenseMap<Value, Value> replacements;
+    for (auto [origVal, replacement] :
+         llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+      replacements[origVal] = replacement;
+    }
     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
                                      getAsOperations(forLoops), replacements};
   }
 
+  // To keep track of replacements for now just record the map from the original
+  // untiled value to the result number of the for loop. Since the loop gets
+  // potentially replaced during fusion, keeping the value directly wont work.
+  DenseMap<Value, size_t> origValToResultNumber;
+  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
+    origValToResultNumber[result] = index;
+  }
+
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
   //    `tensor.extract_slice` operations with source being the operands of the
@@ -776,6 +780,18 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
     candidates.pop_front();
 
+    // Find the original producer of the slice.
+    auto [fusableProducer, destinationInitArg] =
+        getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
+                                          forLoops);
+    if (!fusableProducer)
+      continue;
+
+    auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
+        candidateSliceOp, fusableProducer, destinationInitArg.has_value());
+    if (!fuseSlice)
+      continue;
+
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
@@ -784,6 +800,13 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     if (!fusedResult)
       continue;
 
+    if (yieldReplacement) {
+      yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
+                                       fusedResult.value(), forLoops);
+      origValToResultNumber[fusableProducer] =
+          forLoops.front().getNumResults() - 1;
+    }
+
     if (Operation *tiledAndFusedOp =
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
@@ -791,6 +814,12 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
       addCandidateSlices(tiledAndFusedOp, candidates);
     }
   }
+
+  DenseMap<Value, Value> replacements;
+  for (auto [origVal, resultNumber] : origValToResultNumber) {
+    replacements[origVal] = forLoops.front()->getResult(resultNumber);
+  }
+
   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
                                    getAsOperations(forLoops), replacements};
 }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 112ad6cbde8589..798293bc1327e1 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -311,80 +311,50 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
     // Collect list of operations that can be tiled and fused.
     llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
         collectTiledAndFusedOps(rootOp);
-    auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
-      return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
-             outerMostTiledLoop->isAncestor(user);
+    llvm::SmallDenseMap<Operation *, bool> yielded;
+    auto isIgnoredUser = [&](Operation *user) {
+      return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user);
     };
-
-    // The rest of this method is similar to
-    // scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp, except that also
-    // yields replacements for values of the fused producer.
-
-    // 1. Tile the consumer.
-    SmallVector<OpResult> yieldedValuesToOrigValues;
-    FailureOr<scf::SCFTilingResult> tilingResult =
-        scf::tileUsingSCFForOp(rewriter, rootOp, options);
-    if (failed(tilingResult)) {
-      return rewriter.notifyMatchFailure(rootOp,
-                                         "failed to tile base operation");
+    for (Operation *op : tiledAndFusedOps) {
+      yielded[op] = llvm::any_of(op->getUsers(), [&](Operation *user) {
+        return !isIgnoredUser(user);
+      });
     }
-    yieldedValuesToOrigValues.append(rootOp->result_begin(),
-                                     rootOp->result_end());
-
-    // 2. Tiling each operation results in generation of slices. The source of
-    // these slices could be producers that can be fused into the tiled loops by
-    // computing the slices of these producers in-place. This results in more
-    // slices created for operands of the "fused producer". This open up more
-    // opportunities for fusion. Use a worklist to fuse greedily.
-    auto addCandidateSlices =
-        [](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
-          for (Value operand : fusedOp->getOperands())
-            if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
-              candidates.push_back(sliceOp);
-        };
 
-    std::deque<tensor::ExtractSliceOp> candidates;
-    addCandidateSlices(tilingResult->tiledOps.back(), candidates);
-    OpBuilder::InsertionGuard g(rewriter);
-    auto forLoops = llvm::to_vector(llvm::map_range(
-        tilingResult->loops, [](auto op) { return cast<scf::ForOp>(op); }));
-    while (!candidates.empty()) {
-      // Traverse the slices in BFS fashion.
-      tensor::ExtractSliceOp candidateSliceOp = candidates.front();
-      candidates.pop_front();
-
-      // Materialize the slice of the producer in place.
-      std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
-          tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
-      if (!fusedProducer)
-        continue;
-
-      // Check if the fused producer has other uses that require the value
-      // to be yielded from within the tiled loop.
-      OpResult untiledProducer = fusedProducer->origProducer;
-      if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
-            return !isIgnoredUser(user, forLoops.front());
-          })) {
-        yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
-                                         fusedProducer.value(), forLoops);
-        yieldedValuesToOrigValues.push_back(untiledProducer);
-      }
+    scf::SCFTileAndFuseOptions tileAndFuseOptions;
+    tileAndFuseOptions.setTilingOptions(options);
+    scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
+        [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+            bool isDestinationOperand) {
+          Operation *owner = originalProducer.getOwner();
+          return std::make_tuple(true,
+                                 yielded.contains(owner) && yielded[owner]);
+        };
+    tileAndFuseOptions.setFusionControlFn(controlFn);
 
-      // Add more fusion candidates to the worklist.
-      if (auto fusedProducerOp =
-              fusedProducer->tiledAndFusedProducer.getDefiningOp())
-        addCandidateSlices(fusedProducerOp, candidates);
+    FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
+        scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+            rewriter, rootOp, tileAndFuseOptions);
+    if (failed(tileAndFuseResult)) {
+      return rewriter.notifyMatchFailure(
+          rootOp, "failed to tile and fuse with op as root");
     }
 
-    scf::ForOp outermostLoop = forLoops.front();
-    for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
-      Value replacement = outermostLoop.getResult(index);
+    for (auto it : tileAndFuseResult->replacements) {
+      Value origVal = it.first;
+      Value replacement = it.second;
       rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
-        return !isIgnoredUser(use.getOwner(), outermostLoop);
+        Operation *user = use.getOwner();
+        return !isIgnoredUser(user) &&
+               !tileAndFuseResult->loops.front()->isAncestor(user);
       });
     }
+
     rewriter.eraseOp(rootOp);
-    filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back());
+    for (auto tiledAndFusedOp : tileAndFuseResult->tiledAndFusedOps)
+      if (tiledAndFusedOp->hasAttr(kTransformMarker))
+        filter.replaceTransformationFilter(rewriter, tiledAndFusedOp);
+
     return success();
   }
 

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks okay to me. I was thinking if all the cases are covered or not, then I found that the answer is yes: https://github.com/llvm/llvm-project/blob/main/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

This looks cool to me because it internalize some logics through control function. It is easier to use.

@chelini
Copy link
Contributor

chelini commented Jan 4, 2024

The code looks okay to me. I was thinking if all the cases are covered or not, then I found that the answer is yes: https://github.com/llvm/llvm-project/blob/main/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

This looks cool to me because it internalize some logics through control function. It is easier to use.

Indeed, thanks for working on this @MaheshRavishankar. A callback to control what to fuse is good, in my opinion.

…tile and fuse.

Currently the `tileConsumerAndFuseProducerGreedilyUsingSCFFor` method
greedily fuses through all slices that are generated during the tile
and fuse flow. That is not the normal use case. Ideally the caller
would like to control which slices get fused and which dont. This
patch introduces a new field to the `SCFTileAndFuseOptions` to specify
this control.

The contol function also allows the caller to specify if the
replacement for the fused producer needs to be yielded from within the
tiled computation. This allows replacing the fused producers in case
they have other uses. Without this the original producers still
survive negating the utility of the fusion.

The change here also means that the name of the function
`tileConsumerAndFuseProducerGreedily...` can be updated. Defering that
to a later stage to reduce the churn of API changes.
Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for clarifying this. The PR looks good to me.

@MaheshRavishankar MaheshRavishankar merged commit 4435ced into llvm:main Jan 8, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…tile and fuse (llvm#76871)

Currently the `tileConsumerAndFuseProducerGreedilyUsingSCFFor` method
greedily fuses through all slices that are generated during the tile and
fuse flow. That is not the normal use case. Ideally the caller would
like to control which slices get fused and which dont. This patch
introduces a new field to the `SCFTileAndFuseOptions` to specify this
control.

The contol function also allows the caller to specify if the replacement
for the fused producer needs to be yielded from within the tiled
computation. This allows replacing the fused producers in case they have
other uses. Without this the original producers still survive negating
the utility of the fusion.

The change here also means that the name of the function
`tileConsumerAndFuseProducerGreedily...` can be updated. Defering that
to a later stage to reduce the churn of API changes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants