Skip to content

[mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern #114315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 30, 2024

This PR restricts GeneralizeOuterUnitDimsPackOpPattern to follow its
intended purpose (as per the documentation), which is to:

require all outer dimensions of tensor.pack to be 1.

There was one in-tree test that violated this assumption (and happened
to work) – see @simple_KCRS_to_KRSCsr in
"generalize-tensor-pack.mlir". That test has been updated to satisfy the
new requirements of the pattern.

By enforcing the pattern to follow its intended design (i.e., making it
stricter), the calculation of shapes and sizes for various Ops that the
pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and
InsertSliceOp) becomes much simpler and easier to document. This also
helped generalize the pattern to support cases like the one below:

func.func @simple_pad_and_pack_dynamic_tile_cst(
    %src: tensor<5x1xf32>,
    %dest: tensor<1x1x?x2xf32>,
    %pad: f32) -> tensor<1x1x?x2xf32> {

  %tile_dim_0 = arith.constant 8 : index
  %0 = tensor.pack %src
    padding_value(%pad : f32)
    inner_dims_pos = [0, 1]
    inner_tiles = [%tile_dim_0, 2]
    into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32>

  return %0 : tensor<1x1x?x2xf32>
}

Note that the inner tile slice is dynamic but compile-time constant.
getPackOpSourceOrPaddedSource, which is used to generate PadOp,
detects this and generates a PadOp with static shapes. This is a good
optimization, but it means that all shapes/sizes for Ops generated by
GeneralizeOuterUnitDimsPackOpPattern also need to be updated to be
constant/static. By restricting the pattern and simplifying the
size/shape calculation, supporting the case above becomes much easier.

Notable implementation changes:

  • PadOp processes the original source (no change in dimensions/rank).
    ExtractSliceOp extracts the tile to pack and may reduce the rank. All
    following ops work on the tile extracted by ExtractSliceOp (possibly
    rank-reduced).
  • All shape/size calculations assume that trailing dimensions match
    inner_tiles from tensor.pack. All leading dimensions (i.e., outer
    dimensions) are assumed to be 1.
  • Dynamic sizes for ops like ExtractSliceOp are taken from inner_tiles
    rather than computed as, for example, tensor.dim %dest, 2. It’s the
    responsibility of the "producers" of tensor.pack to ensure that
    dimensions in %dest match the specified tile sizes.

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This PR restricts GeneralizeOuterUnitDimsPackOpPattern by making it
follow its intended use (as per the documentation), i.e. to:

> require all outer dims of tensor.pack to be 1.

There was one test in-tree that violated that assumption (and, happened
to work), see @<!-- -->simple_KCRS_to_KRSCsr in "generalize-tensor-pack.mlir".
That test has been updated to satisfy the updated requirements of the
pattern.

By making the pattern follow its intended design (i.e. making it
stricter), the calculation of shapes and sizes for various Ops that the
pattern generates (PadOp + ExtractSliceOp + EmptyOp + TensorOp +
InsertSliceOp) becomes much simpler and easier to document. It also
helped generalize the pattern to support cases like the one below:

func.func @<!-- -->simple_pad_and_pack_dynamic_tile_cst(
    %src: tensor&lt;5x1xf32&gt;,
    %dest: tensor&lt;1x1x?x2xf32&gt;,
    %pad: f32) -&gt; tensor&lt;1x1x?x2xf32&gt; {

  %tile_dim_0 = arith.constant 8 : index
  %0 = tensor.pack %src
    padding_value(%pad : f32)
    inner_dims_pos = [0, 1]
    inner_tiles = [%tile_dim_0, 2]
    into %dest : tensor&lt;5x1xf32&gt; -&gt; tensor&lt;1x1x?x2xf32&gt;

  return %0 : tensor&lt;1x1x?x2xf32&gt;
}

Note that the inner tile slice is dynamic, but compile-time constant.
getPackOpSourceOrPaddedSource - that's used to generated PadOp - is
able to see that and generates PadOp with static shapes. This a good
optimization, but it means that all shapes/sizes for Ops generated by
GeneralizeOuterUnitDimsPackOpPattern also have to be updated to be
constant/static. By restricting the pattern and making the size/shape
calculation more straightforward, supporting the case above becomes much
easier.

Notable implementation changes:

  • PadOp processes the original source (no change in dimensions/rank).
    ExtractSliceOp extracts the tile to pack and it may reduce the rank.
    All ops that follow operate on the tile extracted by ExtractSliceOp
    (possibly rank-reducded).
  • All shape/size calculations assume that trailing dims match inner_tiles
    from tensor.pack. All the leading dims (i.e. outer dims) are
    assumed to be 1.
  • Dynamic sizes for ops like ExtractSliceOp are taken from
    inner_tiles rather than computed as e.g. tensor.dim %dest, 2.
    It's for "producers" of tensor.pack to make sure that the
    dimensions in %dest match the specified tile sizes.

Patch is 26.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114315.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+37-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+92-33)
  • (modified) mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir (+110-48)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b5710bd78f0089..a8662a3d6f63be 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
                                const SmallVector<Value> &dynSizes) const;
 };
 
-/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
-/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
-/// 1s.
+/// Rewrites a tensor::PackOp into a sequence of:
+///   * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+///     tensor::EmptyOp + tensor::InsertSliceOp ops.
+///
+/// Required that all the outer dims of the input tensor::PackOp are 1.
+///
+/// Before:
+/// ```
+///   %packed = tensor.pack %input
+///     padding_value(%pad : f32)
+///     inner_dims_pos = [1, 0]
+///     inner_tiles = [2, %high]
+///     into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
+/// ```
+///
+/// After:
+/// ```
+///   // PadOp
+///   %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
+///     ^bb0(...):
+///       tensor.yield %arg2 : f32
+///   } : tensor<5x1xf32> to tensor<?x2xf32>
+///   // ExtractSliceOp
+///   %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
+///   1]
+///     : tensor<?x2xf32> to tensor<?x2xf32>
+///   // EmptyOp + TransposeOp
+///   %empty = tensor.empty(%arg3) : tensor<2x?xf32>
+///   %transposed = linalg.transpose
+///     ins(%extracted_slice : tensor<?x2xf32>)
+///     outs(%empty : tensor<2x?xf32>)
+///     permutation = [1, 0]
+///   // InsertSliceOp
+///   %inserted_slice = tensor.insert_slice %transposed
+///     into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
+///     : tensor<2x?xf32> into tensor<1x1x2x?xf32>
+/// ```
 struct GeneralizeOuterUnitDimsPackOpPattern
     : public OpRewritePattern<tensor::PackOp> {
   using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index da5233049aaf69..ed5f1bd602d7f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1138,6 +1139,29 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
   return perm;
 }
 
+// A helper function to generate a dim-and-size pair for Ops like
+// ExtractSliceOp that require both:
+//  * dims to specify the output shape, and
+//  * sizes for the sizes attribute (or similar).
+// For dynamic sizes, if the corresponding size is a compile time constant:
+//  * the return size becomes the attribute encapsulating the known size, and
+//  * dim is updated from kDynamic to its actual known value.
+static std::pair<int64_t, OpFoldResult>
+getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
+  int64_t tileSizeForShape =
+      getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
+
+  OpFoldResult tileSizeOfrSimplified;
+  if (tileSizeForShape != ShapedType::kDynamic) {
+    tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
+  } else {
+    tileSizeOfrSimplified = tileSizeOfr;
+  }
+
+  return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
+                                          tileSizeOfrSimplified);
+}
+
 LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
     tensor::PackOp packOp, PatternRewriter &rewriter) const {
   // TODO: support the case that outer dimensions are not all 1s. A
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
         packOp, "require the tiled outer dimensions of the result are all 1s");
   }
 
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
-  // outer dims.
+  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   Location loc = packOp.getLoc();
+
   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
   auto inputShape = packOp.getSourceType().getShape();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       packOp.getDimAndTileMapping();
-  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
-  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   int64_t srcRank = packOp.getSourceRank();
+
+  int64_t destRank = packOp.getDestRank();
+  size_t numTiles = destRank - srcRank;
+
+  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+  //    %extracted_tile = tensor.extract_slice(%pack_op_input)
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
-  SmallVector<OpFoldResult> readSizes;
-  SmallVector<OpFoldResult> transShapeForEmpty;
-  SmallVector<int64_t> readShapeForExtractSlice;
+
+  // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
+  // all outer dims are 1.
+  SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
+  // The shape of the output for ExtractSliceOp. All leading unit dims are
+  // effectively rank-reduced, hence skipped.
+  SmallVector<int64_t> outputShapeForExtractSlice;
+
+  // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
+  // be equal to the inner tile sizes.
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      readShapeForExtractSlice.push_back(
-          getConstantIntValue(dimAndTileMapping[i])
-              .value_or(ShapedType::kDynamic));
-      readSizes.push_back(dimAndTileMapping[i]);
-      transShapeForEmpty.push_back(dimAndTileMapping[i]);
-      continue;
-    }
-    if (ShapedType::isDynamic(inputShape[i])) {
-      readSizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, input, i).getResult());
-    } else {
-      readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
-    }
-    if (inputShape[i] != 1) {
-      readShapeForExtractSlice.push_back(inputShape[i]);
-      transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+      auto [tileSize, tileSizeOfr] =
+          getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+      extractSliceSizes.push_back(tileSizeOfr);
+      outputShapeForExtractSlice.push_back(tileSize);
     }
   }
 
   Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
+  auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
 
   Value tile = rewriter.create<tensor::ExtractSliceOp>(
-      loc, readType, input, readOffsets, readSizes, readStrides);
+      loc, readType, input, readOffsets, extractSliceSizes, readStrides);
 
-  // 2. Transpose the tile to match the inner tile order.
+  // 2. Transpose the tile to match the inner tile order:
+  //    %init = tensor.empty()
+  //    %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
+  //  NOTE: Outer dims are 1 and hence effectively ignored.
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
       inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
 
-  applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
+  // 2.1 Create tensor.empty (init value for TransposeOp)
+  SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
+  SmallVector<int64_t> transShapeForEmptyOpStatic;
+
+  // Acquire tensor shape required to create EmptyOp. This will match the inner
+  // tile sizes, but the actual data format will depend on whether the tile
+  // sizes are static or dynamic (each case leads to a different builder for
+  // EmptyOp). Conservatively, prepare for both scenarios.
+  size_t idx = numTiles;
+  while (idx != 0) {
+    transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
+    transShapeForEmptyOpStatic.push_back(
+        outputShapeForExtractSlice[numTiles - idx]);
+    idx--;
+  }
 
-  Value empty =
-      rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
+  applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
+  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
+
+  Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
+                    ? rewriter.create<tensor::EmptyOp>(
+                          loc, transShapeForEmptyOpDynamic, elemType)
+                    : rewriter.create<tensor::EmptyOp>(
+                          loc, transShapeForEmptyOpStatic, elemType);
+
+  // 2.2 Create linalg.transpose
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
 
-  // 3. Insert the inner tile to the destination.
-  int64_t destRank = packOp.getDestRank();
+  // 3. Insert the inner tile to the destination:
+  //  %inserted_tile = tensor.insert_slice(%transposed_tile)
   SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
   SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> writeSizes =
-      tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+  // Outer dims are all 1s!
+  SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
+                                       oneIdxAttr);
+  SmallVector<int64_t> writeShape;
+
+  for (auto tileSize : packOp.getMixedTiles()) {
+    auto [tileSizeStatic, tileSizeOfr] =
+        getSimplifiedDimSizePair(tileSize, rewriter);
+    writeSizes.push_back(tileSizeOfr);
+    writeShape.push_back(tileSizeStatic);
+  }
 
+  // 4. Replace tensor.packOp with tensor.insert_slice created above
   auto insert = rewriter.create<tensor::InsertSliceOp>(
       loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
       writeSizes, writeStrides);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7f6b5e279f6857..8abf7a11bed5c9 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -1,21 +1,32 @@
 // RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack"  %s | FileCheck %s
 
-func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> {
-  %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
-  return %0 : tensor<1x1x1x1x8x32xf32>
+
+func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
+  %c8 = arith.constant 8 : index
+  %c5 = arith.constant 5 : i32
+  %pack = tensor.pack %arg0 padding_value(%c5 : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %arg1 : tensor<?x?xi32> -> tensor<1x1x?x1xi32>
+  return %pack : tensor<1x1x?x1xi32>
 }
-// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr
-// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME:      permutation = [1, 0]
-// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
-// CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK:         return %[[INSERT]]
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (-s0 + 8)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<()[s0] -> (-s0 + 1)>
+
+// CHECK-LABEL:   func.func @simple_KCRS_to_KCRSsr(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<?x?xi32>,
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32>
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 5 : i32
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[SRC]], %[[VAL_4]] : tensor<?x?xi32>
+// CHECK:           %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[SRC]], %[[VAL_2]] : tensor<?x?xi32>
+// CHECK:           %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[VAL_7]]]
+// CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[VAL_6]], %[[VAL_8]]] {
+// CHECK:           ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
+// CHECK:             tensor.yield %[[VAL_3]] : i32
+// CHECK:           } : tensor<?x?xi32> to tensor<8x1xi32>
+// CHECK:           %[[INSERT:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<8x1xi32> into tensor<1x1x?x1xi32>
+// CHECK:           return %[[INSERT]] : tensor<1x1x?x1xi32>
 
 // -----
 
@@ -39,26 +50,59 @@ func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: te
 
 /// Same as example above, but with 1 dynamic tile size.
 
-func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
-  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %tile_dim_0: index) -> tensor<1x1x?x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
   return %0 : tensor<1x1x?x2xf32>
 }
-
 // CHECK-LABEL:   func.func @simple_pad_and_pack_dynamic_tile(
 // CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]
 // CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]
-// CHECK-SAME:      %[[HIGH_VAL:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
-// CHECK:           %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK-SAME:      %[[TILE_DIM_0:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_0]]]
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
-// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
 
+func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+  %tile_dim_0 = arith.constant 8 : index
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+  return %0 : tensor<1x1x?x2xf32>
+}
+// CHECK-LABEL:   func.func @simple_pad_and_pack_dynamic_tile_cst(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1] {
+// CHECK:             tensor.yield %[[PAD_VAL]] : f32
+// CHECK:           } : tensor<5x1xf32> to tensor<8x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<8x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
+
+func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %output: tensor<1x1x2x?xf32>, %pad: f32, %tile_dim_1: index) -> tensor<1x1x2x?xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
+  return %0 : tensor<1x1x2x?xf32>
+}
+// CHECK-LABEL:   func.func @simple_pad_and_pack_dynamic_tile_transpose(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME:      %[[TILE_DIM_1:[a-zA-Z0-9]+]]: index) -> tensor<1x1x2x?xf32> {
+// CHECK:           %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_1]]]
+// CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK:            tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NEXT:      } : tensor<5x1xf32> to tensor<?x2xf32>
+// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK:           %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
+// CHECK:           %[[TR:.*]] = linalg.transpose
+// CHECK-SAME:        ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME:        permutation = [1, 0]
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
+// CHECK:           return %[[RES]] : tensor<1x1x2x?xf32>
+
 /// Same as example above, but with 1 scalable tile size.
 
 /// NOTE: For this example to make sense in practice, the "?" in the output shape
@@ -77,7 +121,6 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
 // CHECK-SAME:      %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
 // CHECK-SAME:      %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
 // CHECK-SAME:      %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
-// CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : index
 // CHECK-DAG:       %[[C8:.+]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VS:.+]] = vector.vscale
 // CHECK:           %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
@@ -86,37 +129,56 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
 // CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
-// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
 
 /// Same as example above, but with both tile sizes dynamic.
 
-func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %high_1: index, %high_2: index) -> tensor<1x1x?x?xf32> {
-  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high_1, %high_2] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32>
+func.func @simple_pad_and_pack_dynami...
[truncated]

@banach-space banach-space force-pushed the andrzej/generalize_fix branch from 7f79675 to d556706 Compare October 30, 2024 22:27
@hanhanW hanhanW requested a review from Max191 October 30, 2024 22:27
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR restricts GeneralizeOuterUnitDimsPackOpPattern to follow its
intended purpose (as per the documentation),

Does this PR add this restriction in the matcher for the pattern? I didn't see any update to the matching logic.

There was one in-tree test that violated this assumption (and happened
to work) – see @simple_KCRS_to_KRSCsr in
"generalize-tensor-pack.mlir". That test has been updated to satisfy the
new requirements of the pattern.

This test has a non-unit outer dimension, but that dimension is not packed (no inner_tile for it). I think this may actually be intended behavior, since the pattern checks that packOp.getTiledOuterDims() are all 1. Maybe the comments are just misleading, and this case is meant to be supported.

I'm not sure what the motivation of this pattern was to begin with, so I can't say if this type of case needs to be supported, but I would be wary of removing that functionality without hearing from whoever wrote this pattern. Looks like @qedawkins added this test, so he may have better context.

Comment on lines +1540 to +1542
/// // ExtractSliceOp
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
/// 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be in the scope of this PR to change this, but what exactly is this extract_slice for? It seems to me like it is just taking a full slice. Is the slice ever not full?

Comment on lines 1149 to 1150
static std::pair<int64_t, OpFoldResult>
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a useful util, maybe move it to Dialect/Utils/StaticValueUtils.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Comment on lines -152 to -154
func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> {
%0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
return %0 : tensor<3x1x1x1x8x32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no other tests like this, can we actually keep this test so we can test that the outer dims must be all 1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't see anything in this PR that restricts the pattern to fail on this case. Did you mean to update the matching logic too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't see anything in this PR that restricts the pattern to fail on this case. Did you mean to update the matching logic too?

I did, thanks for catching this. Sending an update shortly.

@Max191 Max191 requested a review from qedawkins November 5, 2024 14:57
@banach-space
Copy link
Contributor Author

Thanks for taking a look! Quick reply to your high-level question (emphasis mine):

This test has a non-unit outer dimension, but that dimension is not packed (no inner_tile for it). I think this may actually be intended behavior, since the pattern checks that packOp.getTiledOuterDims() are all 1. Maybe the comments are just misleading, and this case is meant to be supported.

Indeed, I'd really appreciate if somebody could clarify this 😅

I'm not sure what the motivation of this pattern was to begin with, so I can't say if this type of case needs to be supported, but I would be wary of removing that functionality without hearing from whoever wrote this pattern.

100% agreed, thanks for bringing this up! As you can see, I'm actually struggling a bit to get feedback for this - I really appreciate you not shying away! 🙏🏻

While waiting for Quinn to chime in, I will make a couple of points.

  1. The current logic to compute the necessary sizes is quite convoluted. Adding support for the case mentioned above has been quite tricky. I can try to add support for non-unit not-tiled-outer-dims, but would really prefer avoid complexities that are not required.

  2. The current logic for non-unit not-tiled-outer-dims is quite limited and breaks when the padding value is set. You can try this example:

func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1:
tensor<3x1x1x1x8x32xf32>, %pad: f32) -> tensor<3x1x1x1x8x32xf32> {
  %0 = tensor.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
  return %0 : tensor<3x1x1x1x8x32xf32>
}

The logic that files: getPackOpSourceOrPaddedSource (i.e. you will hit the assert in that method).

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a case where the term outer is slightly unclear (but I probably interpreted it wrong when I wrote it). What the docs should have said after I added that test is all tiled outer dims. In that case the typical decomposition of a pack (pad + expand_shape + transpose) would only be introducing unit dimensions with the expand_shape and could use a rank reducing insert_slice to introduce the new sizes. Whether the untiled outer dims were also unit doesn't affect that aspect of the decomposition.

This also helped generalize the pattern to support cases like the one below:
...
By restricting the pattern and simplifying the size/shape calculation, supporting the case above becomes much easier.

I don't quite follow how exclusion of non-unit untiled outer dimensions simplifies shape calculation logic here (I suspect it could have been more a matter of the previous implementation being more convoluted than necessary).

// CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index):
// CHECK: tensor.yield %[[VAL_2]] : f32
// CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32>
// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extract_slice seems unnecessary to me. Instead we can just extend the permutation of the transpose to include the outer most untiled dims.

int64_t destRank = packOp.getDestRank();
size_t numTiles = destRank - srcRank;

// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on the advantage of slicing out the inner tile here? We end up reintroducing the outer unit dims shortly after with the tensor.insert_slice so this only serves to remove the unit dims from the linalg.transpose for permuting the inner dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am merely a messenger here ;-) (as in, tensor.extract_slice is already a part of this logic, not something added by me).

To me, it makes more sense to only transpose the tile worth of data as that's the intent of tensor.pack, right? But you will have more experience with this logic. I just wanted to preserve the original logic as much as possible and to avoid making too many changes in one PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah really, that makes sense then. If it's preserving existing behavior leaving it is fine, but should look into removing it at some point.

@qedawkins
Copy link
Contributor

Ah I see your response before I sent my review.

The current logic to compute the necessary sizes is quite convoluted. Adding support for the case mentioned above has been quite tricky. I can try to add support for non-unit not-tiled-outer-dims, but would really prefer avoid complexities that are not required.

I believe we were at one point relying on support for this downstream, but that might not be the case anymore (the pass I'm thinking of might have grown it's own simplified pattern for handling cases like this). Dropping support and fixing forward if the current implementation is broken (as you seem to be suggesting with your second point).

@banach-space
Copy link
Contributor Author

banach-space commented Nov 5, 2024

Thank you for taking a look :)

I believe this is a case where the term outer is slightly unclear

Indeed. In fact, I've implemented #109642 so that we can be clearer about the intent in the future. But this logic predates that PR.

I don't quite follow how exclusion of non-unit untiled outer dimensions simplifies shape calculation logic here (I suspect it could have been more a matter of the previous implementation being more convoluted than necessary).

Yes, it's rather convoluted and also incomplete (hence this PR). And then there's getPackOpSourceOrPaddedSource that does require all outer dims to be unit - at least when the pad value is specified. This specifcially makes me believe that this is neither needed nor used?

As for simplifying calculation, if we do allow non-unit untiled dims, then we need to track 3 sets of dims (inner, outer tiled, outer untiled). Sure, that can be done, but given there are other issues here ...

I am happy to restore that functionality, but if it's not really required by anyone (that part is still unclear to me), then keeping things simple might be to our overall benefit.

What are your thoughts?

@qedawkins
Copy link
Contributor

qedawkins commented Nov 5, 2024

This has fallen pretty far out of my cache but the pass that required support for the non-unit untiled outer dim no longer depends on this pattern so I think you should be good to drop it for now.

I would still like to see if we can avoid that extract_slice though. I've found rank-reducing extract_slice ops to compose poorly with most other patterns/passes.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining the current state and the nice comments. This looks like good cleanup to me now and we can try to fix forward on the case dropped here (unless someone else is relying on that functionality).

// * the return size becomes the attribute encapsulating the known size, and
// * dim is updated from kDynamic to its actual known value.
static std::pair<int64_t, OpFoldResult>
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can just use Builder & instead of a PatternRewriter &

applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);

Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?

I'm thinking that for any case where you needed the static builder, we could have had an additional dynamic dim that would make it take the dynamic path, which should still do the same thing for the static part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?

Great point!

It turns out that EmptyOp::build already supports the necessary "magic" via dispatchIndexOpFoldResults :)

@banach-space banach-space force-pushed the andrzej/generalize_fix branch from d556706 to 99d24b7 Compare November 6, 2024 10:51
Copy link

github-actions bot commented Nov 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space
Copy link
Contributor Author

I've just sent some updates addressing/implementing your suggestions for this revision. Two high-level comments:

  • I am happy to remove tensor.extract_slice but will keep this change for a follow-on patch to avoid too many updates in a single PR.
  • If there's a user for whom restricting GeneralizeOuterUnitDimsPackOpPattern causes issues, I am open to an immediate revert and will work on a more generic solution.

Most importantly, thank you for your in-depth review 🙏🏻

@banach-space banach-space force-pushed the andrzej/generalize_fix branch from 99d24b7 to 7916a3f Compare November 6, 2024 12:45
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, thanks!

@banach-space banach-space force-pushed the andrzej/generalize_fix branch from 7916a3f to 50ab228 Compare November 6, 2024 15:18
@Max191 Max191 self-requested a review November 6, 2024 16:06
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the nice discussion with Quinn and this look good to me now. Thanks for the cleanup!

@banach-space banach-space force-pushed the andrzej/generalize_fix branch from 2c3666d to 82283ba Compare November 6, 2024 19:08
…ern`

This PR *restricts* `GeneralizeOuterUnitDimsPackOpPattern` to follow its
intended purpose (as per the documentation), which is to:

  > require all outer dimensions of tensor.pack to be 1.

There was one in-tree test that violated this assumption (and happened
to work) – see `@simple_KCRS_to_KRSCsr` in
"generalize-tensor-pack.mlir". That test has been updated to satisfy the
new requirements of the pattern.

By enforcing the pattern to follow its intended design (i.e., making it
stricter), the calculation of shapes and sizes for various Ops that the
pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and
InsertSliceOp) becomes much simpler and easier to document. This also
helped *generalize* the pattern to support cases like the one below:

```mlir
func.func @simple_pad_and_pack_dynamic_tile_cst(
    %src: tensor<5x1xf32>,
    %dest: tensor<1x1x?x2xf32>,
    %pad: f32) -> tensor<1x1x?x2xf32> {

  %tile_dim_0 = arith.constant 8 : index
  %0 = tensor.pack %src
    padding_value(%pad : f32)
    inner_dims_pos = [0, 1]
    inner_tiles = [%tile_dim_0, 2]
    into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32>

  return %0 : tensor<1x1x?x2xf32>
}
```

Note that the inner tile slice is dynamic but compile-time constant.
`getPackOpSourceOrPaddedSource`, which is used to generate PadOp,
detects this and generates a PadOp with static shapes. This is a good
optimization, but it means that all shapes/sizes for Ops generated by
`GeneralizeOuterUnitDimsPackOpPattern` also need to be updated to be
constant/static. By restricting the pattern and simplifying the
size/shape calculation, supporting the case above becomes much easier.

Notable implementation changes:

* PadOp processes the original source (no change in dimensions/rank).
  ExtractSliceOp extracts the tile to pack and may reduce the rank. All
  following ops work on the tile extracted by ExtractSliceOp (possibly
  rank-reduced).
* All shape/size calculations assume that trailing dimensions match
  inner_tiles from tensor.pack. All leading dimensions (i.e., outer
  dimensions) are assumed to be 1.
* Dynamic sizes for ops like ExtractSliceOp are taken from inner_tiles
  rather than computed as, for example, tensor.dim %dest, 2. It’s the
  responsibility of the "producers" of tensor.pack to ensure that
  dimensions in %dest match the specified tile sizes.
…kOpPattern`

SKip calculating static shapes for EmptyOp
…DimsPackOpPattern`

Raname and move getSimplifiedDimSizePair
@banach-space banach-space force-pushed the andrzej/generalize_fix branch from 82283ba to 2d9bc62 Compare November 6, 2024 19:59
@banach-space banach-space merged commit e9bafa3 into llvm:main Nov 6, 2024
6 of 7 checks passed
@banach-space banach-space deleted the andrzej/generalize_fix branch November 6, 2024 21:14
banach-space added a commit to banach-space/llvm-project that referenced this pull request Nov 7, 2024
Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315.

This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```

Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%extracted_slice : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%padded : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
banach-space added a commit to banach-space/llvm-project that referenced this pull request Nov 7, 2024
Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315.

This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```

Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%extracted_slice : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%padded : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
@banach-space
Copy link
Contributor Author

Promised refactor for tensor.extract_slice: #115312

banach-space added a commit to banach-space/llvm-project that referenced this pull request Nov 11, 2024
Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes.
While relatively simple (e.g., no vectorization), this example required
a few non-trivial fixes in handling `tensor.pack`:

* llvm#114315, llvm#114559, llvm#113108.

The end goal for this test is to incrementally increase its complexity
and to work towards scalable tile sizes.
banach-space added a commit that referenced this pull request Nov 12, 2024
Avoid generating spurious tensor.extract_slice, follow-on for #114315.

This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```

Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
// NOTE: skipped in the output _after_
%extracted_slice = tensor.extract_slice
  %padded[0, 0] [%arg3, 2] [1, 1] :
  tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%extracted_slice : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%padded : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

This PR also adds a check to verify that only the last N trailing
dimensions are tiled (for some value of N). Based on the PR
discussion, this restriction seems reasonable - especially as there
are no in-tree tests requiring otherwise. For now, it also simplifies
the computation of permutations for linalg.transpose. This
restriction can be relaxed in the future if needed.
banach-space added a commit that referenced this pull request Nov 14, 2024
…115698)

Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes.
While relatively simple (e.g., no vectorization), this example required
a few non-trivial fixes in handling `tensor.pack`:

* #114315, #114559, #113108.

The end goal for this test is to incrementally increase its complexity
and to work towards scalable tile sizes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants