Skip to content

Commit e56f5cb

Browse files
committed
[mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern
This PR fixes an issue identified in llvm#118786, where the following example triggers a verification error: ```mlir func.func @foo( %src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> { %0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32> return %0 : tensor<5x1xf32> } ``` The error occurs because of faulty logic when computing dynamic sizes for `tensor::EmptyOp`, which initializes tensors for `linalg::transpose`. This specific example fails due to: * Dynamic inner tile size, combined with * Non-identity permutation. For reference, here's the verification error: ```bash error: 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1 ``` and here's the problematic `tensor.empty` (printed in generic form): ```mlir %1 = "tensor.empty"() : () -> tensor<?x2xf32> ``` **Fix** This PR refactors how dynamic sizes for `tensor::EmptyOp` are computed. Instead of generating a separate vector of dynamic sizes to complement the output shape, this PR adopts a common MLIR idiom: passing a vector of `OpFoldResult`s directly to the `EmptyOp` builder. Previously, only dynamic sizes corresponding to outer dimensions were tracked (see `dynamicSizes`), while inner dimensions were skipped, causing issues in certain cases. **Refactoring changes** Variable names have been updated for better readability: * `readShape` → `readShapeForExtractSlice` * `readSizes` → `extractSliceSizes` * `readStrides` → `stridesForExtractSlice` Additional comments have been added for clarity. **Remaining inconsistencies** Patterns for `tensor::PackOp` and `tensor::UnpackOp` remain partially inconsistent: `DecomposeOuterUnitDimsPackOpPattern` enforces that all outer dimensions must be unit-sized, while `DecomposeOuterUnitDimsUnPackOpPattern` does not. The two implementations have diverged in logic. I plan to address these inconsistencies in a follow-up PR to further unify these patterns.
1 parent 1885886 commit e56f5cb

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,64 +1252,88 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12521252
"require the tiled outer dimensions of the result are all 1s");
12531253
}
12541254

1255-
// 1. Use rank-reduced tensor.extract_slice op to extract the tile.
1255+
// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1256+
// %extracted_tile = tensor.extract_slice(%unpack_op_input)
12561257
Location loc = unpackOp.getLoc();
12571258
Value source = unpackOp.getSource();
12581259
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
12591260
unpackOp.getDimAndTileMapping();
12601261
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
12611262
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1262-
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
1263-
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1264-
SmallVector<OpFoldResult> readSizes;
1265-
SmallVector<int64_t> readShape;
1266-
SmallVector<Value> dynamicDims;
1263+
1264+
// The sizes, affset and strides attributes for ExtractSliceOp.
1265+
SmallVector<OpFoldResult> extractSliceSizes;
1266+
SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1267+
SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1268+
// The shape for ExtractSliceOp (due to rank-reducing, this is likely !=
1269+
// extractSliceSizes).
1270+
SmallVector<int64_t> readShapeForExtractSlice;
1271+
1272+
// Shape for EmptyOp that's used as the init value for TransposeOp below.
1273+
// This should match tile size + transposition.
1274+
SmallVector<OpFoldResult> shapeForEmptyOp;
1275+
12671276
for (auto i : llvm::seq<unsigned>(0, destRank)) {
1277+
// Given the assumption that all outer tiled dims are 1, the corresponding
1278+
// slice size to read is also 1. As this will be rank-reducing "extract
1279+
// slice" (i.e. the unit dims will be "collapsed"), there's no need to
1280+
// update:
1281+
// * the output shape for ExtractSliceOp, nor
1282+
// * the shape for EmptyOp.
12681283
if (dimAndTileMapping.count(i)) {
1269-
readSizes.push_back(oneIdxAttr);
1284+
extractSliceSizes.push_back(oneIdxAttr);
12701285
continue;
12711286
}
12721287

1288+
// Compute sizes attribute for ExtractSliceOp + EmptyOp
12731289
if (ShapedType::isDynamic(srcShape[i])) {
1274-
Value dynamicDim =
1290+
OpFoldResult dynamicDim =
12751291
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1276-
readSizes.push_back(dynamicDim);
1277-
dynamicDims.push_back(dynamicDim);
1292+
extractSliceSizes.push_back(dynamicDim);
1293+
shapeForEmptyOp.push_back(dynamicDim);
12781294
} else {
1279-
readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1295+
extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1296+
if (srcShape[i] != 1)
1297+
shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1298+
}
1299+
// Compute the output shape for ExtractSliceOp (take into account
1300+
// rank-reducing)
1301+
if (srcShape[i] != 1) {
1302+
readShapeForExtractSlice.push_back(srcShape[i]);
12801303
}
1281-
if (srcShape[i] != 1)
1282-
readShape.push_back(srcShape[i]);
12831304
}
12841305
auto mixedTiles = unpackOp.getMixedTiles();
1285-
readSizes.append(mixedTiles.begin(), mixedTiles.end());
1306+
// TODO: This effectively assumes that that tile sizes match the trailing
1307+
// sizes for ExtractSliceOp and EmptyOp - document this.
1308+
extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1309+
shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
12861310

12871311
// Explicitly create the type for extract_slice op because the inner tile
12881312
// size could be 1. We want to represent the whole inner tile in this case.
12891313
auto tileShape = srcShape.drop_front(destRank);
12901314
// Append the inner tile shape to the permuted and rank-reduced outer shape.
1291-
readShape.append(tileShape.begin(), tileShape.end());
1315+
readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
12921316
Type elemType = unpackOp.getSourceType().getElementType();
1293-
auto readType = RankedTensorType::get(readShape, elemType);
1317+
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
12941318
Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1295-
loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1319+
loc, readType, unpackOp.getSource(), extractSliceOffsets,
1320+
extractSliceSizes, extractSliceStrides);
12961321

12971322
// 2. Transpose the tile to match the outer corresponding tile order.
12981323
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
12991324
srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
13001325
// Unpack is a transition out of packed space so we invert the permutation.
13011326
perm = invertPermutationVector(perm);
1302-
SmallVector<int64_t> transpShape(readShape);
1303-
applyPermutationToVector<int64_t>(transpShape, perm);
1327+
applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
13041328

13051329
Value empty =
1306-
rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1330+
rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
13071331
auto transposedOp =
13081332
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
13091333

13101334
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
13111335
// transposed tile.
1312-
int numLoops = transpShape.size();
1336+
int numLoops = shapeForEmptyOp.size();
13131337
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
13141338
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
13151339
SmallVector<OpFoldResult> tileSizes;

mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tens
3535

3636
/// Same as example above, but with 1 dynamic tile size.
3737

38-
func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
39-
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
38+
func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
39+
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
4040
return %0 : tensor<5x1xf32>
4141
}
4242
// CHECK-LABEL: func.func @simple_unpack_dynamic_tile
4343
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
4444
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
45-
// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]
46-
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1]
45+
// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]]
46+
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM]], 2] [1, 1, 1, 1]
4747
// CHECK-NOT: linalg.transpose
4848
// They have the same type, so the insert_slice op is folded
4949
// away.
@@ -52,13 +52,23 @@ func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tens
5252

5353
/// Same as example above, but with 1 dynamic tile size and a trasnpose
5454

55-
/// FIXME: This is currently broken:
56-
/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1
55+
func.func @simple_unpack_dynamic_tile_transpose(%src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
56+
%0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
57+
return %0 : tensor<5x1xf32>
58+
}
59+
// CHECK-LABEL: func.func @simple_unpack_dynamic_tile_transpose
60+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
61+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
62+
// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]]
63+
// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM]]] [1, 1, 1, 1] : tensor<1x1x2x?xf32> to tensor<2x?xf32>
64+
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM]]) : tensor<?x2xf32>
65+
// CHECK: %[[TRANSP:.*]] = linalg.transpose
66+
// CHECK-SAME: ins(%[[TILE]] : tensor<2x?xf32>)
67+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x2xf32>)
68+
// CHECK-SAME: permutation = [1, 0]
69+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1] : tensor<?x2xf32> to tensor<5x1xf32>
70+
// CHECK: return %[[SLICE]] : tensor<5x1xf32>
5771

58-
//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
59-
// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
60-
// return %0 : tensor<5x1xf32>
61-
//}
6272

6373
/// Same as example above, but with 1 scalable tile size.
6474

0 commit comments

Comments
 (0)