-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops #85297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
569fee5
to
370ae55
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Jerry Wu (pzread) ChangesAdd DataLayoutPropagation patterns to bubble-up pack and push-down unpack through collapse/expand shape ops. The current support cases are:
Patch is 22.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85297.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 5ceb85e7d9903b..992a8916d90937 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -552,6 +553,289 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
ControlPropagationFn controlFn;
};
+/// Project dimsPos to the inner-most non unit dim pos with reassocIndices.
+/// For example: Given dimsPos: [0, 2], reassocIndices: [[0, 1], [2, 3]], and
+/// targetShape: [3, 4, 5, 1], it returns [1, 2]. Because for pos 0, the
+/// inner-most projected dim in [0, 1] is 1. And for pos 2, the inner-most
+/// non-unit projected dims in [2, 3] is 2.
+///
+/// If all projected dims are unit dims, it chooses the inner-most dim pos.
+static SmallVector<int64_t>
+projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
+ ArrayRef<ReassociationIndices> reassocIndices,
+ ArrayRef<int64_t> targetShape) {
+ SmallVector<int64_t> projectedDimsPos;
+ for (auto pos : dimsPos) {
+ // In the case all dims are unit, this will return the inner-most one.
+ int64_t projectedPos = reassocIndices[pos].back();
+ for (auto it = reassocIndices[pos].rbegin();
+ it != reassocIndices[pos].rend(); ++it) {
+ int64_t dim = targetShape[*it];
+ if (dim > 1 || ShapedType::isDynamic(dim)) {
+ projectedPos = *it;
+ break;
+ }
+ }
+ projectedDimsPos.push_back(projectedPos);
+ }
+ return projectedDimsPos;
+}
+
+/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
+static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
+ ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> tileSizes) {
+ for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
+ int64_t dim = shape[pos];
+ if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
+ return false;
+ }
+ return true;
+}
+
+/// Permutate the reassociation indices and reindex them in the sequence order.
+/// For example: given reassociationIndices: [[0, 1], [2]] and permutation: [1,
+/// 0], it applies the permutation to get [[2], [0, 1]] and reindexes the
+/// indices into [[0], [1, 2]].
+static int64_t applyPermutationAndReindexReassoc(
+ SmallVector<ReassociationIndices> &reassociationIndices,
+ ArrayRef<int64_t> permutation) {
+ applyPermutationToVector<ReassociationIndices>(reassociationIndices,
+ permutation);
+ int64_t lastPos = 0;
+ for (ReassociationIndices &indices : reassociationIndices) {
+ for (auto &index : indices) {
+ index = lastPos;
+ lastPos += 1;
+ }
+ }
+ return lastPos;
+}
+
+/// Bubble up pack op through collapse shape op when the packed dims can be
+/// projected to the dims before collapsing. This is possible when the inner
+/// tile sizes can divide the projected dims.
+///
+/// For example:
+///
+/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
+/// : tensor<?x16x4xf32> into tensor<?x4xf32>
+/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
+/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
+/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+///
+/// Can be transformed into:
+///
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
+/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
+/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
+/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
+static LogicalResult
+bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
+ tensor::PackOp packOp,
+ PatternRewriter &rewriter) {
+ SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+
+ ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
+ SmallVector<ReassociationIndices> reassocIndices =
+ collapseOp.getReassociationIndices();
+ SmallVector<int64_t> projectedInnerDimsPos =
+ projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
+
+ if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
+ innerTileSizes)) {
+ return failure();
+ }
+ // Expand the outer dims permutation with the associated source dims for the
+ // new permutation after bubbling. This is because moving a collapsed dim is
+ // equivalent to moving the associated source dims together.
+ SmallVector<int64_t> newOuterDimsPerm;
+ for (auto outerPos : outerDimsPerm) {
+ newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+ reassocIndices[outerPos].begin(),
+ reassocIndices[outerPos].end());
+ }
+
+ auto emptyOp = tensor::PackOp::createDestinationTensor(
+ rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
+ projectedInnerDimsPos, newOuterDimsPerm);
+ auto newPackOp = rewriter.create<tensor::PackOp>(
+ packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
+ packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
+
+ SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+ // First apply the permutation on the reassociations of the outer dims.
+ // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+ // -> [[0], [1, 2]]
+ int64_t lastPos =
+ applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+ // Then add direct mapping for the inner tile dims.
+ for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+ newReassocIndices.push_back({lastPos});
+ lastPos += 1;
+ }
+
+ auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
+ collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
+ rewriter.replaceOp(packOp, newCollapseOp);
+
+ return success();
+}
+
+class BubbleUpPackOpThroughReshapeOp final
+ : public OpRewritePattern<tensor::PackOp> {
+public:
+ BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ // User controlled propagation function.
+ if (!controlFn(packOp))
+ return failure();
+
+ Operation *srcOp = packOp.getSource().getDefiningOp();
+ // Currently only support when the pack op is the only user.
+ if (!srcOp || !(srcOp->getNumResults() == 1) ||
+ !srcOp->getResult(0).hasOneUse()) {
+ return failure();
+ }
+ // Currently only support static inner tile sizes.
+ if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
+ return ShapedType::isDynamic(size);
+ })) {
+ return failure();
+ }
+
+ return TypeSwitch<Operation *, LogicalResult>(srcOp)
+ .Case([&](tensor::CollapseShapeOp op) {
+ return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
+ })
+ .Default([](Operation *) { return failure(); });
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
+/// Push down unpack op through expand shape op when the packed dims can be
+/// projected to the dims after expanding. This is possible when the inner tile
+/// sizes can divide the projected dims.
+///
+/// For example:
+///
+/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
+/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
+/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
+/// : tensor<?x256xf32> into tensor<?x256x256xf32>
+///
+/// Can be transformed into:
+///
+/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
+/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
+/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
+/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+static LogicalResult
+pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
+ tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) {
+ SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
+ ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
+
+ ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+ SmallVector<ReassociationIndices> reassocIndices =
+ expandOp.getReassociationIndices();
+ SmallVector<int64_t> projectedInnerDimsPos =
+ projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
+
+ if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
+ innerTileSizes)) {
+ return failure();
+ }
+ // Expand the outer dims permutation with the associated expanded dims for the
+ // new permutation after pushing. This is because moving a source dim is
+ // equivalent to moving the associated expanded dims together.
+ SmallVector<int64_t> newOuterDimsPerm;
+ for (auto outerPos : outerDimsPerm) {
+ newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+ reassocIndices[outerPos].begin(),
+ reassocIndices[outerPos].end());
+ }
+
+ SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+ // First apply the permutation on the reassociations of the outer dims.
+ // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+ // -> [[0], [1, 2]]
+ int64_t lastPos =
+ applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+ // Then add direct mapping for the inner tile dims.
+ for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+ newReassocIndices.push_back({lastPos});
+ lastPos += 1;
+ }
+
+ RankedTensorType newExpandType =
+ tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
+ projectedInnerDimsPos, newOuterDimsPerm);
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ expandOp.getLoc(), newExpandType, unPackOp.getSource(),
+ newReassocIndices);
+
+ auto emptyOp = tensor::UnPackOp::createDestinationTensor(
+ rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
+ projectedInnerDimsPos, newOuterDimsPerm);
+ auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
+ unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
+ projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
+ rewriter.replaceOp(expandOp, newUnPackOp);
+
+ return success();
+}
+
+class PushDownUnPackOpThroughReshapeOp final
+ : public OpRewritePattern<tensor::UnPackOp> {
+public:
+ PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
+ }
+
+ LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
+ PatternRewriter &rewriter) const override {
+ // User controlled propagation function.
+ if (!controlFn(unPackOp))
+ return failure();
+
+ Value result = unPackOp.getResult();
+ // Currently only support unpack op with the single user.
+ if (!result.hasOneUse()) {
+ return failure();
+ }
+ // Currently only support static inner tile sizes.
+ if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
+ return ShapedType::isDynamic(size);
+ })) {
+ return failure();
+ }
+
+ Operation *userOp = *result.user_begin();
+ return TypeSwitch<Operation *, LogicalResult>(userOp)
+ .Case([&](tensor::ExpandShapeOp op) {
+ return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+ })
+ .Default([](Operation *) { return failure(); });
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
// TODO: Relax this restriction. We should unpack a generic op also
// in the presence of multiple unpack ops as producers.
/// Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +1058,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
- PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+ BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
+ PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index e036695a2ac9fd..10c9f5bafb5c03 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -905,3 +905,127 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
+
+// -----
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+ %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+ func.return %pack : tensor<?x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
+
+// -----
+
+func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
+ %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
+ func.return %pack : tensor<4x32x3072x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
+ %2 = tensor.empty() : tensor<8x4x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
+ func.return %pack : tensor<8x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+ %2 = tensor.empty() : tensor<384x32x8x8xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
+ func.return %pack : tensor<384x32x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
+// CHECK: return %[[PACK]] : tensor<384x32x8x8xf32>
+
+// -----
+
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+ %6 = tensor.empty(%dim) : tensor<?x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+ func.return %expanded : tensor<?x256x256xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_through_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
+
+// -----
+
+func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
+ %6 = tensor.empty() : tensor<4x3072x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+ func.return %expanded : tensor<4x12x256x256xf32>
+}
+// CHECK-LABEL: @push_down_permuted_unpack_through_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] ...
[truncated]
|
01d6a8a
to
fb0a191
Compare
65b3aef
to
4277610
Compare
Kindly ping : ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review, LGTM, thanks!
b732bf8
to
fe58b86
Compare
fe58b86
to
86024e8
Compare
Co-authored-by: Quinn Dawkins <[email protected]>
Add DataLayoutPropagation patterns to bubble-up pack and push-down unpack through collapse/expand shape ops.
This is possible when the inner tile sizes of pack/unpack can divide the projected dims after being swapped with collapse/expand shape ops. For example:
can be transformed into:
The current support cases are:
tensor.collapse_shape
tensor.expand_shape