-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Update GeneralizeOuterUnitDimsPackOpPattern
#115312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Update GeneralizeOuterUnitDimsPackOpPattern
#115312
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesAvoid generating spurious tensor.extract_slice, follow-on for #114315. This is best to demonstrate with an example. Here's input for %pack = tensor.pack %input
padding_value(%pad : f32)
inner_dims_pos = [1, 0]
inner_tiles = [2, %tile_dim_1]
into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> Output before: %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%extracted_slice : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>)
permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed=
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32> Output after: %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %arg2 : f32
} : tensor<5x1xf32> to tensor<?x2xf32>
%empty = tensor.empty(%arg3) : tensor<2x?xf32>
%transposed = linalg.transpose
ins(%padded : tensor<?x2xf32>)
outs(%empty : tensor<2x?xf32>) permutation = [1, 0]
%inserted_slice = tensor.insert_slice %transposed
into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] :
tensor<2x?xf32> into tensor<1x1x2x?xf32> This PR also adds a check to verify that only the last N (for some value Full diff: https://github.com/llvm/llvm-project/pull/115312.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a8662a3d6f63be..5209e1145506b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1516,7 +1516,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
};
/// Rewrites a tensor::PackOp into a sequence of:
-/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+/// * tensor::PadOp + linalg::TransposeOp +
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
///
/// Required that all the outer dims of the input tensor::PackOp are 1.
@@ -1537,10 +1537,6 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// ^bb0(...):
/// tensor.yield %arg2 : f32
/// } : tensor<5x1xf32> to tensor<?x2xf32>
-/// // ExtractSliceOp
-/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
-/// 1]
-/// : tensor<?x2xf32> to tensor<?x2xf32>
/// // EmptyOp + TransposeOp
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
/// %transposed = linalg.transpose
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 64096954f56b95..0be8799f327441 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1153,71 +1153,63 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
Location loc = packOp.getLoc();
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
- auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
-
int64_t destRank = packOp.getDestRank();
- size_t numTiles = destRank - srcRank;
-
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
- // %extracted_tile = tensor.extract_slice(%pack_op_input)
- SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+ int64_t numTiles = destRank - srcRank;
- // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
- // all outer dims are 1.
- SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
- // The shape of the output for ExtractSliceOp. All leading unit dims are
- // effectively rank-reduced, hence skipped.
- SmallVector<int64_t> outputShapeForExtractSlice;
+ if (!llvm::all_of(packOp.getInnerDimsPos(),
+ [&srcRank, &numTiles](int64_t dimPos) {
+ return dimPos >= (srcRank - numTiles - 1);
+ }))
+ return rewriter.notifyMatchFailure(
+ packOp, "Attempting to tile non-trailing source dims!");
- // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
- // be equal to the inner tile sizes.
+ // 1. Extract the inner tile sizes.
+ // Where possible, values are replaced with constant attributes (to match the
+ // behaviour of `getPackOpSourceOrPaddedSource`).
+ SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- auto [tileSize, tileSizeOfr] =
+ // Rather than taking the tile size as is, extact the actual constant
+ // value Attribute where possible, e.g.:
+ // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
+ auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
- extractSliceSizes.push_back(tileSizeOfr);
- outputShapeForExtractSlice.push_back(tileSize);
+ tileSizes.push_back(tileSize);
}
}
- Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
-
- Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, extractSliceSizes, readStrides);
-
- // 2. Transpose the tile to match the inner tile order:
+ // 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
- // NOTE: Outer dims are 1 and hence effectively ignored.
- SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
- inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
+ // Two assumptions are made:
+ // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
+ // 2. Inner dims position correspond to the trailing `numTiles` dims.
+ SmallVector<int64_t> tilesPermNormalized =
+ getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+ SmallVector<int64_t> srcPermForTranspose;
+ for (int64_t i = 0; i < (srcRank - numTiles); i++)
+ srcPermForTranspose.push_back(i);
+
+ srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
- llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
+ llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); DBGSNL(););
// 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp;
-
- // Acquire tensor shape required to create EmptyOp. This will match the inner
- // tile sizes.
- size_t idx = numTiles;
- while (idx != 0) {
- transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
- idx--;
- }
+ SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
+ oneIdxAttr);
+ transShapeForEmptyOp.append(tileSizes);
- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
- Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, srcPermForTranspose);
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
// 2.2 Create linalg.transpose
auto transposedOp =
- rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
+ rewriter.create<linalg::TransposeOp>(loc, input, empty, srcPermForTranspose);
// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
index d0c53ae4680013..8be3e7413bfc81 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
@@ -9,19 +9,19 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
// CHECK: func.func @KCRS_to_KCRSsr
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] =
-// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] =
+// CHECK: scf.for %[[R:[a-zA-Z0-9]+]] =
+// CHECK: scf.for %[[S:[a-zA-Z0-9]+]] {{.*}} iter_args(%[[ITER_SLICE:.*]] =
// CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]])
// CHECK: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]])
// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]]
// CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]]
-// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]]
-// CHECK-SAME: outs(%[[EMPTY]]
-// CHECK-SAME: permutation = [1, 0]
+// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[ITER_SLICE]]
+// CHECK-SAME: [0, 0, %[[R]], %[[S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x4x8x8x32xf32> to tensor<1x1x1x1x8x32xf32>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x8x32xf32>
+// CHECK: %[[TRANSP:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC_SLICE]] : tensor<1x1x32x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2]
// CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 8abf7a11bed5c9..f4b1d9a55f0914 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -63,8 +63,7 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
@@ -95,10 +94,10 @@ func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NEXT: } : tensor<5x1xf32> to tensor<?x2xf32>
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
// CHECK: %[[TR:.*]] = linalg.transpose
-// CHECK-SAME: ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME: ins(%[[PAD:.*]] : tensor<?x2xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x?xf32>)
// CHECK-SAME: permutation = [1, 0]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
// CHECK: return %[[RES]] : tensor<1x1x2x?xf32>
@@ -128,10 +127,10 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
/// Same as example above, but with both tile sizes dynamic.
func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> {
@@ -149,8 +148,7 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<1x1x?x?xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x?xf32>
// -----
@@ -170,12 +168,13 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
// CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index):
// CHECK: tensor.yield %[[VAL_2]] : f32
// CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32>
-// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32>
-// CHECK: %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<?x2xf32>) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0]
-// CHECK: %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32>
-// CHECK: return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32>
-// CHECK: }
+// CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_3]]) : tensor<1x1x2x?xf32>
+// CHECK: %[[VAL_11:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[VAL_12:.*]] : tensor<1x1x?x2xf32>)
+// CHECK-SAME: outs(%[[VAL_10]] : tensor<1x1x2x?xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2]
+// CHECK: %[[VAL_13:.*]] = tensor.insert_slice %[[VAL_11]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<1x1x2x?xf32> into tensor<1x1x1x1x2x?xf32>
+// CHECK: return %[[VAL_13]] : tensor<1x1x1x1x2x?xf32>
// -----
@@ -218,12 +217,11 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x32xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME: permutation = [1, 0]
+// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x32x8xf32>
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x32xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315. This is best to demonstrate with an example. Here's input for `GeneralizeOuterUnitDimsPackOpPattern`: ```mlir %pack = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> ``` Output _before_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%extracted_slice : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed= into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` Output _after_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%padded : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` This PR also adds a check to verify that only the last N (for some value of N) trailing dims that are being tiled. From what I can tell, that's always the case in practice. For this PR, it simplifies how the permutation for linalg.transpose is computed. If needed, this can be relaxed in the future
8ed63a2
to
adc0fc7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, but I'm not sure that this assumption about tensor.pack op is actually true in practice:
This PR also adds a check to verify that only the last N (for some value
of N) trailing dims that are being tiled. From what I can tell, that's
always the case in practice.
I have seen pack ops show up in practice that do not follow this restriction, so it may be worth supporting this case, but I'm not sure we rely much on this pattern anymore anyway. I'll approve, but pinging @hanhanW, who might have a better idea of whether or not support for this case would be desired as a follow up.
It is not a restriction in our use cases because we can data-tile any ops that implements ContractionOpInterface. E.g., a contraction generic op which has batch dimension being the innermost dimension. In this case, we don't data-tile batch dimension. These patterns were built for the first take of pack/unpack vectorization, and they are no longer used on IREE CPU x86 codegen path. Because today we have direct vectorization and masking supports. |
Thanks for the context and for the explanation! Any preference how to proceed here? From my perspective, without tests in-tree, it tends to be tricky to produce good ref examples. So, if I can simplify things without breaking any tests, I'd go with that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, just a couple nits.
loc, readType, input, readOffsets, extractSliceSizes, readStrides); | ||
|
||
// 2. Transpose the tile to match the inner tile order: | ||
// 2. Transpose the input to match the inner tile order: | ||
// %init = tensor.empty() | ||
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment needs to be updated. There are no extracted_tile anymore?
// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[ITER_SLICE]] | ||
// CHECK-SAME: [0, 0, %[[R]], %[[S]], 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] : tensor<1x1x4x8x8x32xf32> to tensor<1x1x1x1x8x32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this TILE
generated by the tiling? Perhaps we can drop the checks because it is not used at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's a tile for tensor.insert_slice
below. We are not capturing in the expected output before, so I shouldn't be capturing it after. Let me remove it, thanks for catching this!
4e6236f
to
b7d3d99
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Avoid generating spurious tensor.extract_slice, follow-on for #114315.
This is best to demonstrate with an example. Here's input for
GeneralizeOuterUnitDimsPackOpPattern
:Output before:
Output after:
This PR also adds a check to verify that only the last N trailing dimensions
are tiled (for some value of N). Based on the PR discussion, this
restriction seems reasonable—especially as there are no in-tree tests
requiring otherwise. For now, it also simplifies the computation of
permutations for linalg.transpose. This restriction can be relaxed in
the future if needed.