Skip to content

Commit 58da789

Browse files
authored
[mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (llvm#119379)
1 parent 1be4a67 commit 58da789

File tree

2 files changed

+75
-31
lines changed

2 files changed

+75
-31
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,64 +1254,98 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12541254
"require the tiled outer dimensions of the result are all 1s");
12551255
}
12561256

1257-
// 1. Use rank-reduced tensor.extract_slice op to extract the tile.
1257+
// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1258+
// %extracted_tile = tensor.extract_slice(%unpack_op_input)
12581259
Location loc = unpackOp.getLoc();
12591260
Value source = unpackOp.getSource();
12601261
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
12611262
unpackOp.getDimAndTileMapping();
12621263
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
12631264
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1264-
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
1265-
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1266-
SmallVector<OpFoldResult> readSizes;
1267-
SmallVector<int64_t> readShape;
1268-
SmallVector<Value> dynamicDims;
1265+
1266+
// The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1267+
// dims:
1268+
// [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1269+
SmallVector<int64_t> readShapeForExtractSlice;
1270+
// The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1271+
// outer-tiled-dims being all 1), this will be
1272+
// [ outer-untiled-dims, tile-sizes ]
1273+
SmallVector<OpFoldResult> extractSliceSizes;
1274+
// The offset and strides attributes for ExtractSliceOp.
1275+
SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1276+
SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1277+
1278+
// Shape for EmptyOp that's used as the init value for TransposeOp below.
1279+
// This should be:
1280+
// [ outer-untiled-dims, tile-sizes ]
1281+
// However, skip unit dims - TransposeOp (below) applies rank-reduced
1282+
// permutation.
1283+
SmallVector<OpFoldResult> shapeForEmptyOp;
1284+
12691285
for (auto i : llvm::seq<unsigned>(0, destRank)) {
1286+
// Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1287+
//
1288+
// As all outer tiled dims are 1, so the corresponding
1289+
// slice size to read will also 1. As this will be rank-reducing "extract
1290+
// slice" (i.e. the unit dims will be "collapsed"), there's no need to
1291+
// update:
1292+
// * the output shape for ExtractSliceOp, nor
1293+
// * the shape for EmptyOp.
12701294
if (dimAndTileMapping.count(i)) {
1271-
readSizes.push_back(oneIdxAttr);
1295+
extractSliceSizes.push_back(oneIdxAttr);
12721296
continue;
12731297
}
12741298

1299+
// Compute sizes attribute for ExtractSliceOp + EmptyOp -
1300+
// outer-untiled-dims
12751301
if (ShapedType::isDynamic(srcShape[i])) {
1276-
Value dynamicDim =
1302+
OpFoldResult dynamicDim =
12771303
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1278-
readSizes.push_back(dynamicDim);
1279-
dynamicDims.push_back(dynamicDim);
1304+
extractSliceSizes.push_back(dynamicDim);
1305+
shapeForEmptyOp.push_back(dynamicDim);
12801306
} else {
1281-
readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1307+
extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1308+
if (srcShape[i] != 1)
1309+
shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1310+
}
1311+
// Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1312+
// into account rank-reducing)
1313+
if (srcShape[i] != 1) {
1314+
readShapeForExtractSlice.push_back(srcShape[i]);
12821315
}
1283-
if (srcShape[i] != 1)
1284-
readShape.push_back(srcShape[i]);
12851316
}
1317+
// Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1318+
// shape for EmptyOp.
12861319
auto mixedTiles = unpackOp.getMixedTiles();
1287-
readSizes.append(mixedTiles.begin(), mixedTiles.end());
1320+
extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1321+
shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
12881322

12891323
// Explicitly create the type for extract_slice op because the inner tile
12901324
// size could be 1. We want to represent the whole inner tile in this case.
12911325
auto tileShape = srcShape.drop_front(destRank);
12921326
// Append the inner tile shape to the permuted and rank-reduced outer shape.
1293-
readShape.append(tileShape.begin(), tileShape.end());
1327+
readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
12941328
Type elemType = unpackOp.getSourceType().getElementType();
1295-
auto readType = RankedTensorType::get(readShape, elemType);
1329+
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
12961330
Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1297-
loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1331+
loc, readType, unpackOp.getSource(), extractSliceOffsets,
1332+
extractSliceSizes, extractSliceStrides);
12981333

12991334
// 2. Transpose the tile to match the outer corresponding tile order.
13001335
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
13011336
srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
13021337
// Unpack is a transition out of packed space so we invert the permutation.
13031338
perm = invertPermutationVector(perm);
1304-
SmallVector<int64_t> transpShape(readShape);
1305-
applyPermutationToVector<int64_t>(transpShape, perm);
1339+
applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
13061340

13071341
Value empty =
1308-
rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1342+
rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
13091343
auto transposedOp =
13101344
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
13111345

13121346
// 3. Handle in-complete tiles if needed. It truncates trailing data from the
13131347
// transposed tile.
1314-
int numLoops = transpShape.size();
1348+
int numLoops = shapeForEmptyOp.size();
13151349
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
13161350
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
13171351
SmallVector<OpFoldResult> tileSizes;

mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tens
3535

3636
/// Same as example above, but with 1 dynamic tile size.
3737

38-
func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
39-
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
38+
func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
39+
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
4040
return %0 : tensor<5x1xf32>
4141
}
4242
// CHECK-LABEL: func.func @simple_unpack_dynamic_tile
4343
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
4444
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
45-
// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]
46-
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1]
45+
// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]]
46+
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM]], 2] [1, 1, 1, 1]
4747
// CHECK-NOT: linalg.transpose
4848
// They have the same type, so the insert_slice op is folded
4949
// away.
@@ -52,13 +52,23 @@ func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tens
5252

5353
/// Same as example above, but with 1 dynamic tile size and a trasnpose
5454

55-
/// FIXME: This is currently broken:
56-
/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1
55+
func.func @simple_unpack_dynamic_tile_transpose(%src: tensor<1x1x2x?xf32>, %dest: tensor<5x1xf32>, %tile_dim: index) -> tensor<5x1xf32> {
56+
%0 = tensor.unpack %src inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim] into %dest : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
57+
return %0 : tensor<5x1xf32>
58+
}
59+
// CHECK-LABEL: func.func @simple_unpack_dynamic_tile_transpose
60+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
61+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
62+
// CHECK-SAME: %[[TILE_DIM:[a-zA-Z0-9]+]]
63+
// CHECK: %[[TILE:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM]]] [1, 1, 1, 1] : tensor<1x1x2x?xf32> to tensor<2x?xf32>
64+
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM]]) : tensor<?x2xf32>
65+
// CHECK: %[[TRANSP:.*]] = linalg.transpose
66+
// CHECK-SAME: ins(%[[TILE]] : tensor<2x?xf32>)
67+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x2xf32>)
68+
// CHECK-SAME: permutation = [1, 0]
69+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[TRANSP]][0, 0] [5, 1] [1, 1] : tensor<?x2xf32> to tensor<5x1xf32>
70+
// CHECK: return %[[SLICE]] : tensor<5x1xf32>
5771

58-
//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
59-
// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
60-
// return %0 : tensor<5x1xf32>
61-
//}
6272

6373
/// Same as example above, but with 1 scalable tile size.
6474

0 commit comments

Comments
 (0)