Skip to content

Commit e9bafa3

Browse files
authored
[mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern (#114315)
This PR *restricts* `GeneralizeOuterUnitDimsPackOpPattern` to follow its intended purpose (as per the documentation), which is to: > require all outer dimensions of tensor.pack to be 1. There was one in-tree test that violated this assumption (and happened to work) – see `@simple_KCRS_to_KRSCsr` in "generalize-tensor-pack.mlir". That test has been updated to satisfy the new requirements of the pattern. By enforcing the pattern to follow its intended design (i.e., making it stricter), the calculation of shapes and sizes for various Ops that the pattern generates (PadOp, ExtractSliceOp, EmptyOp, TensorOp, and InsertSliceOp) becomes much simpler and easier to document. This also helped *generalize* the pattern to support cases like the one below: ```mlir func.func @simple_pad_and_pack_dynamic_tile_cst( %src: tensor<5x1xf32>, %dest: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> { %tile_dim_0 = arith.constant 8 : index %0 = tensor.pack %src padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32> return %0 : tensor<1x1x?x2xf32> } ``` Note that the inner tile slice is dynamic but compile-time constant. `getPackOpSourceOrPaddedSource`, which is used to generate PadOp, detects this and generates a PadOp with static shapes. This is a good optimization, but it means that all shapes/sizes for Ops generated by `GeneralizeOuterUnitDimsPackOpPattern` also need to be updated to be constant/static. By restricting the pattern and simplifying the size/shape calculation, supporting the case above becomes much easier. Notable implementation changes: * PadOp processes the original source (no change in dimensions/rank). ExtractSliceOp extracts the tile to pack and may reduce the rank. All following ops work on the tile extracted by ExtractSliceOp (possibly rank-reduced). * All shape/size calculations assume that trailing dimensions match inner_tiles from tensor.pack. All leading dimensions (i.e., outer dimensions) are assumed to be 1. * Dynamic sizes for ops like ExtractSliceOp are taken from inner_tiles rather than computed as, for example, tensor.dim %dest, 2. It’s the responsibility of the "producers" of tensor.pack to ensure that dimensions in %dest match the specified tile sizes.
1 parent 0276621 commit e9bafa3

File tree

5 files changed

+239
-85
lines changed

5 files changed

+239
-85
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
15151515
const SmallVector<Value> &dynSizes) const;
15161516
};
15171517

1518-
/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
1519-
/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
1520-
/// 1s.
1518+
/// Rewrites a tensor::PackOp into a sequence of:
1519+
/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
1520+
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
1521+
///
1522+
/// Required that all the outer dims of the input tensor::PackOp are 1.
1523+
///
1524+
/// Before:
1525+
/// ```
1526+
/// %packed = tensor.pack %input
1527+
/// padding_value(%pad : f32)
1528+
/// inner_dims_pos = [1, 0]
1529+
/// inner_tiles = [2, %high]
1530+
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
1531+
/// ```
1532+
///
1533+
/// After:
1534+
/// ```
1535+
/// // PadOp
1536+
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
1537+
/// ^bb0(...):
1538+
/// tensor.yield %arg2 : f32
1539+
/// } : tensor<5x1xf32> to tensor<?x2xf32>
1540+
/// // ExtractSliceOp
1541+
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
1542+
/// 1]
1543+
/// : tensor<?x2xf32> to tensor<?x2xf32>
1544+
/// // EmptyOp + TransposeOp
1545+
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
1546+
/// %transposed = linalg.transpose
1547+
/// ins(%extracted_slice : tensor<?x2xf32>)
1548+
/// outs(%empty : tensor<2x?xf32>)
1549+
/// permutation = [1, 0]
1550+
/// // InsertSliceOp
1551+
/// %inserted_slice = tensor.insert_slice %transposed
1552+
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
1553+
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
1554+
/// ```
15211555
struct GeneralizeOuterUnitDimsPackOpPattern
15221556
: public OpRewritePattern<tensor::PackOp> {
15231557
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
6060
SmallVectorImpl<Value> &dynamicVec,
6161
SmallVectorImpl<int64_t> &staticVec);
6262

63+
/// Given OpFoldResult representing dim size value (*), generates a pair of
64+
/// sizes:
65+
/// * 1st result, static value, contains an int64_t dim size that can be used
66+
/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
67+
/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
68+
/// actual dim size (either original or updated input value).
69+
/// For input sizes for which it is possible to extract a constant Attribute,
70+
/// replaces the original size value with an integer attribute (unless it's
71+
/// already a constant Attribute). The 1st return value also becomes the actual
72+
/// integer size (as opposed ShapedType::kDynamic).
73+
///
74+
/// (*) This hook is usually used when, given input sizes as OpFoldResult,
75+
/// it's required to generate two vectors:
76+
/// * sizes as int64_t to generate a shape,
77+
/// * sizes as OpFoldResult for sizes-like attribute.
78+
/// Please update this comment if you identify other use cases.
79+
std::pair<int64_t, OpFoldResult>
80+
getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);
81+
6382
/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
6483
template <typename IntTy>
6584
SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,75 +1142,100 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11421142
tensor::PackOp packOp, PatternRewriter &rewriter) const {
11431143
// TODO: support the case that outer dimensions are not all 1s. A
11441144
// tensor.expand_shape will be generated in this case.
1145-
if (llvm::any_of(packOp.getTiledOuterDims(),
1145+
if (llvm::any_of(packOp.getAllOuterDims(),
11461146
[](int64_t dim) { return dim != 1; })) {
11471147
return rewriter.notifyMatchFailure(
1148-
packOp, "require the tiled outer dimensions of the result are all 1s");
1148+
packOp, "not all outer dimensions of the result are 1s");
11491149
}
11501150

1151-
// 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
1152-
// outer dims.
1151+
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1152+
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11531153
Location loc = packOp.getLoc();
1154+
11541155
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
11551156
auto inputShape = packOp.getSourceType().getShape();
11561157
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
11571158
packOp.getDimAndTileMapping();
1158-
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1159-
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11601159
int64_t srcRank = packOp.getSourceRank();
1160+
1161+
int64_t destRank = packOp.getDestRank();
1162+
size_t numTiles = destRank - srcRank;
1163+
1164+
// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1165+
// %extracted_tile = tensor.extract_slice(%pack_op_input)
11611166
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
11621167
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1163-
SmallVector<OpFoldResult> readSizes;
1164-
SmallVector<OpFoldResult> transShapeForEmpty;
1165-
SmallVector<int64_t> readShapeForExtractSlice;
1168+
1169+
// The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
1170+
// all outer dims are 1.
1171+
SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
1172+
// The shape of the output for ExtractSliceOp. All leading unit dims are
1173+
// effectively rank-reduced, hence skipped.
1174+
SmallVector<int64_t> outputShapeForExtractSlice;
1175+
1176+
// Extract the trailing sizes and shape dims for ExtractSliceOp. These should
1177+
// be equal to the inner tile sizes.
11661178
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
11671179
if (dimAndTileMapping.count(i)) {
1168-
readShapeForExtractSlice.push_back(
1169-
getConstantIntValue(dimAndTileMapping[i])
1170-
.value_or(ShapedType::kDynamic));
1171-
readSizes.push_back(dimAndTileMapping[i]);
1172-
transShapeForEmpty.push_back(dimAndTileMapping[i]);
1173-
continue;
1174-
}
1175-
if (ShapedType::isDynamic(inputShape[i])) {
1176-
readSizes.push_back(
1177-
rewriter.create<tensor::DimOp>(loc, input, i).getResult());
1178-
} else {
1179-
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
1180-
}
1181-
if (inputShape[i] != 1) {
1182-
readShapeForExtractSlice.push_back(inputShape[i]);
1183-
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
1180+
auto [tileSize, tileSizeOfr] =
1181+
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1182+
extractSliceSizes.push_back(tileSizeOfr);
1183+
outputShapeForExtractSlice.push_back(tileSize);
11841184
}
11851185
}
11861186

11871187
Type elemType = packOp.getSourceType().getElementType();
1188-
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1188+
auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
11891189

11901190
Value tile = rewriter.create<tensor::ExtractSliceOp>(
1191-
loc, readType, input, readOffsets, readSizes, readStrides);
1191+
loc, readType, input, readOffsets, extractSliceSizes, readStrides);
11921192

1193-
// 2. Transpose the tile to match the inner tile order.
1193+
// 2. Transpose the tile to match the inner tile order:
1194+
// %init = tensor.empty()
1195+
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
1196+
// NOTE: Outer dims are 1 and hence effectively ignored.
11941197
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
11951198
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
11961199

11971200
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
11981201
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
11991202

1200-
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1203+
// 2.1 Create tensor.empty (init value for TransposeOp)
1204+
SmallVector<OpFoldResult> transShapeForEmptyOp;
12011205

1206+
// Acquire tensor shape required to create EmptyOp. This will match the inner
1207+
// tile sizes.
1208+
size_t idx = numTiles;
1209+
while (idx != 0) {
1210+
transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
1211+
idx--;
1212+
}
1213+
1214+
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
12021215
Value empty =
1203-
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1216+
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
1217+
1218+
// 2.2 Create linalg.transpose
12041219
auto transposedOp =
12051220
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
12061221

1207-
// 3. Insert the inner tile to the destination.
1208-
int64_t destRank = packOp.getDestRank();
1222+
// 3. Insert the inner tile to the destination:
1223+
// %inserted_tile = tensor.insert_slice(%transposed_tile)
12091224
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12101225
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1211-
SmallVector<OpFoldResult> writeSizes =
1212-
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
1226+
// Outer dims are all 1s!
1227+
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1228+
oneIdxAttr);
1229+
SmallVector<int64_t> writeShape;
1230+
1231+
for (auto tileSize : packOp.getMixedTiles()) {
1232+
auto [tileSizeStatic, tileSizeOfr] =
1233+
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1234+
writeSizes.push_back(tileSizeOfr);
1235+
writeShape.push_back(tileSizeStatic);
1236+
}
12131237

1238+
// 4. Replace tensor.packOp with tensor.insert_slice created above
12141239
auto insert = rewriter.create<tensor::InsertSliceOp>(
12151240
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
12161241
writeSizes, writeStrides);

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
5858
staticVec.push_back(ShapedType::kDynamic);
5959
}
6060

61+
std::pair<int64_t, OpFoldResult>
62+
getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
63+
int64_t tileSizeForShape =
64+
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
65+
66+
OpFoldResult tileSizeOfrSimplified =
67+
(tileSizeForShape != ShapedType::kDynamic)
68+
? b.getIndexAttr(tileSizeForShape)
69+
: tileSizeOfr;
70+
71+
return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
72+
tileSizeOfrSimplified);
73+
}
74+
6175
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
6276
SmallVectorImpl<Value> &dynamicVec,
6377
SmallVectorImpl<int64_t> &staticVec) {

0 commit comments

Comments
 (0)