Skip to content

[mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface #92624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2024

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented May 18, 2024

This patch adds support for reducing operations with multiple results using PartialReductionOpInterface. Also adds an implementation of PartialReductionOpInterface for multiple results for linalg.generic.

@llvmbot
Copy link
Member

llvmbot commented May 18, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Kunwar Grover (Groverkss)

Changes

Patch is 29.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92624.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-2)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+2-2)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+3-2)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (+6-7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+17-15)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+96-68)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+27-24)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+40-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
-  let results = (outs TransformHandleTypeInterface:$fill_op,
+  let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_linalg_op,
                       TransformHandleTypeInterface:$combining_linalg_op,
                       TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
-  let results = (outs TransformHandleTypeInterface:$fill_op,
+  let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_linalg_op,
                       TransformHandleTypeInterface:$combining_linalg_op,
                       TransformHandleTypeInterface:$forall_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..308ce92e35520 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
   Operation *parallelTiledOp;
   /// The final reduction operation merging all the partial reductions.
   Operation *mergeOp;
-  /// The op initializing the tensor used for partial reductions.
-  Operation *initialOp;
+  /// Initial values used for partial reductions.
+  SmallVector<Value> initialValues;
   /// The `scf.forall` operation that iterate over the tiles.
   scf::ForallOp loops;
 };
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..6d567171e185a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
   Operation *parallelTiledOp;
   /// The final reduction operation merging all the partial reductions.
   Operation *mergeOp;
-  /// Initial op
-  Operation *initialOp;
+  /// Initial values used for reduction.
+  SmallVector<Value> initialValues;
   /// The loop operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
 };
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..6fff7e0da538a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           operation reduction. The tensor shape is equal to operation result
           shape with new dimension for each non zero tile size.
         }],
-        /*retType=*/"FailureOr<Operation*>",
+        /*retType=*/"FailureOr<Value>",
         /*methodName=*/"generateInitialTensorForPartialReduction",
         /*args=*/(ins
             "OpBuilder &":$b,
-            "Location ":$loc,
+            "Location":$loc,
+            "int64_t":$resultNumber,
             "ArrayRef<OpFoldResult>":$sizes,
             "ArrayRef<int>":$reductionDim),
         /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9b3121774ab3a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
-  results.push_back(result->initialOp);
+  for (Value initValue : result->initialValues)
+    results.push_back(initValue.getDefiningOp());
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
   results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
     diag.attachNote(target.getLoc()) << "target operation";
     return diag;
   }
-  results.push_back(result->initialOp);
+  for (Value initValue : result->initialValues)
+    results.push_back(initValue.getDefiningOp());
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
   results.push_back(result->loops);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..e0394f852fcc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
     PartialReductionOpInterface partialReductionIface =
         llvm::cast<PartialReductionOpInterface>(op.getOperation());
-    FailureOr<Operation *> reductionNeutralTensorOp =
+    assert(op->getNumResults() == 1 && "Multiple results not supported.");
+    FailureOr<Value> reductionNeutralTensor =
         partialReductionIface.generateInitialTensorForPartialReduction(
-            builder, builder.getLoc(), shape, {});
-    assert(succeeded(reductionNeutralTensorOp));
-    builder.create<scf::YieldOp>(
-        reductionNeutralTensorOp.value()->getResult(0));
+            builder, builder.getLoc(), 0, shape, {});
+    assert(succeeded(reductionNeutralTensor));
+    builder.create<scf::YieldOp>(reductionNeutralTensor.value());
   }
   return ifOp.getResult(0);
 }
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
     ImplicitLocOpBuilder &builder) {
   // TODO: add support for multiple destination passing style initial value
   // operands.
-  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
-  // needs to also support multiple DPS initial operands.
+  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
   SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
   auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
   Value spmdizedInitOperand =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..d4805611a68b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -657,9 +657,9 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
 
   // Ops implementing PartialReductionOpInterface are not necessarily expected
-  // to implement TilingInterface.. This cast is unsafe atm.
+  // to implement DestinationStyleOpInterface. This cast is unsafe atm.
   // TODO: proper core mechanism to tie interfaces together.
-  // TODO: this function requires a pair of interfaces ..
+  // TODO: this function requires a pair of interfaces.
   auto destinationStyleOp =
       dyn_cast<DestinationStyleOpInterface>(op.getOperation());
   if (!destinationStyleOp)
@@ -671,9 +671,6 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
     return b.notifyMatchFailure(op, "not a linalg op");
 
   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  if (op->getNumResults() != 1)
-    return b.notifyMatchFailure(
-        op, "don't support ops with multiple results for now");
 
   SmallVector<utils::IteratorType> iterators =
       tilingInterfaceOp.getLoopIteratorTypes();
@@ -692,12 +689,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
-  FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
-  if (failed(identityTensor))
-    return b.notifyMatchFailure(op,
-                                "cannot create a tensor of identity value.");
+  SmallVector<Value> initTensors;
+  initTensors.reserve(op->getNumResults());
+  for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+    FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
+        b, loc, idx, numThreads, reductionDim);
+    if (failed(initValue))
+      return b.notifyMatchFailure(
+          op, "cannot create a tensor of identity value for result " +
+                  std::to_string(idx));
+    initTensors.push_back(initValue.value());
+  }
 
   // Gather destination tensors.
   SmallVector<Value> dest;
@@ -715,8 +717,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 2. Create the ForallOp with an empty region.
   scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult(materializedNonZeroNumThreads),
-      (*identityTensor)->getResults(), mapping);
+      loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
+      mapping);
 
   // 3. Calculate the tile offsets and sizes for the subsequent loop that will
   // be nested under `forallOp`.
@@ -726,7 +728,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
                                /*nominalTileSizes=*/std::nullopt, tiledOffsets,
                                tiledSizes);
 
-  // 4. Clone the tileable op and update its destination operands to use the
+  // 4b. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   SmallVector<Value> tilingResults;
   ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +840,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 8. Return.
   ForallReductionTilingResult results;
-  results.initialOp = *identityTensor;
+  results.initialValues = initTensors;
   results.loops = forallOp;
   results.parallelTiledOp = tiledOp;
   results.mergeOp = mergeOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..edae03fec72a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
 struct LinalgOpPartialReductionInterface
     : public PartialReductionOpInterface::ExternalModel<
           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
-  FailureOr<Operation *> generateInitialTensorForPartialReduction(
-      Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
-      ArrayRef<int> reductionDims) const {
+  FailureOr<Value> generateInitialTensorForPartialReduction(
+      Operation *op, OpBuilder &b, Location loc, int64_t resultNumber,
+      ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
     OpBuilder::InsertionGuard guard(b);
 
@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
     // loops. This could be controlled by user for more flexibility.
 
     SmallVector<Operation *, 4> combinerOps;
-    if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+    if (!matchReduction(linalgOp.getRegionOutputArgs(), resultNumber,
+                        combinerOps) ||
         combinerOps.size() != 1)
       return op->emitOpError("Failed to anaysis the reduction operation.");
 
@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
           "Failed to get an identity value for the reduction operation.");
 
     ArrayRef<int64_t> oldShape =
-        linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+        linalgOp.getShape(linalgOp.getDpsInitOperand(resultNumber));
 
     // Calculate the new shape, we insert the new dimensions based on the index
     // of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
       newOutputShape.push_back(dim);
       if (ShapedType::isDynamic(dim))
         dynamicDims.push_back(b.create<tensor::DimOp>(
-            loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+            loc, linalgOp.getDpsInitOperand(resultNumber)->get(), oldIdx));
     }
     Value emptyTensor = b.create<tensor::EmptyOp>(
-        loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
-        dynamicDims);
+        loc, newOutputShape,
+        linalgOp.getRegionOutputArgs()[resultNumber].getType(), dynamicDims);
     Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
     auto identityTensor =
         b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
-    return identityTensor.getOperation();
+    return identityTensor.getResult(0);
   }
 
   Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,58 @@ struct LinalgOpPartialReductionInterface
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
-    AffineMap oldOutputMap =
-        linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
-    SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
-                                       reductionDims.size());
-
-    for (int idx : reductionDims)
-      outputExpr[idx] = b.getAffineDimExpr(idx);
-    int currExpr = 0;
-    for (int idx : llvm::seq<int>(0, outputExpr.size())) {
-      if (outputExpr[idx])
-        continue;
-      outputExpr[idx] = oldOutputMap.getResult(currExpr++);
+    // Step 1. Extend init maps to have reduction dimension dims, since we
+    // are converting them to parallel dimensions.
+    SmallVector<AffineMap> newInitMaps;
+    newInitMaps.reserve(linalgOp.getNumDpsInits());
+    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+      // this with a for range loop when we have it.
+      AffineMap newMap =
+          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+      for (int redPos : reductionDims) {
+        newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
+                                     newMap.getNumResults());
+      }
+      newInitMaps.push_back(newMap);
     }
 
-    // Step 1: Extract a slice of the input operands.
-    SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+    // Step 2a: Extract a slice of the input operands.
+    SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+        b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+
+    // Step 2b: Extract a slice of the init operands.
+    SmallVector<Value, 1> tiledInits;
+    for (Value valueToTile : init) {
+      SmallVector<OpFoldResult> initOffset(offsets.size(), b.getIndexAttr(0));
+      SmallVector<OpFoldResult> initStride(offsets.size(), b.getIndexAttr(1));
+      // TODO: Use SubsetExtractOpInterface here once available.
+      auto extractSlice = b.create<tensor::ExtractSliceOp>(
+          loc, valueToTile, initOffset, sizes, initStride);
+      tiledInits.push_back(extractSlice);
+    }
 
-    // Step 2: Extract the accumulator operands
-    SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
-    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-    // TODO: use SubsetExtractOpInterface once it is available.
-    Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
-                                                 sizes, strides);
+    // Update the indexing maps.
+    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+    // Change the init maps.
+    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+      // this with a for range loop when we have it.
+      OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
+      int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
+      newMaps[mapIdx] = newInitMaps[idx];
+    }
 
-    // Step3. Create a generic op where the reduction dimensions are replaced
-    // by a parallel dimension of the size of reduction.
+    // Step 3. Change the reduction dim iterator types.
     SmallVector<utils::IteratorType> newIteratorTypes =
         linalgOp.getIteratorTypesArray();
     for (int dim : reductionDims)
       newIteratorTypes[dim] = utils::IteratorType::parallel;
-    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
-    newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
-                                    linalgOp.getContext());
+
+    // Step 4. Create the new generic op.
     auto genericOp =
-        b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
-                            ValueRange({out}), newMaps, newIteratorTypes);
+        b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
+                            tiledInits, newMaps, newIteratorTypes);
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
@@ -361,40 +376,53 @@ struct LinalgOpPartialReductionInterface
                              ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
-
-    // Then create a new reduction that only reduce the newly added dimensions
-    // from the previous op.
-    int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
-    AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-    SmallVector<utils::IteratorType> reductionIteratorTypes;
-    SmallVector<AffineExpr> exprs;
-
-    for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
-      if (reductionDimsSet.contains(i)) {
-        reductionIteratorTypes.push_back(utils::IteratorType::reduction);
-      } else {
-        exprs.push_back(b.getAffineDimExpr(i));
-        reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+    // Step 1. Recover the dims that actually need to be merged from the
+    // original operation. We can classify the original iterators as follows:
+    //
+    // parallel                         --> parallel
+    // reduction + not in reductionDims --> parallel (already reduced)
+    // reduction + in reductionDims     --> reduction (will reduce now)
+    SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
+                                               utils::IteratorType::parallel);
+    for (int redIdx : reductionDims)
+      iterators[redIdx] = utils::IteratorType::reduction;
+
+    // Step 2. For each partial result, create a map to index it. This map
+    // is simply the indexing map for the original result with reductionDims
+    // appended (as produced in tileToPartialReduction).
+    int64_t numInits = linalgOp.getNumDpsInits();
+    SmallVector<AffineMap> indexingMaps(numInits * 2);
+    for (int idx : llvm::seq<int>(0, numInits)) {
+      AffineMap &inputMap = indexingMaps[idx];
+      AffineMap &outputMap = indexingMaps[numInits + idx];
+
+      outputMap =
+          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+      inputMap = outputMap;
+      for (int redPos : reductionDims) {...
[truncated]

@Groverkss Groverkss requested review from kuhar and qedawkins May 20, 2024 12:24
// CHECK: func @reduction_tile_transpose
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
// CHECK: scf.for
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldnt expect this to change... this should be NFC for the most part?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored a bit of the implementation for multiple reduction dims as part of adding multiple results to linalg partial tiling interface. Before, it was trying to insert reduction dims in an ad-hoc way. Now, it always inserts reduction dims at end. This means that when reducing them in the final loop, they are all contiguous. Also, the order is deterministic (the old implementation would not work for multiple results).

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I think this looks OK. I am a bit fuzzy on the implementation details completely, but this seems fine to me.

@Groverkss Groverkss merged commit 9329b20 into llvm:main May 22, 2024
4 checks passed
Groverkss added a commit to iree-org/iree that referenced this pull request May 22, 2024
)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
gglangg pushed a commit to gglangg/iree that referenced this pull request Jun 4, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
gglangg pushed a commit to gglangg/iree that referenced this pull request Jun 4, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Jun 5, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)

Signed-off-by: Lubo Litchev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants