Skip to content

[mlir][vector] Split TransposeOpLowering into 2 patterns #91935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType inputType = op.getSourceVectorType();
VectorType resType = op.getResultVectorType();

if (inputType.isScalable())
return rewriter.notifyMatchFailure(
op, "This lowering does not support scalable vectors");

// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();

Expand All @@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");

// Replace:
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
// vector<1xnxelty>
// with:
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
//
// Source with leading unit dim (inverse) is also replaced. Unit dim must
// be fixed. Non-unit can be scalable.
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}

// TODO: Add support for scalable vectors
if (inputType.isScalable())
return failure();

// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
Expand Down Expand Up @@ -411,6 +393,64 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransformsOptions vectorTransformOptions;
};

/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
/// to 2D vectors with at least one unit dim. For example:
///
/// Replace:
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
/// vector<1x4xi32>
/// with:
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dim can be scalable.
///
/// TODO: This pattern was introduced specifically to help lower scalable
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
/// to cancel out) would be preferable:
///
/// BEFORE:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
/// AFTER:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
///
/// Given the context above, we may want to consider (re-)moving this pattern
/// at some later time. I am leaving it for now in case there are other users
/// that I am not aware of.
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}

LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
VectorType resType = op.getResultVectorType();

// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();

if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}

return failure();
}
};

/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
Expand Down Expand Up @@ -483,6 +523,8 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}
Loading