Skip to content

[mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs #93055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 6, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented May 22, 2024

This PR adds transpose + pack/unpack folding support for transpose ops in the form of linalg.generic ops. There were also some bugs with the permutation composing in the previous patterns, so this PR fixes these bugs and adds tests for them as well.

@llvmbot
Copy link
Member

llvmbot commented May 22, 2024

@llvm/pr-subscribers-mlir-tensor

Author: None (Max191)

Changes

This PR adds transpose + pack/unpack folding support for transpose ops in the form of linalg.generic ops. There were also some bugs with the permutation composing in the previous patterns, so this PR fixes these bugs and adds tests for them as well.


Full diff: https://github.com/llvm/llvm-project/pull/93055.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+87-39)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+139)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..ce5fda8e79d65 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
   return success();
 }
 
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+    return SmallVector<int64_t>(transposeOp.getPermutation());
+  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+    return failure();
+
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+    return failure();
+  auto mapRange = linalgOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
+      !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+    return failure();
+  }
+  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+    return failure();
+  AffineMap outMap = mapRange.back();
+  AffineMap inMap = mapRange.front();
+  // To get the permutation, look at each output index and find which
+  // dimension in the input we're reading from for that index.
+  return llvm::map_to_vector(outMap.getResults(),
+                             [&](AffineExpr expr) -> int64_t {
+                               return *inMap.getResultPosition(expr);
+                             });
+}
+
 /// Packing one-dimensional tensor can be expressed as an expand shape op.
 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -244,14 +272,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 
   for (unsigned int i = 0; i < rank; ++i) {
     int64_t remappedPosition = permutation[i];
-
-    if (!inVec.empty()) {
-      if (remappedPosition >= rank) {
-        return false;
-      }
+    if (remappedPosition >= rank)
+      return false;
+    if (!inVec.empty())
       remappedPosition = inVec[remappedPosition];
-    }
-
     resVec.push_back(remappedPosition);
   }
 
@@ -261,20 +285,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
 /// semantics.
 struct FoldProducerPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+    auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
 
     if (!packOp)
       return failure();
 
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
     auto innerDimsPos = packOp.getInnerDimsPos();
     auto mixedInnerTiles = packOp.getMixedTiles();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
+    auto transposePerm = maybePerm.value();
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -283,7 +313,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
                          srcRank))
       return rewriter.notifyMatchFailure(
-          transposeOp,
+          linalgOp,
           "Cannot fold in tensor.pack if a tile dimension was transposed "
           "with a non-tile dimension in linalg.transpose.");
 
@@ -295,11 +325,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     }
 
     Value output = packOp.createDestinationTensor(
-        rewriter, transposeOp.getLoc(), packOp.getSource(),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+        rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+        newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+        linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -314,12 +344,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     auto innerDimsPos = packOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
@@ -335,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
       newInnerDimsPosVec.push_back(transposePermutation[dim]);
 
     Value output = packOp.createDestinationTensor(
-        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, packOp.getLoc(), linalgOp->getOperand(0),
         packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -349,22 +384,29 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+    auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
 
     if (!unPackOp)
       return failure();
 
-    auto transposePermutation = transposeOp.getPermutation();
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
-    SmallVector<int64_t> newOuterDimsPermVec =
-        llvm::to_vector(transposePermutation);
+    SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
 
     if (!outerDimsPerm.empty())
       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -372,11 +414,11 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
     for (auto dim : innerDimsPos)
-      newInnerDimsPosVec.push_back(transposePermutation[dim]);
+      newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
 
     // Reuse the destination of the transpose op.
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+        linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
         newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
 
     return success();
@@ -391,13 +433,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -406,7 +454,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
 
-    if (!checkAndPermute(transposePermutation, outerDimsPerm,
+    if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
           unPackOp,
@@ -414,18 +462,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
           "with a non-tile dimension in linalg.transpose.");
 
     // Process transpose operation for tiled inner dimensions
-    for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
-      int64_t remappedPosition = transposePermutation[i] - destRank;
+    for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+      int64_t remappedPosition = inverseTransposePerm[i] - destRank;
       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
     }
 
     Value output = unPackOp.createDestinationTensor(
-        rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
         newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, newOuterDimsPermVec);
 
     return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9f486f9146ad8..fca6eddaca436 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
 //  CHECK-SAME:        into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
 //       CHECK:       return %[[UNPACK]] : tensor<100x71x64xf32>
 //       CHECK:    }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 1, 4, 3]
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.transpose
+    ins(%unpack : tensor<3x56x3648xf32>)
+    outs(%1 : tensor<3648x3x56xf32>)
+    permutation = [2, 0, 1]
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_non_involution_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 4, 1, 3]
+  %1 = tensor.empty() : tensor<5x32x12xi32>
+  %unpack = tensor.unpack %transposed
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+  return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL:  func.func @transpose_unpacked_dims_no_fold(
+//      CHECK:     linalg.transpose
+//      CHECK:     tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>) {
+  ^bb0(%in : i32, %out : i32):
+    linalg.yield %in : i32
+  } -> tensor<5x2x3x16x4xi32>
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @generic_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%unpack : tensor<3x56x3648xf32>)
+                outs(%1 : tensor<3648x3x56xf32>) {
+  ^bb0(%in : f32, %out : f32):
+    linalg.yield %in : f32
+  } -> tensor<3648x3x56xf32>
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_generic_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }

@llvmbot
Copy link
Member

llvmbot commented May 22, 2024

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

This PR adds transpose + pack/unpack folding support for transpose ops in the form of linalg.generic ops. There were also some bugs with the permutation composing in the previous patterns, so this PR fixes these bugs and adds tests for them as well.


Full diff: https://github.com/llvm/llvm-project/pull/93055.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+87-39)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+139)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..ce5fda8e79d65 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
   return success();
 }
 
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+    return SmallVector<int64_t>(transposeOp.getPermutation());
+  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+    return failure();
+
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+    return failure();
+  auto mapRange = linalgOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
+      !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+    return failure();
+  }
+  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+    return failure();
+  AffineMap outMap = mapRange.back();
+  AffineMap inMap = mapRange.front();
+  // To get the permutation, look at each output index and find which
+  // dimension in the input we're reading from for that index.
+  return llvm::map_to_vector(outMap.getResults(),
+                             [&](AffineExpr expr) -> int64_t {
+                               return *inMap.getResultPosition(expr);
+                             });
+}
+
 /// Packing one-dimensional tensor can be expressed as an expand shape op.
 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -244,14 +272,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 
   for (unsigned int i = 0; i < rank; ++i) {
     int64_t remappedPosition = permutation[i];
-
-    if (!inVec.empty()) {
-      if (remappedPosition >= rank) {
-        return false;
-      }
+    if (remappedPosition >= rank)
+      return false;
+    if (!inVec.empty())
       remappedPosition = inVec[remappedPosition];
-    }
-
     resVec.push_back(remappedPosition);
   }
 
@@ -261,20 +285,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
 /// semantics.
 struct FoldProducerPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+    auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
 
     if (!packOp)
       return failure();
 
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
     auto innerDimsPos = packOp.getInnerDimsPos();
     auto mixedInnerTiles = packOp.getMixedTiles();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
+    auto transposePerm = maybePerm.value();
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -283,7 +313,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
                          srcRank))
       return rewriter.notifyMatchFailure(
-          transposeOp,
+          linalgOp,
           "Cannot fold in tensor.pack if a tile dimension was transposed "
           "with a non-tile dimension in linalg.transpose.");
 
@@ -295,11 +325,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     }
 
     Value output = packOp.createDestinationTensor(
-        rewriter, transposeOp.getLoc(), packOp.getSource(),
-        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+        rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+        newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+        linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -314,12 +344,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
     auto outerDimsPerm = packOp.getOuterDimsPerm();
     auto innerDimsPos = packOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
@@ -335,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
       newInnerDimsPosVec.push_back(transposePermutation[dim]);
 
     Value output = packOp.createDestinationTensor(
-        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, packOp.getLoc(), linalgOp->getOperand(0),
         packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<PackOp>(
-        packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
 
     return success();
@@ -349,22 +384,29 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
-    : public OpRewritePattern<linalg::TransposeOp> {
-  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+    : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+    auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
 
     if (!unPackOp)
       return failure();
 
-    auto transposePermutation = transposeOp.getPermutation();
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
+      return failure();
+    }
+
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     SmallVector<int64_t> newInnerDimsPosVec;
-    SmallVector<int64_t> newOuterDimsPermVec =
-        llvm::to_vector(transposePermutation);
+    SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
 
     if (!outerDimsPerm.empty())
       applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -372,11 +414,11 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
     // permutation rank won't necessarily be equal in all cases.
     for (auto dim : innerDimsPos)
-      newInnerDimsPosVec.push_back(transposePermutation[dim]);
+      newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
 
     // Reuse the destination of the transpose op.
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+        linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
         newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
 
     return success();
@@ -391,13 +433,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+    auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+    if (!linalgOp)
+      return failure();
 
-    if (!transposeOp)
+    FailureOr<SmallVector<int64_t>> maybePerm =
+        getTransposeOpPermutation(linalgOp);
+    if (failed(maybePerm)) {
       return failure();
+    }
 
-    auto transposePermutation = transposeOp.getPermutation();
+    auto transposePermutation = maybePerm.value();
+    SmallVector<int64_t> inverseTransposePerm =
+        invertPermutationVector(transposePermutation);
     auto outerDimsPerm = unPackOp.getOuterDimsPerm();
     auto innerDimsPos = unPackOp.getInnerDimsPos();
     int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -406,7 +454,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
 
-    if (!checkAndPermute(transposePermutation, outerDimsPerm,
+    if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
                          newOuterDimsPermVec, destRank))
       return rewriter.notifyMatchFailure(
           unPackOp,
@@ -414,18 +462,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
           "with a non-tile dimension in linalg.transpose.");
 
     // Process transpose operation for tiled inner dimensions
-    for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
-      int64_t remappedPosition = transposePermutation[i] - destRank;
+    for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+      int64_t remappedPosition = inverseTransposePerm[i] - destRank;
       newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
       newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
     }
 
     Value output = unPackOp.createDestinationTensor(
-        rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+        rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
         newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
 
     rewriter.replaceOpWithNewOp<UnPackOp>(
-        unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
         newMixedInnerTilesVec, newOuterDimsPermVec);
 
     return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9f486f9146ad8..fca6eddaca436 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
 //  CHECK-SAME:        into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
 //       CHECK:       return %[[UNPACK]] : tensor<100x71x64xf32>
 //       CHECK:    }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 1, 4, 3]
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.transpose
+    ins(%unpack : tensor<3x56x3648xf32>)
+    outs(%1 : tensor<3648x3x56xf32>)
+    permutation = [2, 0, 1]
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_non_involution_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>)
+                permutation = [2, 0, 4, 1, 3]
+  %1 = tensor.empty() : tensor<5x32x12xi32>
+  %unpack = tensor.unpack %transposed
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+  return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL:  func.func @transpose_unpacked_dims_no_fold(
+//      CHECK:     linalg.transpose
+//      CHECK:     tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+  %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%arg0 : tensor<2x3x5x4x16xi32>)
+                outs(%0 : tensor<5x2x3x16x4xi32>) {
+  ^bb0(%in : i32, %out : i32):
+    linalg.yield %in : i32
+  } -> tensor<5x2x3x16x4xi32>
+  %1 = tensor.empty() : tensor<5x48x8xi32>
+  %unpack = tensor.unpack %transposed
+            outer_dims_perm = [0, 2, 1]
+            inner_dims_pos = [1, 2]
+            inner_tiles = [16, 4] into
+            %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+  return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL:  func.func @generic_transpose_unpack_fold(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+//      CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [2, 1, 0]
+// CHECK-SAME:        inner_dims_pos = [2, 1]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHEKC-SAME:        into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+//      CHECK:     return %[[UNPACK]] : tensor<5x48x8xi32>
+//      CHECK:   }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+  %0 = tensor.empty() : tensor<3x56x3648xf32>
+  %unpack = tensor.unpack %arg0
+    outer_dims_perm = [2, 0, 1]
+    inner_dims_pos = [1, 2]
+    inner_tiles = [1, 64]
+    into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+  %1 = tensor.empty() : tensor<3648x3x56xf32>
+  %transposed = linalg.generic {
+                iterator_types = ["parallel", "parallel", "parallel"],
+                indexing_maps = [#map, #map1]}
+                ins(%unpack : tensor<3x56x3648xf32>)
+                outs(%1 : tensor<3648x3x56xf32>) {
+  ^bb0(%in : f32, %out : f32):
+    linalg.yield %in : f32
+  } -> tensor<3648x3x56xf32>
+  return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL:  func.func @unpack_generic_transpose_fold(
+//  CHECK-SAME:    %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+//       CHECK:        %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+//       CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:        outer_dims_perm = [0, 1, 2]
+//  CHECK-SAME:        inner_dims_pos = [2, 0]
+//  CHECK-SAME:        inner_tiles = [1, 64]
+//  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+//       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
+//       CHECK:    }

@hanhanW hanhanW requested review from pashu123 and chelini May 22, 2024 18:20
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
return failure();
}
if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious. If we don't add this restriction, we can still replace the block with linalg generic having a one-to-one map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we can keep it restricted to pure transpose ops for these patterns. Maybe we could add a new pattern for something like this in a later PR.

In general, a unary element-wise operation like that will get fused with it's producer/consumer in element-wise fusion. After this you end up with multi-input ops, and it may not be as simple to figure out whether it is beneficial to transpose the generic op. If you do this remapping before element-wise fusion, then I think there could be cases where you add extra transposing that would otherwise fold. Overall, I think it is not a bad idea to try to move the transpose work to the pack op, but I think it takes a bit more careful thought/matching, so I think it is better suited for a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could also be useful to use after dispatch formation, since pack ops often get fused with producers. I think it's definitely worth trying at some point.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can propagate the tranpose semantics from generic with payload operations this would make it more general.

Comment on lines 63 to 66
if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need to check mapRange.size() != 2 because it should already be checked by verifier. Perhaps using assertion instead?

Comment on lines 300 to 304
if (failed(maybePerm)) {
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm style nit: do not use braces for single if-statement.

Comment on lines 353 to 357
if (failed(maybePerm)) {
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: remove {}

Comment on lines 399 to 403
if (failed(maybePerm)) {
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines 403 to 411
auto transposePermutation = maybePerm.value();
SmallVector<int64_t> inverseTransposePerm =
invertPermutationVector(transposePermutation);
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<int64_t> newOuterDimsPermVec =
llvm::to_vector(transposePermutation);
SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can collapse your changes into a single line. Having variables does not really help the readability because the functions document it by their names.

SmallVector<int64_t> newOuterDimsPermVec = invertPermutationVector(maybePerm.value());

if (!transposeOp)
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, remove {}

Comment on lines 446 to 450
auto transposePermutation = maybePerm.value();
SmallVector<int64_t> inverseTransposePerm =
invertPermutationVector(transposePermutation);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I think collapsing them into a single statement is better.

@Max191 Max191 force-pushed the add-generic-pack-transpose-folding branch from 1e7b0e6 to e422446 Compare June 4, 2024 20:16
@Max191 Max191 requested a review from hanhanW June 4, 2024 20:17
@Max191 Max191 merged commit 7ef83f5 into llvm:main Jun 6, 2024
7 checks passed
nirvedhmeshram added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2024
nirvedhmeshram added a commit to iree-org/iree that referenced this pull request Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants