Skip to content

[mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs #93055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 82 additions & 41 deletions mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
return success();
}

// If the `linalgOp` represents a transpose, return the permutation vector for
// the transpose. Otherwise, return failure.
static FailureOr<SmallVector<int64_t>>
getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
return SmallVector<int64_t>(transposeOp.getPermutation());
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
return failure();

if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return failure();
auto mapRange = linalgOp.getIndexingMapsArray();
if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
mapRange.front() == mapRange.back()) {
return failure();
}
if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious. If we don't add this restriction, we can still replace the block with linalg generic having a one-to-one map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we can keep it restricted to pure transpose ops for these patterns. Maybe we could add a new pattern for something like this in a later PR.

In general, a unary element-wise operation like that will get fused with it's producer/consumer in element-wise fusion. After this you end up with multi-input ops, and it may not be as simple to figure out whether it is beneficial to transpose the generic op. If you do this remapping before element-wise fusion, then I think there could be cases where you add extra transposing that would otherwise fold. Overall, I think it is not a bad idea to try to move the transpose work to the pack op, but I think it takes a bit more careful thought/matching, so I think it is better suited for a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could also be useful to use after dispatch formation, since pack ops often get fused with producers. I think it's definitely worth trying at some point.

return failure();
AffineMap outMap = mapRange.back();
AffineMap inMap = mapRange.front();
// To get the permutation, look at each output index and find which
// dimension in the input we're reading from for that index.
return llvm::map_to_vector(outMap.getResults(),
[&](AffineExpr expr) -> int64_t {
return *inMap.getResultPosition(expr);
});
}

/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
Expand Down Expand Up @@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,

for (unsigned int i = 0; i < rank; ++i) {
int64_t remappedPosition = permutation[i];

if (!inVec.empty()) {
if (remappedPosition >= rank) {
return false;
}
if (remappedPosition >= rank)
return false;
if (!inVec.empty())
remappedPosition = inVec[remappedPosition];
}

resVec.push_back(remappedPosition);
}

Expand All @@ -263,20 +287,25 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();

if (!packOp)
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
return failure();

auto innerDimsPos = packOp.getInnerDimsPos();
auto mixedInnerTiles = packOp.getMixedTiles();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto transposePerm = transposeOp.getPermutation();
auto transposePerm = maybePerm.value();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
Expand All @@ -285,7 +314,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
srcRank))
return rewriter.notifyMatchFailure(
transposeOp,
linalgOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");

Expand All @@ -297,11 +326,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
}

Value output = packOp.createDestinationTensor(
rewriter, transposeOp.getLoc(), packOp.getSource(),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
newInnerDimsPosVec, newOuterDimsPermVec);

rewriter.replaceOpWithNewOp<PackOp>(
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);

return success();
Expand All @@ -316,12 +345,16 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();

if (!transposeOp)
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
return failure();

auto transposePermutation = transposeOp.getPermutation();
auto transposePermutation = maybePerm.value();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
Expand All @@ -337,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
newInnerDimsPosVec.push_back(transposePermutation[dim]);

Value output = packOp.createDestinationTensor(
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
rewriter, packOp.getLoc(), linalgOp->getOperand(0),
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);

rewriter.replaceOpWithNewOp<PackOp>(
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);

return success();
Expand All @@ -351,34 +384,38 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
: public OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();

if (!unPackOp)
return failure();

auto transposePermutation = transposeOp.getPermutation();
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
return failure();

auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<int64_t> newOuterDimsPermVec =
llvm::to_vector(transposePermutation);

if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
invertPermutationVector(maybePerm.value());

// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
for (auto dim : innerDimsPos)
newInnerDimsPosVec.push_back(transposePermutation[dim]);
newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);

if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);

// Reuse the destination of the transpose op.
rewriter.replaceOpWithNewOp<UnPackOp>(
transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);

return success();
Expand All @@ -393,13 +430,17 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp

LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
auto transposeOp =
unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();

if (!transposeOp)
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
return failure();

auto transposePermutation = transposeOp.getPermutation();
SmallVector<int64_t> inverseTransposePerm =
invertPermutationVector(maybePerm.value());
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
Expand All @@ -408,26 +449,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;

if (!checkAndPermute(transposePermutation, outerDimsPerm,
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
unPackOp,
"Cannot fold in tensor.unpack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");

// Process transpose operation for tiled inner dimensions
for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
int64_t remappedPosition = transposePermutation[i] - destRank;
for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
int64_t remappedPosition = inverseTransposePerm[i] - destRank;
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}

Value output = unPackOp.createDestinationTensor(
rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);

rewriter.replaceOpWithNewOp<UnPackOp>(
unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
newMixedInnerTilesVec, newOuterDimsPermVec);

return success();
Expand Down
139 changes: 139 additions & 0 deletions mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
// CHECK: }

// -----

func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
outs(%0 : tensor<5x2x3x16x4xi32>)
permutation = [2, 0, 1, 4, 3]
%1 = tensor.empty() : tensor<5x48x8xi32>
%unpack = tensor.unpack %transposed
outer_dims_perm = [0, 2, 1]
inner_dims_pos = [1, 2]
inner_tiles = [16, 4] into
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
return %unpack : tensor<5x48x8xi32>
}
//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
// CHECK-SAME: inner_dims_pos = [2, 1]
// CHECK-SAME: inner_tiles = [4, 16]
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
// CHECK: }

// -----

func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
%0 = tensor.empty() : tensor<3x56x3648xf32>
%unpack = tensor.unpack %arg0
outer_dims_perm = [2, 0, 1]
inner_dims_pos = [1, 2]
inner_tiles = [1, 64]
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>

%1 = tensor.empty() : tensor<3648x3x56xf32>
%transposed = linalg.transpose
ins(%unpack : tensor<3x56x3648xf32>)
outs(%1 : tensor<3648x3x56xf32>)
permutation = [2, 0, 1]
return %transposed : tensor<3648x3x56xf32>
}
// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
// CHECK-SAME: inner_dims_pos = [2, 0]
// CHECK-SAME: inner_tiles = [1, 64]
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
// CHECK: }

// -----

func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
%transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
outs(%0 : tensor<5x2x3x16x4xi32>)
permutation = [2, 0, 4, 1, 3]
%1 = tensor.empty() : tensor<5x32x12xi32>
%unpack = tensor.unpack %transposed
inner_dims_pos = [1, 2]
inner_tiles = [16, 4] into
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
return %unpack : tensor<5x32x12xi32>
}
//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
// CHECK: linalg.transpose
// CHECK: tensor.unpack

// -----

#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
%transposed = linalg.generic {
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
indexing_maps = [#map, #map1]}
ins(%arg0 : tensor<2x3x5x4x16xi32>)
outs(%0 : tensor<5x2x3x16x4xi32>) {
^bb0(%in : i32, %out : i32):
linalg.yield %in : i32
} -> tensor<5x2x3x16x4xi32>
%1 = tensor.empty() : tensor<5x48x8xi32>
%unpack = tensor.unpack %transposed
outer_dims_perm = [0, 2, 1]
inner_dims_pos = [1, 2]
inner_tiles = [16, 4] into
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
return %unpack : tensor<5x48x8xi32>
}
//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
// CHECK-SAME: inner_dims_pos = [2, 1]
// CHECK-SAME: inner_tiles = [4, 16]
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
// CHECK: }

// -----

#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
%0 = tensor.empty() : tensor<3x56x3648xf32>
%unpack = tensor.unpack %arg0
outer_dims_perm = [2, 0, 1]
inner_dims_pos = [1, 2]
inner_tiles = [1, 64]
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>

%1 = tensor.empty() : tensor<3648x3x56xf32>
%transposed = linalg.generic {
iterator_types = ["parallel", "parallel", "parallel"],
indexing_maps = [#map, #map1]}
ins(%unpack : tensor<3x56x3648xf32>)
outs(%1 : tensor<3648x3x56xf32>) {
^bb0(%in : f32, %out : f32):
linalg.yield %in : f32
} -> tensor<3648x3x56xf32>
return %transposed : tensor<3648x3x56xf32>
}
// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
// CHECK-SAME: inner_dims_pos = [2, 0]
// CHECK-SAME: inner_tiles = [1, 64]
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
// CHECK: }
Loading