Skip to content

Commit d556706

Browse files
committed
[mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern
This PR *restricts* `GeneralizeOuterUnitDimsPackOpPattern` to follow its intended purpose (as per the documentation), which is to: > require all outer dimensions of tensor.pack to be 1. There was one in-tree test that violated this assumption (and happened to work) – see `@simple_KCRS_to_KRSCsr` in "generalize-tensor-pack.mlir". That test has been updated to satisfy the new requirements of the pattern. By enforcing the pattern to follow its intended design (i.e., making it stricter), the calculation of shapes and sizes for various Ops that the pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and InsertSliceOp) becomes much simpler and easier to document. This also helped *generalize* the pattern to support cases like the one below: ```mlir func.func @simple_pad_and_pack_dynamic_tile_cst( %src: tensor<5x1xf32>, %dest: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> { %tile_dim_0 = arith.constant 8 : index %0 = tensor.pack %src padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32> return %0 : tensor<1x1x?x2xf32> } ``` Note that the inner tile slice is dynamic but compile-time constant. `getPackOpSourceOrPaddedSource`, which is used to generate PadOp, detects this and generates a PadOp with static shapes. This is a good optimization, but it means that all shapes/sizes for Ops generated by `GeneralizeOuterUnitDimsPackOpPattern` also need to be updated to be constant/static. By restricting the pattern and simplifying the size/shape calculation, supporting the case above becomes much easier. Notable implementation changes: * PadOp processes the original source (no change in dimensions/rank). ExtractSliceOp extracts the tile to pack and may reduce the rank. All following ops work on the tile extracted by ExtractSliceOp (possibly rank-reduced). * All shape/size calculations assume that trailing dimensions match inner_tiles from tensor.pack. All leading dimensions (i.e., outer dimensions) are assumed to be 1. * Dynamic sizes for ops like ExtractSliceOp are taken from inner_tiles rather than computed as, for example, tensor.dim %dest, 2. It’s the responsibility of the "producers" of tensor.pack to ensure that dimensions in %dest match the specified tile sizes.
1 parent e61a7dc commit d556706

File tree

3 files changed

+239
-84
lines changed

3 files changed

+239
-84
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
15151515
const SmallVector<Value> &dynSizes) const;
15161516
};
15171517

1518-
/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
1519-
/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
1520-
/// 1s.
1518+
/// Rewrites a tensor::PackOp into a sequence of:
1519+
/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
1520+
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
1521+
///
1522+
/// Required that all the outer dims of the input tensor::PackOp are 1.
1523+
///
1524+
/// Before:
1525+
/// ```
1526+
/// %packed = tensor.pack %input
1527+
/// padding_value(%pad : f32)
1528+
/// inner_dims_pos = [1, 0]
1529+
/// inner_tiles = [2, %high]
1530+
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
1531+
/// ```
1532+
///
1533+
/// After:
1534+
/// ```
1535+
/// // PadOp
1536+
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
1537+
/// ^bb0(...):
1538+
/// tensor.yield %arg2 : f32
1539+
/// } : tensor<5x1xf32> to tensor<?x2xf32>
1540+
/// // ExtractSliceOp
1541+
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
1542+
/// 1]
1543+
/// : tensor<?x2xf32> to tensor<?x2xf32>
1544+
/// // EmptyOp + TransposeOp
1545+
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
1546+
/// %transposed = linalg.transpose
1547+
/// ins(%extracted_slice : tensor<?x2xf32>)
1548+
/// outs(%empty : tensor<2x?xf32>)
1549+
/// permutation = [1, 0]
1550+
/// // InsertSliceOp
1551+
/// %inserted_slice = tensor.insert_slice %transposed
1552+
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
1553+
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
1554+
/// ```
15211555
struct GeneralizeOuterUnitDimsPackOpPattern
15221556
: public OpRewritePattern<tensor::PackOp> {
15231557
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 92 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2828
#include "mlir/IR/AffineExpr.h"
2929
#include "mlir/IR/Matchers.h"
30+
#include "mlir/IR/PatternMatch.h"
3031
#include "mlir/Pass/Pass.h"
3132
#include "mlir/Support/LLVM.h"
3233
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1138,6 +1139,29 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11381139
return perm;
11391140
}
11401141

1142+
// A helper function to generate a dim-and-size pair for Ops like
1143+
// ExtractSliceOp that require both:
1144+
// * dims to specify the output shape, and
1145+
// * sizes for the sizes attribute (or similar).
1146+
// For dynamic sizes, if the corresponding size is a compile time constant:
1147+
// * the return size becomes the attribute encapsulating the known size, and
1148+
// * dim is updated from kDynamic to its actual known value.
1149+
static std::pair<int64_t, OpFoldResult>
1150+
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
1151+
int64_t tileSizeForShape =
1152+
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
1153+
1154+
OpFoldResult tileSizeOfrSimplified;
1155+
if (tileSizeForShape != ShapedType::kDynamic) {
1156+
tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
1157+
} else {
1158+
tileSizeOfrSimplified = tileSizeOfr;
1159+
}
1160+
1161+
return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
1162+
tileSizeOfrSimplified);
1163+
}
1164+
11411165
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11421166
tensor::PackOp packOp, PatternRewriter &rewriter) const {
11431167
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11481172
packOp, "require the tiled outer dimensions of the result are all 1s");
11491173
}
11501174

1151-
// 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
1152-
// outer dims.
1175+
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1176+
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11531177
Location loc = packOp.getLoc();
1178+
11541179
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
11551180
auto inputShape = packOp.getSourceType().getShape();
11561181
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
11571182
packOp.getDimAndTileMapping();
1158-
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1159-
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11601183
int64_t srcRank = packOp.getSourceRank();
1184+
1185+
int64_t destRank = packOp.getDestRank();
1186+
size_t numTiles = destRank - srcRank;
1187+
1188+
// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1189+
// %extracted_tile = tensor.extract_slice(%pack_op_input)
11611190
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
11621191
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1163-
SmallVector<OpFoldResult> readSizes;
1164-
SmallVector<OpFoldResult> transShapeForEmpty;
1165-
SmallVector<int64_t> readShapeForExtractSlice;
1192+
1193+
// The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
1194+
// all outer dims are 1.
1195+
SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
1196+
// The shape of the output for ExtractSliceOp. All leading unit dims are
1197+
// effectively rank-reduced, hence skipped.
1198+
SmallVector<int64_t> outputShapeForExtractSlice;
1199+
1200+
// Extract the trailing sizes and shape dims for ExtractSliceOp. These should
1201+
// be equal to the inner tile sizes.
11661202
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
11671203
if (dimAndTileMapping.count(i)) {
1168-
readShapeForExtractSlice.push_back(
1169-
getConstantIntValue(dimAndTileMapping[i])
1170-
.value_or(ShapedType::kDynamic));
1171-
readSizes.push_back(dimAndTileMapping[i]);
1172-
transShapeForEmpty.push_back(dimAndTileMapping[i]);
1173-
continue;
1174-
}
1175-
if (ShapedType::isDynamic(inputShape[i])) {
1176-
readSizes.push_back(
1177-
rewriter.create<tensor::DimOp>(loc, input, i).getResult());
1178-
} else {
1179-
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
1180-
}
1181-
if (inputShape[i] != 1) {
1182-
readShapeForExtractSlice.push_back(inputShape[i]);
1183-
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
1204+
auto [tileSize, tileSizeOfr] =
1205+
getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
1206+
extractSliceSizes.push_back(tileSizeOfr);
1207+
outputShapeForExtractSlice.push_back(tileSize);
11841208
}
11851209
}
11861210

11871211
Type elemType = packOp.getSourceType().getElementType();
1188-
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1212+
auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
11891213

11901214
Value tile = rewriter.create<tensor::ExtractSliceOp>(
1191-
loc, readType, input, readOffsets, readSizes, readStrides);
1215+
loc, readType, input, readOffsets, extractSliceSizes, readStrides);
11921216

1193-
// 2. Transpose the tile to match the inner tile order.
1217+
// 2. Transpose the tile to match the inner tile order:
1218+
// %init = tensor.empty()
1219+
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
1220+
// NOTE: Outer dims are 1 and hence effectively ignored.
11941221
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
11951222
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
11961223

11971224
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
11981225
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
11991226

1200-
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1227+
// 2.1 Create tensor.empty (init value for TransposeOp)
1228+
SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
1229+
SmallVector<int64_t> transShapeForEmptyOpStatic;
1230+
1231+
// Acquire tensor shape required to create EmptyOp. This will match the inner
1232+
// tile sizes, but the actual data format will depend on whether the tile
1233+
// sizes are static or dynamic (each case leads to a different builder for
1234+
// EmptyOp). Conservatively, prepare for both scenarios.
1235+
size_t idx = numTiles;
1236+
while (idx != 0) {
1237+
transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
1238+
transShapeForEmptyOpStatic.push_back(
1239+
outputShapeForExtractSlice[numTiles - idx]);
1240+
idx--;
1241+
}
12011242

1202-
Value empty =
1203-
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1243+
applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
1244+
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
1245+
1246+
Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
1247+
? rewriter.create<tensor::EmptyOp>(
1248+
loc, transShapeForEmptyOpDynamic, elemType)
1249+
: rewriter.create<tensor::EmptyOp>(
1250+
loc, transShapeForEmptyOpStatic, elemType);
1251+
1252+
// 2.2 Create linalg.transpose
12041253
auto transposedOp =
12051254
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
12061255

1207-
// 3. Insert the inner tile to the destination.
1208-
int64_t destRank = packOp.getDestRank();
1256+
// 3. Insert the inner tile to the destination:
1257+
// %inserted_tile = tensor.insert_slice(%transposed_tile)
12091258
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12101259
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1211-
SmallVector<OpFoldResult> writeSizes =
1212-
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
1260+
// Outer dims are all 1s!
1261+
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1262+
oneIdxAttr);
1263+
SmallVector<int64_t> writeShape;
1264+
1265+
for (auto tileSize : packOp.getMixedTiles()) {
1266+
auto [tileSizeStatic, tileSizeOfr] =
1267+
getSimplifiedDimSizePair(tileSize, rewriter);
1268+
writeSizes.push_back(tileSizeOfr);
1269+
writeShape.push_back(tileSizeStatic);
1270+
}
12131271

1272+
// 4. Replace tensor.packOp with tensor.insert_slice created above
12141273
auto insert = rewriter.create<tensor::InsertSliceOp>(
12151274
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
12161275
writeSizes, writeStrides);

0 commit comments

Comments
 (0)