Skip to content

[mlir][TilingInterface] Update PartialReductionOpInterface to get it more in line with TilingInterface. #95460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

MaheshRavishankar
Copy link
Contributor

The TilingInterface methods have return values that allow the interface implementation to return multiple operations, and also return tiled values explicitly. This is to avoid the assumption that the interface needs to return a single operation and this operations result are the expected tiled values. Make the
PartialReductionOpInterface::tileToPartialReduction return TilingResult as well for the same reason.

Similarly make the PartialReductionOpInterface::mergeReductions also return a list of generated operations and values to use as replacements.

This is just a refactoring to allow for deprecation of linalg::tileReductionUsingForall with scf::tileReductionUsingSCF method.

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

Changes

The TilingInterface methods have return values that allow the interface implementation to return multiple operations, and also return tiled values explicitly. This is to avoid the assumption that the interface needs to return a single operation and this operations result are the expected tiled values. Make the
PartialReductionOpInterface::tileToPartialReduction return TilingResult as well for the same reason.

Similarly make the PartialReductionOpInterface::mergeReductions also return a list of generated operations and values to use as replacements.

This is just a refactoring to allow for deprecation of linalg::tileReductionUsingForall with scf::tileReductionUsingSCF method.


Full diff: https://github.com/llvm/llvm-project/pull/95460.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-2)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+4-2)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+10)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+4-4)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+8-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+7-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+16-10)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+34-21)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..05e97befdec1f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -873,9 +873,9 @@ tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
 /// Transformation information returned after reduction tiling.
 struct ForallReductionTilingResult {
   /// The partial reduction tiled op generated.
-  Operation *parallelTiledOp;
+  SmallVector<Operation *> parallelTiledOps;
   /// The final reduction operation merging all the partial reductions.
-  Operation *mergeOp;
+  SmallVector<Operation *> mergeOps;
   /// Initial values used for partial reductions.
   SmallVector<Value> initialValues;
   /// The `scf.forall` operation that iterate over the tiles.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..6316f1d130d19 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -261,13 +261,15 @@ lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 /// Transformation information returned after reduction tiling.
 struct SCFReductionTilingResult {
   /// The partial reduction tiled op generated.
-  Operation *parallelTiledOp;
+  SmallVector<Operation *> parallelTiledOps;
   /// The final reduction operation merging all the partial reductions.
-  Operation *mergeOp;
+  SmallVector<Operation *> mergeOps;
   /// Initial values used for reduction.
   SmallVector<Value> initialValues;
   /// The loop operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
+  /// The replacements to use for the results of the tiled operation.
+  SmallVector<Value> replacements;
 };
 
 /// Method to tile a reduction and generate a parallel op within a serial loop.
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..0cfd23587a7ad 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -33,6 +33,16 @@ struct TilingResult {
   SmallVector<Value> tiledValues;
 };
 
+/// Container for the result of merge operation of tiling.
+/// - `mergeOps` contains operations created during the merge.
+/// - `replacements` contains the values that represents the result of the
+/// merge.
+///    These are used as replacements for the original tiled operation.
+struct MergeResult {
+  SmallVector<Operation *> mergeOps;
+  SmallVector<Value> replacements;
+};
+
 } // namespace mlir
 
 /// Include the ODS generated interface header files.
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index bc83c81c0086c..031270fdda8ae 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -247,7 +247,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           less or equal to the tile size. This is meant to be used with
           `mergeReductions` method which will combine the partial reductions.
         }],
-        /*retType=*/"Operation*",
+        /*retType=*/"FailureOr<TilingResult>",
         /*methodName=*/"tileToPartialReduction",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -258,7 +258,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "ArrayRef<int>":$reductionDims),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return nullptr;
+          return failure();
         }]
       >,
       InterfaceMethod<
@@ -267,7 +267,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           tiled along the reduction dimensions. This will only apply the
           reduction the operation.
         }],
-        /*retType=*/"Operation*",
+        /*retType=*/"FailureOr<MergeResult>",
         /*methodName=*/"mergeReductions",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -276,7 +276,7 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
             "ArrayRef<int>":$reductionDim),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return nullptr;
+          return failure();
         }]
       >
   ];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..2807b3ce42abd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
     return emitDefaultSilenceableFailure(target);
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  results.push_back(result->parallelTiledOp);
-  results.push_back(result->mergeOp);
+  for (auto parallelTiledOp : result->parallelTiledOps)
+    results.push_back(parallelTiledOp);
+  for (auto mergeOp : result->mergeOps)
+    results.push_back(mergeOp);
   results.push_back(result->loops.front());
   return DiagnosedSilenceableFailure::success();
 }
@@ -2577,8 +2579,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   }
   for (Value initValue : result->initialValues)
     results.push_back(initValue.getDefiningOp());
-  results.push_back(result->parallelTiledOp);
-  results.push_back(result->mergeOp);
+  for (auto parallelTiledOp : result->parallelTiledOps)
+    results.push_back(parallelTiledOp);
+  for (auto mergeOp : result->mergeOps)
+    results.push_back(mergeOp);
   results.push_back(result->loops);
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index a0a0e11a6903d..d8dee82237156 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -833,16 +833,19 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 7. Merge the partial reductions.
   b.setInsertionPointAfter(forallOp);
-  Operation *mergeOp =
+  FailureOr<MergeResult> mergeResult =
       op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
-  b.replaceOp(op, mergeOp->getResults());
+  if (failed(mergeResult)) {
+    return failure();
+  }
+  b.replaceOp(op, mergeResult->replacements);
 
   // 8. Return.
   ForallReductionTilingResult results;
   results.initialValues = initTensors;
   results.loops = forallOp;
-  results.parallelTiledOp = tiledOp;
-  results.mergeOp = mergeOp;
+  results.parallelTiledOps.push_back(tiledOp);
+  results.mergeOps.append(mergeResult->mergeOps);
   return results;
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..b2a1e7c71f58e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -368,11 +368,11 @@ struct LinalgOpPartialReductionInterface
     return inits;
   }
 
-  Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
-                                    ValueRange init,
-                                    ArrayRef<OpFoldResult> offsets,
-                                    ArrayRef<OpFoldResult> sizes,
-                                    ArrayRef<int> reductionDims) const {
+  FailureOr<TilingResult>
+  tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+                         ValueRange init, ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         ArrayRef<int> reductionDims) const {
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
@@ -437,12 +437,15 @@ struct LinalgOpPartialReductionInterface
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
-    return genericOp.getOperation();
+    return TilingResult{
+        {genericOp.getOperation()},
+        llvm::map_to_vector(genericOp->getResults(),
+                            [](OpResult r) -> Value { return r; })};
   }
 
-  Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
-                             ValueRange partialReduce,
-                             ArrayRef<int> reductionDims) const {
+  FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b,
+                                         Location loc, ValueRange partialReduce,
+                                         ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
 
     // Step 1. Recover the dims that actually need to be merged from the
@@ -493,7 +496,10 @@ struct LinalgOpPartialReductionInterface
           }
           b.create<linalg::YieldOp>(loc, yieldedValues);
         });
-    return reduction.getOperation();
+    return MergeResult{
+        {reduction.getOperation()},
+        llvm::map_to_vector(reduction->getResults(),
+                            [](OpResult r) -> Value { return r; })};
   }
 };
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..35edd490f72eb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   SmallVector<Value> &initTensors = maybeInitTensors.value();
 
   // 3. Define the callback to use for generating the inner most tile loop body.
-  Operation *parallelOp = nullptr;
+  SmallVector<Operation *> parallelTiledOps;
   auto innerYieldTiledValuesFn =
       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
           ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
@@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
     }
 
     // 4a. Clone the operation.
-    auto clonedOp = cast<PartialReductionOpInterface>(
-        cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+    {
+      auto clonedOp = cast<PartialReductionOpInterface>(
+          cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+      // 4b. Tile the cloned operation.
+      FailureOr<TilingResult> partialTilingResult =
+          clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
+                                          sizes, reductionDims);
+      if (failed(partialTilingResult)) {
+        return failure();
+      }
+      std::swap(parallelTiledOps, partialTilingResult->tiledOps);
+      std::swap(tiledResult, partialTilingResult->tiledValues);
 
-    // 4b. Tile the cloned operation.
-    parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
-                                                 offsets, sizes, reductionDims);
-    // 4c. Delete the cloned operation.
-    b.eraseOp(clonedOp);
+      // 4c. Delete the cloned operation.
+      b.eraseOp(clonedOp);
+    }
 
-    tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
     // 4d. Compute the offsets and sizes needed to insert the result of the
     // tiled value back into destination before yielding the destination.
-    for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
+    for (auto result : tiledResult) {
       SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
       resultOffsets.emplace_back(std::move(outOffsets));
 
       SmallVector<OpFoldResult> outSizes;
       for (size_t i = 0; i < offsets.size(); i++) {
-        outSizes.push_back(
-            tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
+        outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
       }
       resultSizes.emplace_back(std::move(outSizes));
     }
@@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
   // 5. Apply the merge reduction to combine all the partial values.
   b.setInsertionPointAfter(*loops.begin());
-  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
-  b.replaceOp(op, mergeOp->getResults());
-
-  SCFReductionTilingResult results;
-  results.initialValues = initTensors;
-  results.loops = loops;
-  results.parallelTiledOp = parallelOp;
-  results.mergeOp = mergeOp;
-  return results;
+  FailureOr<MergeResult> mergeResult =
+      op.mergeReductions(b, loc, replacements, reductionDims);
+  if (failed(mergeResult)) {
+    return failure();
+  }
+  b.replaceOp(op, mergeResult->replacements);
+
+  SCFReductionTilingResult reductionTilingResult;
+  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
+  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
+  std::swap(reductionTilingResult.initialValues, initTensors);
+  std::swap(reductionTilingResult.loops, loops);
+  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
+
+  return reductionTilingResult;
 }
 
 //===----------------------------------------------------------------------===//

@@ -2525,8 +2525,10 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
return emitDefaultSilenceableFailure(target);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
for (auto parallelTiledOp : result->parallelTiledOps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand auto, here and below.

…t more in line with `TilingInterface`.

The `TilingInterface` methods have return values that allow the
interface implementation to return multiple operations, and also
return tiled values explicitly. This is to avoid the assumption that
the interface needs to return a single operation and this operations
result are the expected tiled values. Make the
`PartialReductionOpInterface::tileToPartialReduction` return
`TilingResult` as well for the same reason.

Similarly make the `PartialReductionOpInterface::mergeReductions` also
return a list of generated operations and values to use as
replacements.

This is just a refactoring to allow for deprecation of
`linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF`
method.
@MaheshRavishankar MaheshRavishankar force-pushed the update_partial_reduction_op_interface branch from 50dc889 to 998ffc9 Compare June 18, 2024 16:07
@MaheshRavishankar MaheshRavishankar merged commit b99d0b3 into llvm:main Jun 18, 2024
3 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…t more in line with `TilingInterface`. (llvm#95460)

The `TilingInterface` methods have return values that allow the
interface implementation to return multiple operations, and also return
tiled values explicitly. This is to avoid the assumption that the
interface needs to return a single operation and this operations result
are the expected tiled values. Make the
`PartialReductionOpInterface::tileToPartialReduction` return
`TilingResult` as well for the same reason.

Similarly make the `PartialReductionOpInterface::mergeReductions` also
return a list of generated operations and values to use as replacements.

This is just a refactoring to allow for deprecation of
`linalg::tileReductionUsingForall` with `scf::tileReductionUsingSCF`
method.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants