Skip to content

[MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops #85297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 302 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>

Expand Down Expand Up @@ -552,6 +553,305 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
ControlPropagationFn controlFn;
};

/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
///
/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
/// non-unit projected dims in pos [2, 3] is 2.
///
/// If all candidates in a reassociation are unit dims, it chooses the
/// inner-most dim pos.
static SmallVector<int64_t>
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
ArrayRef<ReassociationIndices> reassocIndices,
ArrayRef<int64_t> targetShape) {
SmallVector<int64_t> projectedDimsPos;
for (auto pos : dimsPos) {
// In the case all dims are unit, this will return the inner-most one.
int64_t projectedPos = reassocIndices[pos].back();
for (auto i : llvm::reverse(reassocIndices[pos])) {
int64_t dim = targetShape[i];
if (dim > 1 || ShapedType::isDynamic(dim)) {
projectedPos = i;
break;
}
}
projectedDimsPos.push_back(projectedPos);
}
return projectedDimsPos;
}

/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> tileSizes) {
for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
int64_t dim = shape[pos];
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
return false;
}
return true;
}

/// Permutate the reassociation indices and reindex them in the sequence order.
/// Returns the next dim pos in the sequence.
///
/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
/// [[0], [1, 2]].
static int64_t applyPermutationAndReindexReassoc(
SmallVector<ReassociationIndices> &reassocIndices,
ArrayRef<int64_t> permutation) {
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
int64_t nextPos = 0;
for (ReassociationIndices &indices : reassocIndices) {
for (auto &index : indices) {
index = nextPos;
nextPos += 1;
}
}
return nextPos;
}

/// Bubble up pack op through collapse shape op when the packed dims can be
/// projected to the dims before collapsing. This is possible when the inner
/// tile sizes can divide the projected dims.
///
/// For example:
///
/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
/// : tensor<?x16x4xf32> into tensor<?x4xf32>
/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
///
/// can be transformed into:
///
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
static LogicalResult
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
tensor::PackOp packOp,
PatternRewriter &rewriter) {
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();

ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
collapseOp.getReassociationIndices();
// Project inner tile pos to the dim pos before collapsing. For example, if
// dims [x, y] is collapsed into [z], packing on dim z can be projected back
// to pack on dim y.
//
// Project to inner-most non-unit dims to increase the chance that they can be
// divided by the inner tile sizes. This is correct because for [..., x, 1],
// packing on dim 1 is equivalent to packing on dim x.
SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);

if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
innerTileSizes)) {
return failure();
}
// Expand the outer dims permutation with the associated source dims for the
// new permutation after bubbling. This is because moving a collapsed dim is
// equivalent to moving the associated source dims together.
SmallVector<int64_t> newOuterDimsPerm;
for (auto outerPos : outerDimsPerm) {
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
reassocIndices[outerPos].begin(),
reassocIndices[outerPos].end());
}

auto emptyOp = tensor::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
projectedInnerDimsPos, newOuterDimsPerm);
auto newPackOp = rewriter.create<tensor::PackOp>(
packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);

SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
// First apply the permutation on the reassociations of the outer dims.
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
// -> [[0], [1, 2]]
int64_t nextPos =
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
// Then add direct mapping for the inner tile dims.
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
newReassocIndices.push_back({nextPos});
nextPos += 1;
}

auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
rewriter.replaceOp(packOp, newCollapseOp);

return success();
}

class BubbleUpPackOpThroughReshapeOp final
: public OpRewritePattern<tensor::PackOp> {
public:
BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
Operation *srcOp = packOp.getSource().getDefiningOp();
// Currently only support when the pack op is the only user.
if (!srcOp || !(srcOp->getNumResults() == 1) ||
!srcOp->getResult(0).hasOneUse()) {
return failure();
}
// Currently only support static inner tile sizes.
if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
return ShapedType::isDynamic(size);
})) {
return failure();
}

// User controlled propagation function.
if (!controlFn(srcOp))
return failure();

return TypeSwitch<Operation *, LogicalResult>(srcOp)
.Case([&](tensor::CollapseShapeOp op) {
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
})
.Default([](Operation *) { return failure(); });
}

private:
ControlPropagationFn controlFn;
};

/// Push down unpack op through expand shape op when the packed dims can be
/// projected to the dims after expanding. This is possible when the inner tile
/// sizes can divide the projected dims.
///
/// For example:
///
/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
/// : tensor<?x256xf32> into tensor<?x256x256xf32>
///
/// can be transformed into:
///
/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
static LogicalResult
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) {
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();

ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
expandOp.getReassociationIndices();
// Project inner tile pos to the dim pos after expanding. For example, if dims
// [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
// on dim y.
//
// Project to inner-most non-unit dims to increase the chance that they can be
// divided by the inner tile sizes. This is correct because for [..., x, 1],
// unpacking on dim 1 is equivalent to unpacking on dim x.
SmallVector<int64_t> projectedInnerDimsPos =
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);

if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
innerTileSizes)) {
return failure();
}
// Expand the outer dims permutation with the associated expanded dims for the
// new permutation after pushing. This is because moving a source dim is
// equivalent to moving the associated expanded dims together.
SmallVector<int64_t> newOuterDimsPerm;
for (auto outerPos : outerDimsPerm) {
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
reassocIndices[outerPos].begin(),
reassocIndices[outerPos].end());
}

SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
// First apply the permutation on the reassociations of the outer dims.
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
// -> [[0], [1, 2]]
int64_t nextPos =
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
// Then add direct mapping for the inner tile dims.
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
newReassocIndices.push_back({nextPos});
nextPos += 1;
}

RankedTensorType newExpandType =
tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
projectedInnerDimsPos, newOuterDimsPerm);
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
newReassocIndices);

auto emptyOp = tensor::UnPackOp::createDestinationTensor(
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
projectedInnerDimsPos, newOuterDimsPerm);
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
rewriter.replaceOp(expandOp, newUnPackOp);

return success();
}

class PushDownUnPackOpThroughReshapeOp final
: public OpRewritePattern<tensor::UnPackOp> {
public:
PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
}

LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
Value result = unPackOp.getResult();
// Currently only support unpack op with the single user.
if (!result.hasOneUse()) {
return failure();
}
// Currently only support static inner tile sizes.
if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
return ShapedType::isDynamic(size);
})) {
return failure();
}

Operation *consumerOp = *result.user_begin();
// User controlled propagation function.
if (!controlFn(consumerOp))
return failure();

return TypeSwitch<Operation *, LogicalResult>(consumerOp)
.Case([&](tensor::ExpandShapeOp op) {
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
})
.Default([](Operation *) { return failure(); });
}

private:
ControlPropagationFn controlFn;
};

// TODO: Relax this restriction. We should unpack a generic op also
// in the presence of multiple unpack ops as producers.
/// Return the unpacked operand, if present, for the current generic op.
Expand Down Expand Up @@ -774,6 +1074,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
Loading