-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Split TransposeOpLowering
into 2 patterns
#91935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Split TransposeOpLowering
into 2 patterns
#91935
Conversation
Splits `TransposeOpLowering` into two patterns: 1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose` as `vector.shape_cast` (there has to be at least one unit dim), 2. `TransposeOpLowering` - the original pattern without the part extracted into `Transpose2DWithUnitDimToShapeCast`. The rationale behind the split: * the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't really match the intended output from `TransposeOpLowering` as documented in the source file - it doesn't make much sense to keep it embedded inside `TransposeOpLowering`, * `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors, `TransposeOpLowering` _does_ not.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesSplits
The rationale behind the split:
Full diff: https://github.com/llvm/llvm-project/pull/91935.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7011c478fefba..0706f22cb8b12 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType inputType = op.getSourceVectorType();
VectorType resType = op.getResultVectorType();
+ if (inputType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "This lowering does not support scalable vectors");
+
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
- // Replace:
- // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
- // vector<1xnxelty>
- // with:
- // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
- //
- // Source with leading unit dim (inverse) is also replaced. Unit dim must
- // be fixed. Non-unit can be scalable.
- if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
- transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
- return success();
- }
-
- // TODO: Add support for scalable vectors
- if (inputType.isScalable())
- return failure();
-
// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,48 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransformsOptions vectorTransformOptions;
};
+/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
+/// to 2D vectors with at least one unit dim. For example:
+///
+/// Replace:
+/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
+/// vector<1x4xi32>
+/// with:
+/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit can be scalable.
+class Transpose2DWithUnitDimToShapeCast
+ : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ Value input = op.getVector();
+ VectorType resType = op.getResultVectorType();
+
+ // Set up convenience transposition table.
+ ArrayRef<int64_t> transp = op.getPermutation();
+
+ if (resType.getRank() == 2 &&
+ ((resType.getShape().front() == 1 &&
+ !resType.getScalableDims().front()) ||
+ (resType.getShape().back() == 1 &&
+ !resType.getScalableDims().back())) &&
+ transp == ArrayRef<int64_t>({1, 0})) {
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -483,6 +507,8 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
+ patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
+ benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}
|
Note that it only works if (and only if) the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving this PR, but I do wonder (with the benefit of hindsight), if the Transpose2DWithUnitDimToShapeCast
lowering would be better as a (more specific) canonicalization of a shape_cast + transpose
. Since what we really want here is for the illegal vector types (e.g. vector<[4]x1xf32>
) to be eliminated, which won't always happen if the transpose is replaced with a shape cast.
E.g. the rewrite could be:
%0 = some_op
%1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
%2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
->
%0 = some_op
%1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think such a change causes failure on SPIR-V backend. I think we've had this discussion previously.
Waiting for more people to chime in. cc @kuhar and @antiagainst
It shouldn't, this is merely moving code around. |
IIRC this is now going to be generating a vector.shape_cast that isnt handled on SPIR-V backends, and this pattern is being added to a "generic vector lowering" pattern set. I'd suggest adding a "populateVectorTransposeLoweringForLLVMPatterns" and adding this to that path. I know the SPIR-V path isnt tested as well in MLIR, but the shape_cast isnt supported on SPIR-V path. |
IIRC this lowering has not been a problem on SPIR-V for a while now (since |
No, llvm-project/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir Lines 801 to 858 in 1fd196c
In our previous discussion about these patterns, @antiagainst confirmed that those are not a problem for SPIR-V:
In any case, let me re-iterate, this patch does not add any new patterns, it is merely moving the existing patterns.
No, this pattern:
This PR merely extracts that logic into a separate class, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I understand now. Thanks for the explanation and sorry for the noise!
That's a good point, but I'd like to get input from other reviewers first (specifically, folks that are using these patterns for something other than scalable vectors). Let me add a TODO with your suggestion and we can re-visit at some later time. Hopefully somebody will notice and bring it up.
No worries, it's good to have somebody with a SPIR-V hat on scanning these PRs, thanks for taking a look! |
Splits
TransposeOpLowering
into two patterns:Transpose2DWithUnitDimToShapeCast
- rewrites 2Dvector.transpose
as
vector.shape_cast
(there has to be at least one unit dim),TransposeOpLowering
- the original pattern without the partextracted into
Transpose2DWithUnitDimToShapeCast
.The rationale behind the split:
Transpose2DWithUnitDimToShapeCast
doesn'treally match the intended output from
TransposeOpLowering
asdocumented in the source file - it doesn't make much sense to keep
it embedded inside
TransposeOpLowering
,Transpose2DWithUnitDimToShapeCast
does work for scalable vectors,TransposeOpLowering
does not.