Skip to content

[mlir][tensor] Update GeneralizeOuterUnitDimsPackOpPattern #115312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 12, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 7, 2024

Avoid generating spurious tensor.extract_slice, follow-on for #114315.

This is best to demonstrate with an example. Here's input for
GeneralizeOuterUnitDimsPackOpPattern:

%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>

Output before:

%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
// NOTE: skipped in the output _after_
%extracted_slice = tensor.extract_slice
  %padded[0, 0] [%arg3, 2] [1, 1] :
  tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%extracted_slice : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>

Output after:

%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%padded : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>

This PR also adds a check to verify that only the last N trailing dimensions
are tiled (for some value of N). Based on the PR discussion, this
restriction seems reasonable—especially as there are no in-tree tests
requiring otherwise. For now, it also simplifies the computation of
permutations for linalg.transpose. This restriction can be relaxed in
the future if needed.

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

Avoid generating spurious tensor.extract_slice, follow-on for #114315.

This is best to demonstrate with an example. Here's input for
GeneralizeOuterUnitDimsPackOpPattern:

%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor&lt;5x1xf32&gt; -&gt; tensor&lt;1x1x2x?xf32&gt;

Output before:

%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor&lt;5x1xf32&gt; to tensor&lt;?x2xf32&gt;
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor&lt;?x2xf32&gt; to tensor&lt;?x2xf32&gt;
%empty = tensor.empty(%arg3) : tensor&lt;2x?xf32&gt;
%transposed = linalg.transpose
  ins(%extracted_slice : tensor&lt;?x2xf32&gt;)
  outs(%empty : tensor&lt;2x?xf32&gt;)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor&lt;2x?xf32&gt; into tensor&lt;1x1x2x?xf32&gt;

Output after:

%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor&lt;5x1xf32&gt; to tensor&lt;?x2xf32&gt;
%empty = tensor.empty(%arg3) : tensor&lt;2x?xf32&gt;
%transposed = linalg.transpose
  ins(%padded : tensor&lt;?x2xf32&gt;)
  outs(%empty : tensor&lt;2x?xf32&gt;) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor&lt;2x?xf32&gt; into tensor&lt;1x1x2x?xf32&gt;

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future


Full diff: https://github.com/llvm/llvm-project/pull/115312.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+1-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+35-43)
  • (modified) mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir (+9-9)
  • (modified) mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir (+17-19)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a8662a3d6f63be..5209e1145506b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1516,7 +1516,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 };
 
 /// Rewrites a tensor::PackOp into a sequence of:
-///   * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+///   * tensor::PadOp + linalg::TransposeOp +
 ///     tensor::EmptyOp + tensor::InsertSliceOp ops.
 ///
 /// Required that all the outer dims of the input tensor::PackOp are 1.
@@ -1537,10 +1537,6 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 ///     ^bb0(...):
 ///       tensor.yield %arg2 : f32
 ///   } : tensor<5x1xf32> to tensor<?x2xf32>
-///   // ExtractSliceOp
-///   %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
-///   1]
-///     : tensor<?x2xf32> to tensor<?x2xf32>
 ///   // EmptyOp + TransposeOp
 ///   %empty = tensor.empty(%arg3) : tensor<2x?xf32>
 ///   %transposed = linalg.transpose
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 64096954f56b95..0be8799f327441 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1153,71 +1153,63 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
   Location loc = packOp.getLoc();
 
   Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
-  auto inputShape = packOp.getSourceType().getShape();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       packOp.getDimAndTileMapping();
   int64_t srcRank = packOp.getSourceRank();
-
   int64_t destRank = packOp.getDestRank();
-  size_t numTiles = destRank - srcRank;
-
-  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
-  //    %extracted_tile = tensor.extract_slice(%pack_op_input)
-  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
-  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+  int64_t numTiles = destRank - srcRank;
 
-  // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
-  // all outer dims are 1.
-  SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
-  // The shape of the output for ExtractSliceOp. All leading unit dims are
-  // effectively rank-reduced, hence skipped.
-  SmallVector<int64_t> outputShapeForExtractSlice;
+  if (!llvm::all_of(packOp.getInnerDimsPos(),
+                    [&srcRank, &numTiles](int64_t dimPos) {
+                      return dimPos >= (srcRank - numTiles - 1);
+                    }))
+    return rewriter.notifyMatchFailure(
+        packOp, "Attempting to tile non-trailing source dims!");
 
-  // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
-  // be equal to the inner tile sizes.
+  // 1. Extract the inner tile sizes.
+  // Where possible, values are replaced with constant attributes (to match the
+  // behaviour of `getPackOpSourceOrPaddedSource`).
+  SmallVector<OpFoldResult> tileSizes;
   for (auto i : llvm::seq<unsigned>(0, srcRank)) {
     if (dimAndTileMapping.count(i)) {
-      auto [tileSize, tileSizeOfr] =
+      // Rather than taking the tile size as is, extact the actual constant
+      // value Attribute where possible, e.g.:
+      //    [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
+      auto [_, tileSize] =
           getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
-      extractSliceSizes.push_back(tileSizeOfr);
-      outputShapeForExtractSlice.push_back(tileSize);
+      tileSizes.push_back(tileSize);
     }
   }
 
-  Type elemType = packOp.getSourceType().getElementType();
-  auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
-
-  Value tile = rewriter.create<tensor::ExtractSliceOp>(
-      loc, readType, input, readOffsets, extractSliceSizes, readStrides);
-
-  // 2. Transpose the tile to match the inner tile order:
+  // 2. Transpose the input to match the inner tile order:
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
-  // NOTE: Outer dims are 1 and hence effectively ignored.
-  SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
-      inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
+  // Two assumptions are made:
+  //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
+  //  2. Inner dims position correspond to the trailing `numTiles` dims.
+  SmallVector<int64_t> tilesPermNormalized =
+      getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+  SmallVector<int64_t> srcPermForTranspose;
+  for (int64_t i = 0; i < (srcRank - numTiles); i++)
+    srcPermForTranspose.push_back(i);
+
+  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
-             llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+             llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); DBGSNL(););
 
   // 2.1 Create tensor.empty (init value for TransposeOp)
-  SmallVector<OpFoldResult> transShapeForEmptyOp;
-
-  // Acquire tensor shape required to create EmptyOp. This will match the inner
-  // tile sizes.
-  size_t idx = numTiles;
-  while (idx != 0) {
-    transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
-    idx--;
-  }
+  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
+                                                 oneIdxAttr);
+  transShapeForEmptyOp.append(tileSizes);
 
-  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
-  Value empty =
-      rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
+  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, srcPermForTranspose);
+  Value empty = rewriter.create<tensor::EmptyOp>(
+      loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
 
   // 2.2 Create linalg.transpose
   auto transposedOp =
-      rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
+      rewriter.create<linalg::TransposeOp>(loc, input, empty, srcPermForTranspose);
 
   // 3. Insert the inner tile to the destination:
   //  %inserted_tile = tensor.insert_slice(%transposed_tile)
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d0c53ae4680013..8be3e7413bfc81 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -9,19 +9,19 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
 // CHECK:       func.func @KCRS_to_KCRSsr
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
-// CHECK:           %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK:         scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK:           scf.for %[[S:[a-zA-Z0-9]+]] {{.*}} iter_args(%[[ITER_SLICE:.*]] =
 // CHECK:             %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
 // CHECK:             %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
 // CHECK:             %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
 // CHECK-SAME:          [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:             %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-SAME:          [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK:             %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK:             %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:          ins(%[[TILE]]
-// CHECK-SAME:          outs(%[[EMPTY]]
-// CHECK-SAME:          permutation = [1, 0]
+// CHECK:             %[[TILE:.*]] = tensor.extract_slice %[[ITER_SLICE]]
+// CHECK-SAME:          [0, 0, %[[R]], %[[S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x4x8x8x32xf32> to tensor<1x1x1x1x8x32xf32>
+// CHECK:             %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x8x32xf32>
+// CHECK:             %[[TRANSP:.*]] = linalg.transpose
+// CHECK-SAME:          ins(%[[SRC_SLICE]] : tensor<1x1x32x8xf32>)
+// CHECK-SAME:          outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME:          permutation = [0, 1, 3, 2]
 // CHECK:             %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
 
 module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 8abf7a11bed5c9..f4b1d9a55f0914 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -63,8 +63,7 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
 
 func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
@@ -95,10 +94,10 @@ func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:            tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NEXT:      } : tensor<5x1xf32> to tensor<?x2xf32>
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
 // CHECK:           %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
 // CHECK:           %[[TR:.*]] = linalg.transpose
-// CHECK-SAME:        ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME:        ins(%[[PAD:.*]] : tensor<?x2xf32>)
+// CHECK-SAME:        outs(%[[EMPTY]] : tensor<2x?xf32>)
 // CHECK-SAME:        permutation = [1, 0]
 // CHECK:           %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x2x?xf32>
@@ -128,10 +127,10 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
 // CHECK:           %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x2xf32>
 
+
 /// Same as example above, but with both tile sizes dynamic.
 
 func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> {
@@ -149,8 +148,7 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
 // CHECK:           %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
 // CHECK:             tensor.yield %[[PAD_VAL]] : f32
 // CHECK-NOT:       linalg.transpose
-// CHECK:           %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
 // CHECK:           return %[[RES]] : tensor<1x1x?x?xf32>
 
 // -----
@@ -170,12 +168,13 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index):
 // CHECK:             tensor.yield %[[VAL_2]] : f32
 // CHECK:           } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32>
-// CHECK:           %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32>
-// CHECK:           %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32>
-// CHECK:           %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<?x2xf32>) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0]
-// CHECK:           %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32>
-// CHECK:           return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32>
-// CHECK:         }
+// CHECK:           %[[VAL_10:.*]] = tensor.empty(%[[VAL_3]]) : tensor<1x1x2x?xf32>
+// CHECK:           %[[VAL_11:.*]] = linalg.transpose
+// CHECK-SAME:        ins(%[[VAL_12:.*]] : tensor<1x1x?x2xf32>)
+// CHECK-SAME:        outs(%[[VAL_10]] : tensor<1x1x2x?xf32>)
+// CHECK-SAME:        permutation = [0, 1, 3, 2]
+// CHECK:           %[[VAL_13:.*]] = tensor.insert_slice %[[VAL_11]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<1x1x2x?xf32> into tensor<1x1x1x1x2x?xf32>
+// CHECK:           return %[[VAL_13]] : tensor<1x1x1x1x2x?xf32>
 
 // -----
 
@@ -218,12 +217,11 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
 // CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x32xf32>
 // CHECK:         %[[TRANSP:.+]] =  linalg.transpose
-// CHECK-SAME:      ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME:      outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME:      permutation = [1, 0]
+// CHECK-SAME:      ins(%[[SRC]] : tensor<1x1x32x8xf32>
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME:      permutation = [0, 1, 3, 2]
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]

Copy link

github-actions bot commented Nov 7, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315.

This is best to demonstrate with an example. Here's input for
`GeneralizeOuterUnitDimsPackOpPattern`:
```mlir
%pack = tensor.pack %input
  padding_value(%pad : f32)
  inner_dims_pos = [1, 0]
  inner_tiles = [2, %tile_dim_1]
  into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
```

Output _before_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%extracted_slice : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>)
  permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

Output _after_:
```mlir
%padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
  tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
  ins(%padded : tensor<?x2xf32>)
  outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
  into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
  tensor<2x?xf32> into tensor<1x1x2x?xf32>
```

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice. For this PR, it simplifies how the
permutation for linalg.transpose is computed. If needed, this can be
relaxed in the future
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but I'm not sure that this assumption about tensor.pack op is actually true in practice:

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice.

I have seen pack ops show up in practice that do not follow this restriction, so it may be worth supporting this case, but I'm not sure we rely much on this pattern anymore anyway. I'll approve, but pinging @hanhanW, who might have a better idea of whether or not support for this case would be desired as a follow up.

@hanhanW
Copy link
Contributor

hanhanW commented Nov 11, 2024

This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice.

I have seen pack ops show up in practice that do not follow this restriction, so it may be worth supporting this case, but I'm not sure we rely much on this pattern anymore anyway. I'll approve, but pinging @hanhanW, who might have a better idea of whether or not support for this case would be desired as a follow up.

It is not a restriction in our use cases because we can data-tile any ops that implements ContractionOpInterface. E.g., a contraction generic op which has batch dimension being the innermost dimension. In this case, we don't data-tile batch dimension.

These patterns were built for the first take of pack/unpack vectorization, and they are no longer used on IREE CPU x86 codegen path. Because today we have direct vectorization and masking supports.

@banach-space
Copy link
Contributor Author

It is not a restriction in our use cases because we can data-tile any ops that implements ContractionOpInterface. E.g., a contraction generic op which has batch dimension being the innermost dimension. In this case, we don't data-tile batch dimension.

These patterns were built for the first take of pack/unpack vectorization, and they are no longer used on IREE CPU x86 codegen path. Because today we have direct vectorization and masking supports.

Thanks for the context and for the explanation! Any preference how to proceed here?

From my perspective, without tests in-tree, it tends to be tricky to produce good ref examples. So, if I can simplify things without breaking any tests, I'd go with that.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, just a couple nits.

loc, readType, input, readOffsets, extractSliceSizes, readStrides);

// 2. Transpose the tile to match the inner tile order:
// 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment needs to be updated. There are no extracted_tile anymore?

Comment on lines 18 to 19
// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[ITER_SLICE]]
// CHECK-SAME: [0, 0, %[[R]], %[[S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x4x8x8x32xf32> to tensor<1x1x1x1x8x32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TILE generated by the tiling? Perhaps we can drop the checks because it is not used at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a tile for tensor.insert_slice below. We are not capturing in the expected output before, so I shouldn't be capturing it after. Let me remove it, thanks for catching this!

@banach-space banach-space force-pushed the andrzej/skip_extract_slice branch from 4e6236f to b7d3d99 Compare November 12, 2024 15:27
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@banach-space banach-space merged commit 7ebfbf9 into llvm:main Nov 12, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants