Skip to content

[MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops #85297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 28, 2024

Conversation

pzread
Copy link
Member

@pzread pzread commented Mar 14, 2024

Add DataLayoutPropagation patterns to bubble-up pack and push-down unpack through collapse/expand shape ops.

This is possible when the inner tile sizes of pack/unpack can divide the projected dims after being swapped with collapse/expand shape ops. For example:

%collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
%pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<3072x256xf32> -> tensor<384x256x8x1xf32>

can be transformed into:

%pack = tensor.pack %1 outer_dims_perm = [0, 1, 2] inner_dims_pos = [0, 2] inner_tiles = [8, 1] into %2 : tensor<3072x64x4xf32> -> tensor<384x64x4x8x1xf32>
%collapsed = tensor.collapse_shape %pack [[0], [1, 2], [3], [4]] : tensor<384x64x4x8x1xf32> into tensor<384x256x8x1xf32>

The current support cases are:

  • Bubble-up pack op through tensor.collapse_shape
  • Push-down unpack op through tensor.expand_shape

@pzread pzread force-pushed the pack-bubble-reshape branch 2 times, most recently from 569fee5 to 370ae55 Compare March 18, 2024 22:45
@pzread pzread changed the title [WIP][MLIR] Add patterns to bubble-up/push-down pack/unpack through reshape [MLIR] Add patterns to bubble-up/push-down pack/unpack through reshape Mar 19, 2024
@pzread pzread changed the title [MLIR] Add patterns to bubble-up/push-down pack/unpack through reshape [MLIR] Add patterns to bubble-up pack and push-down unpack through reshape Mar 19, 2024
@pzread pzread changed the title [MLIR] Add patterns to bubble-up pack and push-down unpack through reshape [MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops Mar 19, 2024
@pzread pzread marked this pull request as ready for review March 19, 2024 17:27
@pzread pzread requested review from qedawkins and dcaballe March 19, 2024 17:27
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Jerry Wu (pzread)

Changes

Add DataLayoutPropagation patterns to bubble-up pack and push-down unpack through collapse/expand shape ops.

The current support cases are:

  • Bubble-up pack op through tensor.collapse_shape
  • Push-down unpack op through tensor.expand_shape

Patch is 22.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85297.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+286-1)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+124)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 5ceb85e7d9903b..992a8916d90937 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -552,6 +553,289 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
   ControlPropagationFn controlFn;
 };
 
+/// Project dimsPos to the inner-most non unit dim pos with reassocIndices.
+/// For example: Given dimsPos: [0, 2], reassocIndices: [[0, 1], [2, 3]], and
+/// targetShape: [3, 4, 5, 1], it returns [1, 2]. Because for pos 0, the
+/// inner-most projected dim in [0, 1] is 1. And for pos 2, the inner-most
+/// non-unit projected dims in [2, 3] is 2.
+///
+/// If all projected dims are unit dims, it chooses the inner-most dim pos.
+static SmallVector<int64_t>
+projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
+                                 ArrayRef<ReassociationIndices> reassocIndices,
+                                 ArrayRef<int64_t> targetShape) {
+  SmallVector<int64_t> projectedDimsPos;
+  for (auto pos : dimsPos) {
+    // In the case all dims are unit, this will return the inner-most one.
+    int64_t projectedPos = reassocIndices[pos].back();
+    for (auto it = reassocIndices[pos].rbegin();
+         it != reassocIndices[pos].rend(); ++it) {
+      int64_t dim = targetShape[*it];
+      if (dim > 1 || ShapedType::isDynamic(dim)) {
+        projectedPos = *it;
+        break;
+      }
+    }
+    projectedDimsPos.push_back(projectedPos);
+  }
+  return projectedDimsPos;
+}
+
+/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
+static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
+                                       ArrayRef<int64_t> shape,
+                                       ArrayRef<int64_t> tileSizes) {
+  for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
+    int64_t dim = shape[pos];
+    if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
+      return false;
+  }
+  return true;
+}
+
+/// Permutate the reassociation indices and reindex them in the sequence order.
+/// For example: given reassociationIndices: [[0, 1], [2]] and permutation: [1,
+/// 0], it applies the permutation to get [[2], [0, 1]] and reindexes the
+/// indices into [[0], [1, 2]].
+static int64_t applyPermutationAndReindexReassoc(
+    SmallVector<ReassociationIndices> &reassociationIndices,
+    ArrayRef<int64_t> permutation) {
+  applyPermutationToVector<ReassociationIndices>(reassociationIndices,
+                                                 permutation);
+  int64_t lastPos = 0;
+  for (ReassociationIndices &indices : reassociationIndices) {
+    for (auto &index : indices) {
+      index = lastPos;
+      lastPos += 1;
+    }
+  }
+  return lastPos;
+}
+
+/// Bubble up pack op through collapse shape op when the packed dims can be
+/// projected to the dims before collapsing. This is possible when the inner
+/// tile sizes can divide the projected dims.
+///
+/// For example:
+///
+/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
+///     : tensor<?x16x4xf32> into tensor<?x4xf32>
+/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
+///     inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
+///     : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+///
+/// Can be transformed into:
+///
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
+///     inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
+///     : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
+///     : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
+static LogicalResult
+bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
+                                   tensor::PackOp packOp,
+                                   PatternRewriter &rewriter) {
+  SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+
+  ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
+  SmallVector<ReassociationIndices> reassocIndices =
+      collapseOp.getReassociationIndices();
+  SmallVector<int64_t> projectedInnerDimsPos =
+      projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
+
+  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
+                                  innerTileSizes)) {
+    return failure();
+  }
+  // Expand the outer dims permutation with the associated source dims for the
+  // new permutation after bubbling. This is because moving a collapsed dim is
+  // equivalent to moving the associated source dims together.
+  SmallVector<int64_t> newOuterDimsPerm;
+  for (auto outerPos : outerDimsPerm) {
+    newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+                            reassocIndices[outerPos].begin(),
+                            reassocIndices[outerPos].end());
+  }
+
+  auto emptyOp = tensor::PackOp::createDestinationTensor(
+      rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
+      projectedInnerDimsPos, newOuterDimsPerm);
+  auto newPackOp = rewriter.create<tensor::PackOp>(
+      packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
+      packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
+
+  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+  // First apply the permutation on the reassociations of the outer dims.
+  // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+  // -> [[0], [1, 2]]
+  int64_t lastPos =
+      applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+  // Then add direct mapping for the inner tile dims.
+  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+    newReassocIndices.push_back({lastPos});
+    lastPos += 1;
+  }
+
+  auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
+      collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
+  rewriter.replaceOp(packOp, newCollapseOp);
+
+  return success();
+}
+
+class BubbleUpPackOpThroughReshapeOp final
+    : public OpRewritePattern<tensor::PackOp> {
+public:
+  BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
+      : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    // User controlled propagation function.
+    if (!controlFn(packOp))
+      return failure();
+
+    Operation *srcOp = packOp.getSource().getDefiningOp();
+    // Currently only support when the pack op is the only user.
+    if (!srcOp || !(srcOp->getNumResults() == 1) ||
+        !srcOp->getResult(0).hasOneUse()) {
+      return failure();
+    }
+    // Currently only support static inner tile sizes.
+    if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
+          return ShapedType::isDynamic(size);
+        })) {
+      return failure();
+    }
+
+    return TypeSwitch<Operation *, LogicalResult>(srcOp)
+        .Case([&](tensor::CollapseShapeOp op) {
+          return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
+        })
+        .Default([](Operation *) { return failure(); });
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
+/// Push down unpack op through expand shape op when the packed dims can be
+/// projected to the dims after expanding. This is possible when the inner tile
+/// sizes can divide the projected dims.
+///
+/// For example:
+///
+/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
+///     inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
+///     : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
+///     : tensor<?x256xf32> into tensor<?x256x256xf32>
+///
+/// Can be transformed into:
+///
+/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
+///     : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
+///     inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
+///     : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+static LogicalResult
+pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
+                                   tensor::ExpandShapeOp expandOp,
+                                   PatternRewriter &rewriter) {
+  SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
+  ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
+  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
+
+  ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+  SmallVector<ReassociationIndices> reassocIndices =
+      expandOp.getReassociationIndices();
+  SmallVector<int64_t> projectedInnerDimsPos =
+      projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
+
+  if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
+                                  innerTileSizes)) {
+    return failure();
+  }
+  // Expand the outer dims permutation with the associated expanded dims for the
+  // new permutation after pushing. This is because moving a source dim is
+  // equivalent to moving the associated expanded dims together.
+  SmallVector<int64_t> newOuterDimsPerm;
+  for (auto outerPos : outerDimsPerm) {
+    newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+                            reassocIndices[outerPos].begin(),
+                            reassocIndices[outerPos].end());
+  }
+
+  SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
+  // First apply the permutation on the reassociations of the outer dims.
+  // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
+  // -> [[0], [1, 2]]
+  int64_t lastPos =
+      applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
+  // Then add direct mapping for the inner tile dims.
+  for (size_t i = 0; i < innerDimsPos.size(); ++i) {
+    newReassocIndices.push_back({lastPos});
+    lastPos += 1;
+  }
+
+  RankedTensorType newExpandType =
+      tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
+                                      projectedInnerDimsPos, newOuterDimsPerm);
+  auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+      expandOp.getLoc(), newExpandType, unPackOp.getSource(),
+      newReassocIndices);
+
+  auto emptyOp = tensor::UnPackOp::createDestinationTensor(
+      rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
+      projectedInnerDimsPos, newOuterDimsPerm);
+  auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
+      unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
+      projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
+  rewriter.replaceOp(expandOp, newUnPackOp);
+
+  return success();
+}
+
+class PushDownUnPackOpThroughReshapeOp final
+    : public OpRewritePattern<tensor::UnPackOp> {
+public:
+  PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
+                                   ControlPropagationFn fun)
+      : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
+  }
+
+  LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
+                                PatternRewriter &rewriter) const override {
+    // User controlled propagation function.
+    if (!controlFn(unPackOp))
+      return failure();
+
+    Value result = unPackOp.getResult();
+    // Currently only support unpack op with the single user.
+    if (!result.hasOneUse()) {
+      return failure();
+    }
+    // Currently only support static inner tile sizes.
+    if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
+          return ShapedType::isDynamic(size);
+        })) {
+      return failure();
+    }
+
+    Operation *userOp = *result.user_begin();
+    return TypeSwitch<Operation *, LogicalResult>(userOp)
+        .Case([&](tensor::ExpandShapeOp op) {
+          return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
+        })
+        .Default([](Operation *) { return failure(); });
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +1058,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
     const ControlPropagationFn &controlPackUnPackPropagation) {
   patterns
       .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
-              PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+              BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
+              PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index e036695a2ac9fd..10c9f5bafb5c03 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -905,3 +905,127 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:      into %[[UNPACK_NEW_DEST]]
 // CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>
+
+// -----
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+  %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+  func.return %pack : tensor<?x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
+
+// -----
+
+func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
+  %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
+  func.return %pack : tensor<4x32x3072x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
+  %2 = tensor.empty() : tensor<8x4x8x1xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
+  func.return %pack : tensor<8x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
+  %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+  %2 = tensor.empty() : tensor<384x32x8x8xf32>
+  %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
+  func.return %pack : tensor<384x32x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
+// CHECK:         return %[[PACK]] : tensor<384x32x8x8xf32>
+
+// -----
+
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+  %6 = tensor.empty(%dim) : tensor<?x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+  func.return %expanded : tensor<?x256x256xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_through_expand
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK:         return %[[UNPACK]] : tensor<?x256x256xf32>
+
+// -----
+
+func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
+  %6 = tensor.empty() : tensor<4x3072x256xf32>
+  %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+  func.return %expanded : tensor<4x12x256x256xf32>
+}
+// CHECK-LABEL: @push_down_permuted_unpack_through_expand
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] ...
[truncated]

@pzread pzread force-pushed the pack-bubble-reshape branch 3 times, most recently from 01d6a8a to fb0a191 Compare March 19, 2024 20:23
@pzread pzread force-pushed the pack-bubble-reshape branch 2 times, most recently from 65b3aef to 4277610 Compare March 19, 2024 20:33
@pzread
Copy link
Member Author

pzread commented Mar 26, 2024

Kindly ping : )

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, LGTM, thanks!

@pzread pzread force-pushed the pack-bubble-reshape branch from b732bf8 to fe58b86 Compare March 27, 2024 17:45
@pzread pzread force-pushed the pack-bubble-reshape branch from fe58b86 to 86024e8 Compare March 27, 2024 17:46
@pzread pzread merged commit 0c1c0d5 into llvm:main Mar 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants