Skip to content

[mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern #114315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
const SmallVector<Value> &dynSizes) const;
};

/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
/// 1s.
/// Rewrites a tensor::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
///
/// Required that all the outer dims of the input tensor::PackOp are 1.
///
/// Before:
/// ```
/// %packed = tensor.pack %input
/// padding_value(%pad : f32)
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, %high]
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
/// ```
///
/// After:
/// ```
/// // PadOp
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
/// ^bb0(...):
/// tensor.yield %arg2 : f32
/// } : tensor<5x1xf32> to tensor<?x2xf32>
/// // ExtractSliceOp
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
/// 1]
Comment on lines +1540 to +1542
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be in the scope of this PR to change this, but what exactly is this extract_slice for? It seems to me like it is just taking a full slice. Is the slice ever not full?

/// : tensor<?x2xf32> to tensor<?x2xf32>
/// // EmptyOp + TransposeOp
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
/// %transposed = linalg.transpose
/// ins(%extracted_slice : tensor<?x2xf32>)
/// outs(%empty : tensor<2x?xf32>)
/// permutation = [1, 0]
/// // InsertSliceOp
/// %inserted_slice = tensor.insert_slice %transposed
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
/// ```
struct GeneralizeOuterUnitDimsPackOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec);

/// Given OpFoldResult representing dim size value (*), generates a pair of
/// sizes:
/// * 1st result, static value, contains an int64_t dim size that can be used
/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
/// actual dim size (either original or updated input value).
/// For input sizes for which it is possible to extract a constant Attribute,
/// replaces the original size value with an integer attribute (unless it's
/// already a constant Attribute). The 1st return value also becomes the actual
/// integer size (as opposed ShapedType::kDynamic).
///
/// (*) This hook is usually used when, given input sizes as OpFoldResult,
/// it's required to generate two vectors:
/// * sizes as int64_t to generate a shape,
/// * sizes as OpFoldResult for sizes-like attribute.
/// Please update this comment if you identify other use cases.
std::pair<int64_t, OpFoldResult>
getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);

/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
template <typename IntTy>
SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
Expand Down
93 changes: 59 additions & 34 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,75 +1142,100 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
if (llvm::any_of(packOp.getTiledOuterDims(),
if (llvm::any_of(packOp.getAllOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "require the tiled outer dimensions of the result are all 1s");
packOp, "not all outer dimensions of the result are 1s");
}

// 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
// outer dims.
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
int64_t srcRank = packOp.getSourceRank();

int64_t destRank = packOp.getDestRank();
size_t numTiles = destRank - srcRank;

// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on the advantage of slicing out the inner tile here? We end up reintroducing the outer unit dims shortly after with the tensor.insert_slice so this only serves to remove the unit dims from the linalg.transpose for permuting the inner dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am merely a messenger here ;-) (as in, tensor.extract_slice is already a part of this logic, not something added by me).

To me, it makes more sense to only transpose the tile worth of data as that's the intent of tensor.pack, right? But you will have more experience with this logic. I just wanted to preserve the original logic as much as possible and to avoid making too many changes in one PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah really, that makes sense then. If it's preserving existing behavior leaving it is fine, but should look into removing it at some point.

// %extracted_tile = tensor.extract_slice(%pack_op_input)
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<OpFoldResult> transShapeForEmpty;
SmallVector<int64_t> readShapeForExtractSlice;

// The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
// all outer dims are 1.
SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
// The shape of the output for ExtractSliceOp. All leading unit dims are
// effectively rank-reduced, hence skipped.
SmallVector<int64_t> outputShapeForExtractSlice;

// Extract the trailing sizes and shape dims for ExtractSliceOp. These should
// be equal to the inner tile sizes.
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
readShapeForExtractSlice.push_back(
getConstantIntValue(dimAndTileMapping[i])
.value_or(ShapedType::kDynamic));
readSizes.push_back(dimAndTileMapping[i]);
transShapeForEmpty.push_back(dimAndTileMapping[i]);
continue;
}
if (ShapedType::isDynamic(inputShape[i])) {
readSizes.push_back(
rewriter.create<tensor::DimOp>(loc, input, i).getResult());
} else {
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
}
if (inputShape[i] != 1) {
readShapeForExtractSlice.push_back(inputShape[i]);
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
auto [tileSize, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
extractSliceSizes.push_back(tileSizeOfr);
outputShapeForExtractSlice.push_back(tileSize);
}
}

Type elemType = packOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);

Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, input, readOffsets, readSizes, readStrides);
loc, readType, input, readOffsets, extractSliceSizes, readStrides);

// 2. Transpose the tile to match the inner tile order.
// 2. Transpose the tile to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
// NOTE: Outer dims are 1 and hence effectively ignored.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());

LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););

applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOp;

// Acquire tensor shape required to create EmptyOp. This will match the inner
// tile sizes.
size_t idx = numTiles;
while (idx != 0) {
transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
idx--;
}

applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
Value empty =
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);

// 2.2 Create linalg.transpose
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);

// 3. Insert the inner tile to the destination.
int64_t destRank = packOp.getDestRank();
// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
SmallVector<OpFoldResult> writeSizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
oneIdxAttr);
SmallVector<int64_t> writeShape;

for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
writeShape.push_back(tileSizeStatic);
}

// 4. Replace tensor.packOp with tensor.insert_slice created above
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
writeSizes, writeStrides);
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(ShapedType::kDynamic);
}

std::pair<int64_t, OpFoldResult>
getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
int64_t tileSizeForShape =
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);

OpFoldResult tileSizeOfrSimplified =
(tileSizeForShape != ShapedType::kDynamic)
? b.getIndexAttr(tileSizeForShape)
: tileSizeOfr;

return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
tileSizeOfrSimplified);
}

void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
Expand Down
Loading
Loading