Skip to content

Fix transpose->unpack folding pattern for the partial-tile case of unpack #107271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Sep 4, 2024

Just directly create the empty tensor of appropriate shape instead of relying on UnPackOp::createDestinationTensor which is trying to infer the destination shape, which isn't possible in general with the set of paramters that it is taking.

@bjacob bjacob marked this pull request as ready for review September 4, 2024 17:41
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

UnPackOp::createDestinationTensor was trying to infer the destination shape, which wasn't possible in general with the set of paramters that it was taking, in the case of partial-tile unpack where unpack has extract-slice semantics.

Added an optional (default empty) additional parameter to UnPackOp::createDestinationTensor to allow passing the destination shape. Went over existing callers. Only one needed to pass it explicitly, others are in the full-tile case where the existing code was fine.


Full diff: https://github.com/llvm/llvm-project/pull/107271.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8-5)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+9-3)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+29-6)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..8040cc97cd8bc4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -2076,7 +2076,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
-        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+        SmallVector<OpFoldResult> mixedSizes = {});
 
     /// Build and return a new UnPackOp that is a clone of the current UnPackOp
     /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 996de530c255d4..41afbbe840352c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4360,15 +4360,19 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
                                         Value source,
                                         ArrayRef<OpFoldResult> innerTileSizes,
                                         ArrayRef<int64_t> innerDimsPos,
-                                        ArrayRef<int64_t> outerDimsPerm) {
+                                        ArrayRef<int64_t> outerDimsPerm,
+                                        SmallVector<OpFoldResult> mixedSizes) {
+  auto srcType = llvm::cast<RankedTensorType>(source.getType());
+  auto elemType = srcType.getElementType();
+  if (!mixedSizes.empty()) {
+    return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+  }
+
   AffineExpr sym0, sym1;
   bindSymbols(b.getContext(), sym0, sym1);
   auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
     return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
   };
-
-  SmallVector<OpFoldResult> mixedSizes;
-  auto srcType = llvm::cast<RankedTensorType>(source.getType());
   for (auto i :
        llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
     if (srcType.isDynamicDim(i))
@@ -4384,7 +4388,6 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
   for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
     mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
 
-  auto elemType = srcType.getElementType();
   return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
 }
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index c681cadcb27cb2..fdd6ff47f3bb5e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     if (failed(maybePerm))
       return failure();
 
+    SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
+    if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
+      return failure();
+    }
+
     SmallVector<int64_t> inverseTransposePerm =
         invertPermutationVector(maybePerm.value());
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,13 +453,13 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
-
     if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
-                         newOuterDimsPermVec, destRank))
+                         newOuterDimsPermVec, destRank)) {
       return rewriter.notifyMatchFailure(
           unPackOp,
           "Cannot fold in tensor.unpack if a tile dimension was transposed "
           "with a non-tile dimension in linalg.transpose.");
+    }
 
     // Process transpose operation for tiled inner dimensions
     for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
@@ -465,7 +470,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
     Value output = unPackOp.createDestinationTensor(
         rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec,
+        unpackOpResultDims[0]);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
         unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 629a4c21357207..bff913f5f55feb 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
 
 // -----
 
+func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+  %0 = tensor.empty() : tensor<1x1x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+                outs(%0 : tensor<1x1x16x4xi32>)
+                permutation = [1, 0, 3, 2]
+  %1 = tensor.empty() : tensor<15x3xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 1]
+            inner_dims_pos = [0, 1]
+            inner_tiles = [16, 4] into
+            %1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
+  return %unpack : tensor<15x3xi32>
+}
+//CHECK-LABEL:  func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [1, 0]
+// CHECK-SAME:        inner_dims_pos = [1, 0]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHECK-SAME:        into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<15x3xi32>
+//      CHECK:   }
+
+// -----
+
 func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
   %transposed = linalg.transpose
     ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
     into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
   return %unpack : tensor<?x?xf32>
 }
-//       CHECK:    #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-LABEL:   func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
 //  CHECK-SAME:     %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
 //   CHECK-DAG:       %[[CST1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:       %[[CST0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
-//   CHECK-DAG:       %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
-//   CHECK-DAG:       %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
-//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
+//   CHECK-DAG:       %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
+//       CHECK:       %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
 //       CHECK:       %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 //  CHECK-SAME:         outer_dims_perm = [0, 1]
 //  CHECK-SAME:         inner_dims_pos = [1, 0]

@bjacob bjacob force-pushed the pr-fold-transpose-unpack-partial-tile branch from c80fb59 to 20c2882 Compare September 4, 2024 18:14
@bjacob bjacob merged commit c1667f9 into llvm:main Sep 4, 2024
6 of 7 checks passed
bjacob added a commit to iree-org/iree that referenced this pull request Sep 5, 2024
* Changes to `IREEGPU_DataTiledMMAAttr`:
  * Add `unroll_{m,n,k}` parameters.
  * Drop custom builder/parser/printer code.
* Implement `getABCVectorTypes()` by falling back on
`MMAAttr::getABCVectorTypes()` since that is unchanged.
* Changes to `MaterializeEncodingInfo`:
* Drop `intrinsicSize`, too specific to particular levels of swizzling.
Also drop `innerTileShapes` and `srcRank` as redundant with the general
expand_shape array that I wanted to introduce anyway (next point).
* Introduce a single optional encapsulated `swizzle` containing just
generic expandshape and permutation arrays, so they are agnostic and
dont need to be generalized again.
* Changes to `MaterializeEncoding` logic:
* Just create the `expand_shape` and `tranpose` dictated by the
`swizzle` structure. So this is completely general and doesn't introduce
unit dimensions to fit a rigid shape structure anymore.
* The removal of the unit dims accounts for the bulk of the lit test
differences, particularly in cases where the unit dims used to prevent
folding. Now the folding happens, so the tests observe a single folded
`pack` or `unpack` in some cases where the `expand_shape` and
`transpose` were folded.
* This relies on an upstream fix,
llvm/llvm-project#107271. This PR currently
cherry-picks it, so you will need `git submodule update`.
  * Generalize some places that were hardcoding 2D matrices.
* Generally drop lots of logic that where inferring kernel shapes that
are now consistently something that the user passes, not something that
we compute (in the diff, search for `std::sqrt` to see the kind of code
being dropped here).
* Changes to `multi_mma` op:
  * Update the verifier.

Signed-off-by: Benoit Jacob <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants