-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern
#114315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThis PR restricts > require all outer dims of tensor.pack to be 1. There was one test in-tree that violated that assumption (and, happened By making the pattern follow its intended design (i.e. making it func.func @<!-- -->simple_pad_and_pack_dynamic_tile_cst(
%src: tensor<5x1xf32>,
%dest: tensor<1x1x?x2xf32>,
%pad: f32) -> tensor<1x1x?x2xf32> {
%tile_dim_0 = arith.constant 8 : index
%0 = tensor.pack %src
padding_value(%pad : f32)
inner_dims_pos = [0, 1]
inner_tiles = [%tile_dim_0, 2]
into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
return %0 : tensor<1x1x?x2xf32>
} Note that the inner tile slice is dynamic, but compile-time constant. Notable implementation changes:
Patch is 26.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114315.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b5710bd78f0089..a8662a3d6f63be 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
const SmallVector<Value> &dynSizes) const;
};
-/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
-/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
-/// 1s.
+/// Rewrites a tensor::PackOp into a sequence of:
+/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
+/// tensor::EmptyOp + tensor::InsertSliceOp ops.
+///
+/// Required that all the outer dims of the input tensor::PackOp are 1.
+///
+/// Before:
+/// ```
+/// %packed = tensor.pack %input
+/// padding_value(%pad : f32)
+/// inner_dims_pos = [1, 0]
+/// inner_tiles = [2, %high]
+/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
+/// ```
+///
+/// After:
+/// ```
+/// // PadOp
+/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
+/// ^bb0(...):
+/// tensor.yield %arg2 : f32
+/// } : tensor<5x1xf32> to tensor<?x2xf32>
+/// // ExtractSliceOp
+/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
+/// 1]
+/// : tensor<?x2xf32> to tensor<?x2xf32>
+/// // EmptyOp + TransposeOp
+/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
+/// %transposed = linalg.transpose
+/// ins(%extracted_slice : tensor<?x2xf32>)
+/// outs(%empty : tensor<2x?xf32>)
+/// permutation = [1, 0]
+/// // InsertSliceOp
+/// %inserted_slice = tensor.insert_slice %transposed
+/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
+/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
+/// ```
struct GeneralizeOuterUnitDimsPackOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index da5233049aaf69..ed5f1bd602d7f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1138,6 +1139,29 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
return perm;
}
+// A helper function to generate a dim-and-size pair for Ops like
+// ExtractSliceOp that require both:
+// * dims to specify the output shape, and
+// * sizes for the sizes attribute (or similar).
+// For dynamic sizes, if the corresponding size is a compile time constant:
+// * the return size becomes the attribute encapsulating the known size, and
+// * dim is updated from kDynamic to its actual known value.
+static std::pair<int64_t, OpFoldResult>
+getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
+ int64_t tileSizeForShape =
+ getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
+
+ OpFoldResult tileSizeOfrSimplified;
+ if (tileSizeForShape != ShapedType::kDynamic) {
+ tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
+ } else {
+ tileSizeOfrSimplified = tileSizeOfr;
+ }
+
+ return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
+ tileSizeOfrSimplified);
+}
+
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
- // outer dims.
+ Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+ Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
int64_t srcRank = packOp.getSourceRank();
+
+ int64_t destRank = packOp.getDestRank();
+ size_t numTiles = destRank - srcRank;
+
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+ // %extracted_tile = tensor.extract_slice(%pack_op_input)
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> readSizes;
- SmallVector<OpFoldResult> transShapeForEmpty;
- SmallVector<int64_t> readShapeForExtractSlice;
+
+ // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
+ // all outer dims are 1.
+ SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
+ // The shape of the output for ExtractSliceOp. All leading unit dims are
+ // effectively rank-reduced, hence skipped.
+ SmallVector<int64_t> outputShapeForExtractSlice;
+
+ // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
+ // be equal to the inner tile sizes.
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- readShapeForExtractSlice.push_back(
- getConstantIntValue(dimAndTileMapping[i])
- .value_or(ShapedType::kDynamic));
- readSizes.push_back(dimAndTileMapping[i]);
- transShapeForEmpty.push_back(dimAndTileMapping[i]);
- continue;
- }
- if (ShapedType::isDynamic(inputShape[i])) {
- readSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, input, i).getResult());
- } else {
- readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
- }
- if (inputShape[i] != 1) {
- readShapeForExtractSlice.push_back(inputShape[i]);
- transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+ auto [tileSize, tileSizeOfr] =
+ getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+ extractSliceSizes.push_back(tileSizeOfr);
+ outputShapeForExtractSlice.push_back(tileSize);
}
}
Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
+ auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, readSizes, readStrides);
+ loc, readType, input, readOffsets, extractSliceSizes, readStrides);
- // 2. Transpose the tile to match the inner tile order.
+ // 2. Transpose the tile to match the inner tile order:
+ // %init = tensor.empty()
+ // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
+ // NOTE: Outer dims are 1 and hence effectively ignored.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
- applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
+ // 2.1 Create tensor.empty (init value for TransposeOp)
+ SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
+ SmallVector<int64_t> transShapeForEmptyOpStatic;
+
+ // Acquire tensor shape required to create EmptyOp. This will match the inner
+ // tile sizes, but the actual data format will depend on whether the tile
+ // sizes are static or dynamic (each case leads to a different builder for
+ // EmptyOp). Conservatively, prepare for both scenarios.
+ size_t idx = numTiles;
+ while (idx != 0) {
+ transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
+ transShapeForEmptyOpStatic.push_back(
+ outputShapeForExtractSlice[numTiles - idx]);
+ idx--;
+ }
- Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
+ applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
+
+ Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
+ ? rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOpDynamic, elemType)
+ : rewriter.create<tensor::EmptyOp>(
+ loc, transShapeForEmptyOpStatic, elemType);
+
+ // 2.2 Create linalg.transpose
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
- // 3. Insert the inner tile to the destination.
- int64_t destRank = packOp.getDestRank();
+ // 3. Insert the inner tile to the destination:
+ // %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- SmallVector<OpFoldResult> writeSizes =
- tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+ // Outer dims are all 1s!
+ SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
+ oneIdxAttr);
+ SmallVector<int64_t> writeShape;
+
+ for (auto tileSize : packOp.getMixedTiles()) {
+ auto [tileSizeStatic, tileSizeOfr] =
+ getSimplifiedDimSizePair(tileSize, rewriter);
+ writeSizes.push_back(tileSizeOfr);
+ writeShape.push_back(tileSizeStatic);
+ }
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
writeSizes, writeStrides);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 7f6b5e279f6857..8abf7a11bed5c9 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -1,21 +1,32 @@
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s
-func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> {
- %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32>
- return %0 : tensor<1x1x1x1x8x32xf32>
+
+func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> {
+ %c8 = arith.constant 8 : index
+ %c5 = arith.constant 5 : i32
+ %pack = tensor.pack %arg0 padding_value(%c5 : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %arg1 : tensor<?x?xi32> -> tensor<1x1x?x1xi32>
+ return %pack : tensor<1x1x?x1xi32>
}
-// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32>
-// CHECK: %[[TRANSP:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>)
-// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>)
-// CHECK-SAME: permutation = [1, 0]
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
-// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
-// CHECK: return %[[INSERT]]
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (-s0 + 8)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<()[s0] -> (-s0 + 1)>
+
+// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<?x?xi32>,
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i32
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[SRC]], %[[VAL_4]] : tensor<?x?xi32>
+// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]]
+// CHECK: %[[VAL_7:.*]] = tensor.dim %[[SRC]], %[[VAL_2]] : tensor<?x?xi32>
+// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[VAL_7]]]
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[VAL_6]], %[[VAL_8]]] {
+// CHECK: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
+// CHECK: tensor.yield %[[VAL_3]] : i32
+// CHECK: } : tensor<?x?xi32> to tensor<8x1xi32>
+// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<8x1xi32> into tensor<1x1x?x1xi32>
+// CHECK: return %[[INSERT]] : tensor<1x1x?x1xi32>
// -----
@@ -39,26 +50,59 @@ func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: te
/// Same as example above, but with 1 dynamic tile size.
-func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
- %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %tile_dim_0: index) -> tensor<1x1x?x2xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
return %0 : tensor<1x1x?x2xf32>
}
-
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[HIGH_VAL:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
+// CHECK-SAME: %[[TILE_DIM_0:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> {
+// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_0]]]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
-// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
+ %tile_dim_0 = arith.constant 8 : index
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
+ return %0 : tensor<1x1x?x2xf32>
+}
+// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_cst(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK: } : tensor<5x1xf32> to tensor<8x2xf32>
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<8x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+
+func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %output: tensor<1x1x2x?xf32>, %pad: f32, %tile_dim_1: index) -> tensor<1x1x2x?xf32> {
+ %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
+ return %0 : tensor<1x1x2x?xf32>
+}
+// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_transpose(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]: index) -> tensor<1x1x2x?xf32> {
+// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_1]]]
+// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
+// CHECK: tensor.yield %[[PAD_VAL]] : f32
+// CHECK-NEXT: } : tensor<5x1xf32> to tensor<?x2xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
+// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32>
+// CHECK: %[[TR:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[SLICE]] : tensor<?x2xf32>) outs(%[[EMPTY]] : tensor<2x?xf32>)
+// CHECK-SAME: permutation = [1, 0]
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32>
+// CHECK: return %[[RES]] : tensor<1x1x2x?xf32>
+
/// Same as example above, but with 1 scalable tile size.
/// NOTE: For this example to make sense in practice, the "?" in the output shape
@@ -77,7 +121,6 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[VS:.+]] = vector.vscale
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
@@ -86,37 +129,56 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t
// CHECK: tensor.yield %[[PAD_VAL]] : f32
// CHECK-NOT: linalg.transpose
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
-// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
-// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
/// Same as example above, but with both tile sizes dynamic.
-func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %high_1: index, %high_2: index) -> tensor<1x1x?x?xf32> {
- %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high_1, %high_2] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32>
+func.func @simple_pad_and_pack_dynami...
[truncated]
|
7f79675
to
d556706
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR restricts GeneralizeOuterUnitDimsPackOpPattern to follow its
intended purpose (as per the documentation),
Does this PR add this restriction in the matcher for the pattern? I didn't see any update to the matching logic.
There was one in-tree test that violated this assumption (and happened
to work) – see @simple_KCRS_to_KRSCsr in
"generalize-tensor-pack.mlir". That test has been updated to satisfy the
new requirements of the pattern.
This test has a non-unit outer dimension, but that dimension is not packed (no inner_tile for it). I think this may actually be intended behavior, since the pattern checks that packOp.getTiledOuterDims()
are all 1. Maybe the comments are just misleading, and this case is meant to be supported.
I'm not sure what the motivation of this pattern was to begin with, so I can't say if this type of case needs to be supported, but I would be wary of removing that functionality without hearing from whoever wrote this pattern. Looks like @qedawkins added this test, so he may have better context.
/// // ExtractSliceOp | ||
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1, | ||
/// 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may not be in the scope of this PR to change this, but what exactly is this extract_slice for? It seems to me like it is just taking a full slice. Is the slice ever not full?
static std::pair<int64_t, OpFoldResult> | ||
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a useful util, maybe move it to Dialect/Utils/StaticValueUtils.h
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)
func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> { | ||
%0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32> | ||
return %0 : tensor<3x1x1x1x8x32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are no other tests like this, can we actually keep this test so we can test that the outer dims must be all 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't see anything in this PR that restricts the pattern to fail on this case. Did you mean to update the matching logic too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't see anything in this PR that restricts the pattern to fail on this case. Did you mean to update the matching logic too?
I did, thanks for catching this. Sending an update shortly.
Thanks for taking a look! Quick reply to your high-level question (emphasis mine):
Indeed, I'd really appreciate if somebody could clarify this 😅
100% agreed, thanks for bringing this up! As you can see, I'm actually struggling a bit to get feedback for this - I really appreciate you not shying away! 🙏🏻 While waiting for Quinn to chime in, I will make a couple of points.
func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1:
tensor<3x1x1x1x8x32xf32>, %pad: f32) -> tensor<3x1x1x1x8x32xf32> {
%0 = tensor.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
return %0 : tensor<3x1x1x1x8x32xf32>
} The logic that files: getPackOpSourceOrPaddedSource (i.e. you will hit the assert in that method). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is a case where the term outer
is slightly unclear (but I probably interpreted it wrong when I wrote it). What the docs should have said after I added that test is all tiled outer dims. In that case the typical decomposition of a pack (pad
+ expand_shape
+ transpose
) would only be introducing unit dimensions with the expand_shape
and could use a rank reducing insert_slice
to introduce the new sizes. Whether the untiled outer dims were also unit doesn't affect that aspect of the decomposition.
This also helped generalize the pattern to support cases like the one below:
...
By restricting the pattern and simplifying the size/shape calculation, supporting the case above becomes much easier.
I don't quite follow how exclusion of non-unit untiled outer dimensions simplifies shape calculation logic here (I suspect it could have been more a matter of the previous implementation being more convoluted than necessary).
// CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index): | ||
// CHECK: tensor.yield %[[VAL_2]] : f32 | ||
// CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32> | ||
// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor<?x2xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extract_slice seems unnecessary to me. Instead we can just extend the permutation of the transpose to include the outer most untiled dims.
int64_t destRank = packOp.getDestRank(); | ||
size_t numTiles = destRank - srcRank; | ||
|
||
// 1. Use rank-reduced tensor.extract_slice op to extract the tile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on the advantage of slicing out the inner tile here? We end up reintroducing the outer unit dims shortly after with the tensor.insert_slice
so this only serves to remove the unit dims from the linalg.transpose
for permuting the inner dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am merely a messenger here ;-) (as in, tensor.extract_slice
is already a part of this logic, not something added by me).
To me, it makes more sense to only transpose the tile worth of data as that's the intent of tensor.pack
, right? But you will have more experience with this logic. I just wanted to preserve the original logic as much as possible and to avoid making too many changes in one PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah really, that makes sense then. If it's preserving existing behavior leaving it is fine, but should look into removing it at some point.
Ah I see your response before I sent my review.
I believe we were at one point relying on support for this downstream, but that might not be the case anymore (the pass I'm thinking of might have grown it's own simplified pattern for handling cases like this). Dropping support and fixing forward if the current implementation is broken (as you seem to be suggesting with your second point). |
Thank you for taking a look :)
Indeed. In fact, I've implemented #109642 so that we can be clearer about the intent in the future. But this logic predates that PR.
Yes, it's rather convoluted and also incomplete (hence this PR). And then there's getPackOpSourceOrPaddedSource that does require all outer dims to be unit - at least when the pad value is specified. This specifcially makes me believe that this is neither needed nor used? As for simplifying calculation, if we do allow non-unit untiled dims, then we need to track 3 sets of dims (inner, outer tiled, outer untiled). Sure, that can be done, but given there are other issues here ... I am happy to restore that functionality, but if it's not really required by anyone (that part is still unclear to me), then keeping things simple might be to our overall benefit. What are your thoughts? |
This has fallen pretty far out of my cache but the pass that required support for the non-unit untiled outer dim no longer depends on this pattern so I think you should be good to drop it for now. I would still like to see if we can avoid that extract_slice though. I've found rank-reducing extract_slice ops to compose poorly with most other patterns/passes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining the current state and the nice comments. This looks like good cleanup to me now and we can try to fix forward on the case dropped here (unless someone else is relying on that functionality).
// * the return size becomes the attribute encapsulating the known size, and | ||
// * dim is updated from kDynamic to its actual known value. | ||
static std::pair<int64_t, OpFoldResult> | ||
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can just use Builder &
instead of a PatternRewriter &
applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm); | ||
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm); | ||
|
||
Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?
I'm thinking that for any case where you needed the static builder, we could have had an additional dynamic dim that would make it take the dynamic path, which should still do the same thing for the static part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?
Great point!
It turns out that EmptyOp::build already supports the necessary "magic" via dispatchIndexOpFoldResults :)
d556706
to
99d24b7
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
I've just sent some updates addressing/implementing your suggestions for this revision. Two high-level comments:
Most importantly, thank you for your in-depth review 🙏🏻 |
99d24b7
to
7916a3f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM, thanks!
7916a3f
to
50ab228
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the nice discussion with Quinn and this look good to me now. Thanks for the cleanup!
2c3666d
to
82283ba
Compare
…ern` This PR *restricts* `GeneralizeOuterUnitDimsPackOpPattern` to follow its intended purpose (as per the documentation), which is to: > require all outer dimensions of tensor.pack to be 1. There was one in-tree test that violated this assumption (and happened to work) – see `@simple_KCRS_to_KRSCsr` in "generalize-tensor-pack.mlir". That test has been updated to satisfy the new requirements of the pattern. By enforcing the pattern to follow its intended design (i.e., making it stricter), the calculation of shapes and sizes for various Ops that the pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and InsertSliceOp) becomes much simpler and easier to document. This also helped *generalize* the pattern to support cases like the one below: ```mlir func.func @simple_pad_and_pack_dynamic_tile_cst( %src: tensor<5x1xf32>, %dest: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> { %tile_dim_0 = arith.constant 8 : index %0 = tensor.pack %src padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32> return %0 : tensor<1x1x?x2xf32> } ``` Note that the inner tile slice is dynamic but compile-time constant. `getPackOpSourceOrPaddedSource`, which is used to generate PadOp, detects this and generates a PadOp with static shapes. This is a good optimization, but it means that all shapes/sizes for Ops generated by `GeneralizeOuterUnitDimsPackOpPattern` also need to be updated to be constant/static. By restricting the pattern and simplifying the size/shape calculation, supporting the case above becomes much easier. Notable implementation changes: * PadOp processes the original source (no change in dimensions/rank). ExtractSliceOp extracts the tile to pack and may reduce the rank. All following ops work on the tile extracted by ExtractSliceOp (possibly rank-reduced). * All shape/size calculations assume that trailing dimensions match inner_tiles from tensor.pack. All leading dimensions (i.e., outer dimensions) are assumed to be 1. * Dynamic sizes for ops like ExtractSliceOp are taken from inner_tiles rather than computed as, for example, tensor.dim %dest, 2. It’s the responsibility of the "producers" of tensor.pack to ensure that dimensions in %dest match the specified tile sizes.
…kOpPattern` SKip calculating static shapes for EmptyOp
…DimsPackOpPattern` Raname and move getSimplifiedDimSizePair
…terUnitDimsPackOpPattern` Minor tweak
82283ba
to
2d9bc62
Compare
Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315. This is best to demonstrate with an example. Here's input for `GeneralizeOuterUnitDimsPackOpPattern`: ```mlir %pack = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> ``` Output _before_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%extracted_slice : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed= into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` Output _after_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%padded : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` This PR also adds a check to verify that only the last N (for some value of N) trailing dims that are being tiled. From what I can tell, that's always the case in practice. For this PR, it simplifies how the permutation for linalg.transpose is computed. If needed, this can be relaxed in the future
Avoid generating spurious tensor.extract_slice, follow-on for llvm#114315. This is best to demonstrate with an example. Here's input for `GeneralizeOuterUnitDimsPackOpPattern`: ```mlir %pack = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> ``` Output _before_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%extracted_slice : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed= into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` Output _after_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%padded : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` This PR also adds a check to verify that only the last N (for some value of N) trailing dims that are being tiled. From what I can tell, that's always the case in practice. For this PR, it simplifies how the permutation for linalg.transpose is computed. If needed, this can be relaxed in the future
Promised refactor for |
Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes. While relatively simple (e.g., no vectorization), this example required a few non-trivial fixes in handling `tensor.pack`: * llvm#114315, llvm#114559, llvm#113108. The end goal for this test is to incrementally increase its complexity and to work towards scalable tile sizes.
Avoid generating spurious tensor.extract_slice, follow-on for #114315. This is best to demonstrate with an example. Here's input for `GeneralizeOuterUnitDimsPackOpPattern`: ```mlir %pack = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> ``` Output _before_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> // NOTE: skipped in the output _after_ %extracted_slice = tensor.extract_slice %padded[0, 0] [%arg3, 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%extracted_slice : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed= into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` Output _after_: ```mlir %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg4: index, %arg5: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> %empty = tensor.empty(%arg3) : tensor<2x?xf32> %transposed = linalg.transpose ins(%padded : tensor<?x2xf32>) outs(%empty : tensor<2x?xf32>) permutation = [1, 0] %inserted_slice = tensor.insert_slice %transposed into %arg1[0, 0, 0, 0] [1, 1, 2, %arg3] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> ``` This PR also adds a check to verify that only the last N trailing dimensions are tiled (for some value of N). Based on the PR discussion, this restriction seems reasonable - especially as there are no in-tree tests requiring otherwise. For now, it also simplifies the computation of permutations for linalg.transpose. This restriction can be relaxed in the future if needed.
…115698) Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes. While relatively simple (e.g., no vectorization), this example required a few non-trivial fixes in handling `tensor.pack`: * #114315, #114559, #113108. The end goal for this test is to incrementally increase its complexity and to work towards scalable tile sizes.
This PR restricts
GeneralizeOuterUnitDimsPackOpPattern
to follow itsintended purpose (as per the documentation), which is to:
There was one in-tree test that violated this assumption (and happened
to work) – see
@simple_KCRS_to_KRSCsr
in"generalize-tensor-pack.mlir". That test has been updated to satisfy the
new requirements of the pattern.
By enforcing the pattern to follow its intended design (i.e., making it
stricter), the calculation of shapes and sizes for various Ops that the
pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and
InsertSliceOp) becomes much simpler and easier to document. This also
helped generalize the pattern to support cases like the one below:
Note that the inner tile slice is dynamic but compile-time constant.
getPackOpSourceOrPaddedSource
, which is used to generate PadOp,detects this and generates a PadOp with static shapes. This is a good
optimization, but it means that all shapes/sizes for Ops generated by
GeneralizeOuterUnitDimsPackOpPattern
also need to be updated to beconstant/static. By restricting the pattern and simplifying the
size/shape calculation, supporting the case above becomes much easier.
Notable implementation changes:
ExtractSliceOp extracts the tile to pack and may reduce the rank. All
following ops work on the tile extracted by ExtractSliceOp (possibly
rank-reduced).
inner_tiles from tensor.pack. All leading dimensions (i.e., outer
dimensions) are assumed to be 1.
rather than computed as, for example, tensor.dim %dest, 2. It’s the
responsibility of the "producers" of tensor.pack to ensure that
dimensions in %dest match the specified tile sizes.