Skip to content

Commit ad4ed60

Browse files
author
git apple-llvm automerger
committed
Merge commit 'e9bafa35d270' from llvm.org/main into next
2 parents 548b7fc + e9bafa3 commit ad4ed60

File tree

5 files changed

+239
-85
lines changed

5 files changed

+239
-85
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
15151515
const SmallVector<Value> &dynSizes) const;
15161516
};
15171517

1518-
/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
1519-
/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
1520-
/// 1s.
1518+
/// Rewrites a tensor::PackOp into a sequence of:
1519+
/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
1520+
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
1521+
///
1522+
/// Required that all the outer dims of the input tensor::PackOp are 1.
1523+
///
1524+
/// Before:
1525+
/// ```
1526+
/// %packed = tensor.pack %input
1527+
/// padding_value(%pad : f32)
1528+
/// inner_dims_pos = [1, 0]
1529+
/// inner_tiles = [2, %high]
1530+
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
1531+
/// ```
1532+
///
1533+
/// After:
1534+
/// ```
1535+
/// // PadOp
1536+
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
1537+
/// ^bb0(...):
1538+
/// tensor.yield %arg2 : f32
1539+
/// } : tensor<5x1xf32> to tensor<?x2xf32>
1540+
/// // ExtractSliceOp
1541+
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
1542+
/// 1]
1543+
/// : tensor<?x2xf32> to tensor<?x2xf32>
1544+
/// // EmptyOp + TransposeOp
1545+
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
1546+
/// %transposed = linalg.transpose
1547+
/// ins(%extracted_slice : tensor<?x2xf32>)
1548+
/// outs(%empty : tensor<2x?xf32>)
1549+
/// permutation = [1, 0]
1550+
/// // InsertSliceOp
1551+
/// %inserted_slice = tensor.insert_slice %transposed
1552+
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
1553+
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
1554+
/// ```
15211555
struct GeneralizeOuterUnitDimsPackOpPattern
15221556
: public OpRewritePattern<tensor::PackOp> {
15231557
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
6060
SmallVectorImpl<Value> &dynamicVec,
6161
SmallVectorImpl<int64_t> &staticVec);
6262

63+
/// Given OpFoldResult representing dim size value (*), generates a pair of
64+
/// sizes:
65+
/// * 1st result, static value, contains an int64_t dim size that can be used
66+
/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
67+
/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
68+
/// actual dim size (either original or updated input value).
69+
/// For input sizes for which it is possible to extract a constant Attribute,
70+
/// replaces the original size value with an integer attribute (unless it's
71+
/// already a constant Attribute). The 1st return value also becomes the actual
72+
/// integer size (as opposed ShapedType::kDynamic).
73+
///
74+
/// (*) This hook is usually used when, given input sizes as OpFoldResult,
75+
/// it's required to generate two vectors:
76+
/// * sizes as int64_t to generate a shape,
77+
/// * sizes as OpFoldResult for sizes-like attribute.
78+
/// Please update this comment if you identify other use cases.
79+
std::pair<int64_t, OpFoldResult>
80+
getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);
81+
6382
/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
6483
template <typename IntTy>
6584
SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,75 +1142,100 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11421142
tensor::PackOp packOp, PatternRewriter &rewriter) const {
11431143
// TODO: support the case that outer dimensions are not all 1s. A
11441144
// tensor.expand_shape will be generated in this case.
1145-
if (llvm::any_of(packOp.getTiledOuterDims(),
1145+
if (llvm::any_of(packOp.getAllOuterDims(),
11461146
[](int64_t dim) { return dim != 1; })) {
11471147
return rewriter.notifyMatchFailure(
1148-
packOp, "require the tiled outer dimensions of the result are all 1s");
1148+
packOp, "not all outer dimensions of the result are 1s");
11491149
}
11501150

1151-
// 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
1152-
// outer dims.
1151+
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1152+
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11531153
Location loc = packOp.getLoc();
1154+
11541155
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
11551156
auto inputShape = packOp.getSourceType().getShape();
11561157
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
11571158
packOp.getDimAndTileMapping();
1158-
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1159-
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
11601159
int64_t srcRank = packOp.getSourceRank();
1160+
1161+
int64_t destRank = packOp.getDestRank();
1162+
size_t numTiles = destRank - srcRank;
1163+
1164+
// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1165+
// %extracted_tile = tensor.extract_slice(%pack_op_input)
11611166
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
11621167
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1163-
SmallVector<OpFoldResult> readSizes;
1164-
SmallVector<OpFoldResult> transShapeForEmpty;
1165-
SmallVector<int64_t> readShapeForExtractSlice;
1168+
1169+
// The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
1170+
// all outer dims are 1.
1171+
SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
1172+
// The shape of the output for ExtractSliceOp. All leading unit dims are
1173+
// effectively rank-reduced, hence skipped.
1174+
SmallVector<int64_t> outputShapeForExtractSlice;
1175+
1176+
// Extract the trailing sizes and shape dims for ExtractSliceOp. These should
1177+
// be equal to the inner tile sizes.
11661178
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
11671179
if (dimAndTileMapping.count(i)) {
1168-
readShapeForExtractSlice.push_back(
1169-
getConstantIntValue(dimAndTileMapping[i])
1170-
.value_or(ShapedType::kDynamic));
1171-
readSizes.push_back(dimAndTileMapping[i]);
1172-
transShapeForEmpty.push_back(dimAndTileMapping[i]);
1173-
continue;
1174-
}
1175-
if (ShapedType::isDynamic(inputShape[i])) {
1176-
readSizes.push_back(
1177-
rewriter.create<tensor::DimOp>(loc, input, i).getResult());
1178-
} else {
1179-
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
1180-
}
1181-
if (inputShape[i] != 1) {
1182-
readShapeForExtractSlice.push_back(inputShape[i]);
1183-
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
1180+
auto [tileSize, tileSizeOfr] =
1181+
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1182+
extractSliceSizes.push_back(tileSizeOfr);
1183+
outputShapeForExtractSlice.push_back(tileSize);
11841184
}
11851185
}
11861186

11871187
Type elemType = packOp.getSourceType().getElementType();
1188-
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1188+
auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
11891189

11901190
Value tile = rewriter.create<tensor::ExtractSliceOp>(
1191-
loc, readType, input, readOffsets, readSizes, readStrides);
1191+
loc, readType, input, readOffsets, extractSliceSizes, readStrides);
11921192

1193-
// 2. Transpose the tile to match the inner tile order.
1193+
// 2. Transpose the tile to match the inner tile order:
1194+
// %init = tensor.empty()
1195+
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
1196+
// NOTE: Outer dims are 1 and hence effectively ignored.
11941197
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
11951198
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
11961199

11971200
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
11981201
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
11991202

1200-
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1203+
// 2.1 Create tensor.empty (init value for TransposeOp)
1204+
SmallVector<OpFoldResult> transShapeForEmptyOp;
12011205

1206+
// Acquire tensor shape required to create EmptyOp. This will match the inner
1207+
// tile sizes.
1208+
size_t idx = numTiles;
1209+
while (idx != 0) {
1210+
transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
1211+
idx--;
1212+
}
1213+
1214+
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
12021215
Value empty =
1203-
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1216+
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
1217+
1218+
// 2.2 Create linalg.transpose
12041219
auto transposedOp =
12051220
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
12061221

1207-
// 3. Insert the inner tile to the destination.
1208-
int64_t destRank = packOp.getDestRank();
1222+
// 3. Insert the inner tile to the destination:
1223+
// %inserted_tile = tensor.insert_slice(%transposed_tile)
12091224
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
12101225
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1211-
SmallVector<OpFoldResult> writeSizes =
1212-
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
1226+
// Outer dims are all 1s!
1227+
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1228+
oneIdxAttr);
1229+
SmallVector<int64_t> writeShape;
1230+
1231+
for (auto tileSize : packOp.getMixedTiles()) {
1232+
auto [tileSizeStatic, tileSizeOfr] =
1233+
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1234+
writeSizes.push_back(tileSizeOfr);
1235+
writeShape.push_back(tileSizeStatic);
1236+
}
12131237

1238+
// 4. Replace tensor.packOp with tensor.insert_slice created above
12141239
auto insert = rewriter.create<tensor::InsertSliceOp>(
12151240
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
12161241
writeSizes, writeStrides);

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
5858
staticVec.push_back(ShapedType::kDynamic);
5959
}
6060

61+
std::pair<int64_t, OpFoldResult>
62+
getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
63+
int64_t tileSizeForShape =
64+
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
65+
66+
OpFoldResult tileSizeOfrSimplified =
67+
(tileSizeForShape != ShapedType::kDynamic)
68+
? b.getIndexAttr(tileSizeForShape)
69+
: tileSizeOfr;
70+
71+
return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
72+
tileSizeOfrSimplified);
73+
}
74+
6175
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
6276
SmallVectorImpl<Value> &dynamicVec,
6377
SmallVectorImpl<int64_t> &staticVec) {

0 commit comments

Comments
 (0)