-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs #93055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor Author: None (Max191) ChangesThis PR adds transpose + pack/unpack folding support for transpose ops in the form of Full diff: https://github.com/llvm/llvm-project/pull/93055.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..ce5fda8e79d65 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
return success();
}
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+ if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+ return SmallVector<int64_t>(transposeOp.getPermutation());
+ if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ return failure();
+
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+ return failure();
+ auto mapRange = linalgOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
+ !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+ return failure();
+ }
+ if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+ return failure();
+ AffineMap outMap = mapRange.back();
+ AffineMap inMap = mapRange.front();
+ // To get the permutation, look at each output index and find which
+ // dimension in the input we're reading from for that index.
+ return llvm::map_to_vector(outMap.getResults(),
+ [&](AffineExpr expr) -> int64_t {
+ return *inMap.getResultPosition(expr);
+ });
+}
+
/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -244,14 +272,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
for (unsigned int i = 0; i < rank; ++i) {
int64_t remappedPosition = permutation[i];
-
- if (!inVec.empty()) {
- if (remappedPosition >= rank) {
- return false;
- }
+ if (remappedPosition >= rank)
+ return false;
+ if (!inVec.empty())
remappedPosition = inVec[remappedPosition];
- }
-
resVec.push_back(remappedPosition);
}
@@ -261,20 +285,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+ auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
if (!packOp)
return failure();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
+ return failure();
+ }
+
auto innerDimsPos = packOp.getInnerDimsPos();
auto mixedInnerTiles = packOp.getMixedTiles();
auto outerDimsPerm = packOp.getOuterDimsPerm();
- auto transposePerm = transposeOp.getPermutation();
+ auto transposePerm = maybePerm.value();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -283,7 +313,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
srcRank))
return rewriter.notifyMatchFailure(
- transposeOp,
+ linalgOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
@@ -295,11 +325,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
}
Value output = packOp.createDestinationTensor(
- rewriter, transposeOp.getLoc(), packOp.getSource(),
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+ rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+ newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
- transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+ linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
@@ -314,12 +344,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
- auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+ auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp)
+ return failure();
- if (!transposeOp)
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
return failure();
+ }
- auto transposePermutation = transposeOp.getPermutation();
+ auto transposePermutation = maybePerm.value();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
@@ -335,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
newInnerDimsPosVec.push_back(transposePermutation[dim]);
Value output = packOp.createDestinationTensor(
- rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+ rewriter, packOp.getLoc(), linalgOp->getOperand(0),
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
- packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+ packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
@@ -349,22 +384,29 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+ auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
if (!unPackOp)
return failure();
- auto transposePermutation = transposeOp.getPermutation();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
+ return failure();
+ }
+
+ auto transposePermutation = maybePerm.value();
+ SmallVector<int64_t> inverseTransposePerm =
+ invertPermutationVector(transposePermutation);
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
- SmallVector<int64_t> newOuterDimsPermVec =
- llvm::to_vector(transposePermutation);
+ SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -372,11 +414,11 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
for (auto dim : innerDimsPos)
- newInnerDimsPosVec.push_back(transposePermutation[dim]);
+ newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
// Reuse the destination of the transpose op.
rewriter.replaceOpWithNewOp<UnPackOp>(
- transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+ linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
return success();
@@ -391,13 +433,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
- auto transposeOp =
- unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+ auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp)
+ return failure();
- if (!transposeOp)
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
return failure();
+ }
- auto transposePermutation = transposeOp.getPermutation();
+ auto transposePermutation = maybePerm.value();
+ SmallVector<int64_t> inverseTransposePerm =
+ invertPermutationVector(transposePermutation);
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -406,7 +454,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
- if (!checkAndPermute(transposePermutation, outerDimsPerm,
+ if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
unPackOp,
@@ -414,18 +462,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
"with a non-tile dimension in linalg.transpose.");
// Process transpose operation for tiled inner dimensions
- for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
- int64_t remappedPosition = transposePermutation[i] - destRank;
+ for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+ int64_t remappedPosition = inverseTransposePerm[i] - destRank;
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
Value output = unPackOp.createDestinationTensor(
- rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+ rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<UnPackOp>(
- unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+ unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
newMixedInnerTilesVec, newOuterDimsPermVec);
return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9f486f9146ad8..fca6eddaca436 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
// CHECK: }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>)
+ permutation = [2, 0, 1, 4, 3]
+ %1 = tensor.empty() : tensor<5x48x8xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 2, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+ return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 1, 0]
+// CHECK-SAME: inner_dims_pos = [2, 1]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
+// CHECK: }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+ %0 = tensor.empty() : tensor<3x56x3648xf32>
+ %unpack = tensor.unpack %arg0
+ outer_dims_perm = [2, 0, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [1, 64]
+ into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+ %1 = tensor.empty() : tensor<3648x3x56xf32>
+ %transposed = linalg.transpose
+ ins(%unpack : tensor<3x56x3648xf32>)
+ outs(%1 : tensor<3648x3x56xf32>)
+ permutation = [2, 0, 1]
+ return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2]
+// CHECK-SAME: inner_dims_pos = [2, 0]
+// CHECK-SAME: inner_tiles = [1, 64]
+// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
+// CHECK: }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>)
+ permutation = [2, 0, 4, 1, 3]
+ %1 = tensor.empty() : tensor<5x32x12xi32>
+ %unpack = tensor.unpack %transposed
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+ return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
+// CHECK: linalg.transpose
+// CHECK: tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.generic {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = [#map, #map1]}
+ ins(%arg0 : tensor<2x3x5x4x16xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>) {
+ ^bb0(%in : i32, %out : i32):
+ linalg.yield %in : i32
+ } -> tensor<5x2x3x16x4xi32>
+ %1 = tensor.empty() : tensor<5x48x8xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 2, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+ return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 1, 0]
+// CHECK-SAME: inner_dims_pos = [2, 1]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
+// CHECK: }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+ %0 = tensor.empty() : tensor<3x56x3648xf32>
+ %unpack = tensor.unpack %arg0
+ outer_dims_perm = [2, 0, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [1, 64]
+ into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+ %1 = tensor.empty() : tensor<3648x3x56xf32>
+ %transposed = linalg.generic {
+ iterator_types = ["parallel", "parallel", "parallel"],
+ indexing_maps = [#map, #map1]}
+ ins(%unpack : tensor<3x56x3648xf32>)
+ outs(%1 : tensor<3648x3x56xf32>) {
+ ^bb0(%in : f32, %out : f32):
+ linalg.yield %in : f32
+ } -> tensor<3648x3x56xf32>
+ return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2]
+// CHECK-SAME: inner_dims_pos = [2, 0]
+// CHECK-SAME: inner_tiles = [1, 64]
+// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
+// CHECK: }
|
@llvm/pr-subscribers-mlir Author: None (Max191) ChangesThis PR adds transpose + pack/unpack folding support for transpose ops in the form of Full diff: https://github.com/llvm/llvm-project/pull/93055.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index ebcb34e9ef024..ce5fda8e79d65 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
return success();
}
+// If the `linalgOp` represents a transpose, return the permutation vector for
+// the transpose. Otherwise, return failure.
+static FailureOr<SmallVector<int64_t>>
+getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
+ if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
+ return SmallVector<int64_t>(transposeOp.getPermutation());
+ if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ return failure();
+
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+ return failure();
+ auto mapRange = linalgOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
+ !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
+ return failure();
+ }
+ if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
+ return failure();
+ AffineMap outMap = mapRange.back();
+ AffineMap inMap = mapRange.front();
+ // To get the permutation, look at each output index and find which
+ // dimension in the input we're reading from for that index.
+ return llvm::map_to_vector(outMap.getResults(),
+ [&](AffineExpr expr) -> int64_t {
+ return *inMap.getResultPosition(expr);
+ });
+}
+
/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -244,14 +272,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
for (unsigned int i = 0; i < rank; ++i) {
int64_t remappedPosition = permutation[i];
-
- if (!inVec.empty()) {
- if (remappedPosition >= rank) {
- return false;
- }
+ if (remappedPosition >= rank)
+ return false;
+ if (!inVec.empty())
remappedPosition = inVec[remappedPosition];
- }
-
resVec.push_back(remappedPosition);
}
@@ -261,20 +285,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
+ auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
if (!packOp)
return failure();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
+ return failure();
+ }
+
auto innerDimsPos = packOp.getInnerDimsPos();
auto mixedInnerTiles = packOp.getMixedTiles();
auto outerDimsPerm = packOp.getOuterDimsPerm();
- auto transposePerm = transposeOp.getPermutation();
+ auto transposePerm = maybePerm.value();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -283,7 +313,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
srcRank))
return rewriter.notifyMatchFailure(
- transposeOp,
+ linalgOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
@@ -295,11 +325,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
}
Value output = packOp.createDestinationTensor(
- rewriter, transposeOp.getLoc(), packOp.getSource(),
- newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+ rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
+ newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
- transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
+ linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
@@ -314,12 +344,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
- auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+ auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp)
+ return failure();
- if (!transposeOp)
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
return failure();
+ }
- auto transposePermutation = transposeOp.getPermutation();
+ auto transposePermutation = maybePerm.value();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
@@ -335,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
newInnerDimsPosVec.push_back(transposePermutation[dim]);
Value output = packOp.createDestinationTensor(
- rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+ rewriter, packOp.getLoc(), linalgOp->getOperand(0),
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
- packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+ packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
@@ -349,22 +384,29 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
- : public OpRewritePattern<linalg::TransposeOp> {
- using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<linalg::LinalgOp> {
+ using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
- auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
+ auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
if (!unPackOp)
return failure();
- auto transposePermutation = transposeOp.getPermutation();
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
+ return failure();
+ }
+
+ auto transposePermutation = maybePerm.value();
+ SmallVector<int64_t> inverseTransposePerm =
+ invertPermutationVector(transposePermutation);
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
- SmallVector<int64_t> newOuterDimsPermVec =
- llvm::to_vector(transposePermutation);
+ SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -372,11 +414,11 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
for (auto dim : innerDimsPos)
- newInnerDimsPosVec.push_back(transposePermutation[dim]);
+ newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
// Reuse the destination of the transpose op.
rewriter.replaceOpWithNewOp<UnPackOp>(
- transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
+ linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
return success();
@@ -391,13 +433,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
- auto transposeOp =
- unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
+ auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
+ if (!linalgOp)
+ return failure();
- if (!transposeOp)
+ FailureOr<SmallVector<int64_t>> maybePerm =
+ getTransposeOpPermutation(linalgOp);
+ if (failed(maybePerm)) {
return failure();
+ }
- auto transposePermutation = transposeOp.getPermutation();
+ auto transposePermutation = maybePerm.value();
+ SmallVector<int64_t> inverseTransposePerm =
+ invertPermutationVector(transposePermutation);
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -406,7 +454,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
- if (!checkAndPermute(transposePermutation, outerDimsPerm,
+ if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
unPackOp,
@@ -414,18 +462,18 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
"with a non-tile dimension in linalg.transpose.");
// Process transpose operation for tiled inner dimensions
- for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
- int64_t remappedPosition = transposePermutation[i] - destRank;
+ for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
+ int64_t remappedPosition = inverseTransposePerm[i] - destRank;
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
Value output = unPackOp.createDestinationTensor(
- rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
+ rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<UnPackOp>(
- unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+ unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
newMixedInnerTilesVec, newOuterDimsPermVec);
return success();
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 9f486f9146ad8..fca6eddaca436 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
// CHECK: }
+
+// -----
+
+func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>)
+ permutation = [2, 0, 1, 4, 3]
+ %1 = tensor.empty() : tensor<5x48x8xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 2, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+ return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 1, 0]
+// CHECK-SAME: inner_dims_pos = [2, 1]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
+// CHECK: }
+
+// -----
+
+func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+ %0 = tensor.empty() : tensor<3x56x3648xf32>
+ %unpack = tensor.unpack %arg0
+ outer_dims_perm = [2, 0, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [1, 64]
+ into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+ %1 = tensor.empty() : tensor<3648x3x56xf32>
+ %transposed = linalg.transpose
+ ins(%unpack : tensor<3x56x3648xf32>)
+ outs(%1 : tensor<3648x3x56xf32>)
+ permutation = [2, 0, 1]
+ return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2]
+// CHECK-SAME: inner_dims_pos = [2, 0]
+// CHECK-SAME: inner_tiles = [1, 64]
+// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
+// CHECK: }
+
+// -----
+
+func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>)
+ permutation = [2, 0, 4, 1, 3]
+ %1 = tensor.empty() : tensor<5x32x12xi32>
+ %unpack = tensor.unpack %transposed
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
+ return %unpack : tensor<5x32x12xi32>
+}
+//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
+// CHECK: linalg.transpose
+// CHECK: tensor.unpack
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
+#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
+func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+ %0 = tensor.empty() : tensor<5x2x3x16x4xi32>
+ %transposed = linalg.generic {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = [#map, #map1]}
+ ins(%arg0 : tensor<2x3x5x4x16xi32>)
+ outs(%0 : tensor<5x2x3x16x4xi32>) {
+ ^bb0(%in : i32, %out : i32):
+ linalg.yield %in : i32
+ } -> tensor<5x2x3x16x4xi32>
+ %1 = tensor.empty() : tensor<5x48x8xi32>
+ %unpack = tensor.unpack %transposed
+ outer_dims_perm = [0, 2, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 4] into
+ %1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
+ return %unpack : tensor<5x48x8xi32>
+}
+//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 1, 0]
+// CHECK-SAME: inner_dims_pos = [2, 1]
+// CHECK-SAME: inner_tiles = [4, 16]
+// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
+// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
+// CHECK: }
+
+// -----
+
+#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
+#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
+func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+ %0 = tensor.empty() : tensor<3x56x3648xf32>
+ %unpack = tensor.unpack %arg0
+ outer_dims_perm = [2, 0, 1]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [1, 64]
+ into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
+
+ %1 = tensor.empty() : tensor<3648x3x56xf32>
+ %transposed = linalg.generic {
+ iterator_types = ["parallel", "parallel", "parallel"],
+ indexing_maps = [#map, #map1]}
+ ins(%unpack : tensor<3x56x3648xf32>)
+ outs(%1 : tensor<3648x3x56xf32>) {
+ ^bb0(%in : f32, %out : f32):
+ linalg.yield %in : f32
+ } -> tensor<3648x3x56xf32>
+ return %transposed : tensor<3648x3x56xf32>
+}
+// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
+// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2]
+// CHECK-SAME: inner_dims_pos = [2, 0]
+// CHECK-SAME: inner_tiles = [1, 64]
+// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
+// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
+// CHECK: }
|
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) { | ||
return failure(); | ||
} | ||
if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious. If we don't add this restriction, we can still replace the block with linalg generic having a one-to-one map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for now we can keep it restricted to pure transpose ops for these patterns. Maybe we could add a new pattern for something like this in a later PR.
In general, a unary element-wise operation like that will get fused with it's producer/consumer in element-wise fusion. After this you end up with multi-input ops, and it may not be as simple to figure out whether it is beneficial to transpose the generic op. If you do this remapping before element-wise fusion, then I think there could be cases where you add extra transposing that would otherwise fold. Overall, I think it is not a bad idea to try to move the transpose work to the pack op, but I think it takes a bit more careful thought/matching, so I think it is better suited for a later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could also be useful to use after dispatch formation, since pack ops often get fused with producers. I think it's definitely worth trying at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can propagate the tranpose semantics from generic with payload operations this would make it more general.
if (mapRange.size() != 2 || !mapRange.front().isPermutation() || | ||
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont need to check mapRange.size() != 2
because it should already be checked by verifier. Perhaps using assertion instead?
if (failed(maybePerm)) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm style nit: do not use braces for single if-statement.
if (failed(maybePerm)) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: remove {}
if (failed(maybePerm)) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
auto transposePermutation = maybePerm.value(); | ||
SmallVector<int64_t> inverseTransposePerm = | ||
invertPermutationVector(transposePermutation); | ||
auto outerDimsPerm = unPackOp.getOuterDimsPerm(); | ||
auto innerDimsPos = unPackOp.getInnerDimsPos(); | ||
SmallVector<int64_t> newInnerDimsPosVec; | ||
SmallVector<int64_t> newOuterDimsPermVec = | ||
llvm::to_vector(transposePermutation); | ||
SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can collapse your changes into a single line. Having variables does not really help the readability because the functions document it by their names.
SmallVector<int64_t> newOuterDimsPermVec = invertPermutationVector(maybePerm.value());
if (!transposeOp) | ||
FailureOr<SmallVector<int64_t>> maybePerm = | ||
getTransposeOpPermutation(linalgOp); | ||
if (failed(maybePerm)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, remove {}
auto transposePermutation = maybePerm.value(); | ||
SmallVector<int64_t> inverseTransposePerm = | ||
invertPermutationVector(transposePermutation); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I think collapsing them into a single statement is better.
1e7b0e6
to
e422446
Compare
…ops, fix bugs (llvm#93055)" This reverts commit 7ef83f5.
This PR adds transpose + pack/unpack folding support for transpose ops in the form of
linalg.generic
ops. There were also some bugs with the permutation composing in the previous patterns, so this PR fixes these bugs and adds tests for them as well.