-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast #73951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast. Example:
Folds to:
This is an (alternate) fix for lowering matmuls to ArmSME. Full diff: https://github.com/llvm/llvm-project/pull/73951.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133fc9..cf006adaee72a25 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,12 +5548,55 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
+/// permutes a unit dim from the result of the shape_cast.
+class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transpOp,
+ PatternRewriter &rewriter) const override {
+ Value transposeSrc = transpOp.getVector();
+ auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return failure();
+
+ auto sourceType = transpOp.getSourceVectorType();
+ auto resultType = transpOp.getResultVectorType();
+
+ auto filterUnitDims = [](VectorType type) {
+ return llvm::make_filter_range(
+ llvm::zip_equal(type.getShape(), type.getScalableDims()),
+ [&](auto dim) {
+ auto [size, isScalble] = dim;
+ return size != 1 || isScalble;
+ });
+ };
+
+ auto sourceWithoutUnitDims = filterUnitDims(sourceType);
+ auto resultWithoutUnitDims = filterUnitDims(sourceType);
+
+ // If this transpose just permutes a unit dim, then we can fold it into the
+ // shape_cast.
+ for (auto [srcDim, resDim] :
+ llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
+ if (srcDim != resDim)
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
+ shapeCastOp.getSource());
+
+ return success();
+ };
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..6bfb477ecf97285 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,6 +67,18 @@ func.func @create_mask_transpose_to_transposed_create_mask(
// -----
+// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
+ // CHECK-NOT: vector.transpose
+ %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %1 : vector<1x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: extract_from_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
|
This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast. Example: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> ``` Folds to: ``` vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32> ``` This is an (alternate) fix for lowering matmuls to ArmSME.
107a03a
to
5c8fd0d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I'd give time for others to take a look though given the ongoing discussion.
if (!shapeCastOp) | ||
return failure(); | ||
|
||
auto sourceType = transpOp.getSourceVectorType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we had
%0 = shape_cast ... : vector<4x1x4> to vector<2x2x1x4>
transpose %0 [0, 1, 3, 2] : vector<2x2x1x4> to vector<2x2x4x1>
This devolves to the same discussion in the other PR. Since there's already a shape_cast in the source I won't block here, but would it still work to use the source vector type of the shape_cast
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still a legal shape_cast
and will lower to (pretty much) the same thing. But yeah, the point here is we're not adding a shape_cast
where there was not already one before, so this should not cause problems for SPIR-V :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point I wanted to make was that the pattern in the test is definitely a reasonable canonicalization (i.e. the transpose is just shuffling unit dims introduced by the shape_cast), but unit dims present before the shape_cast touch on the same discussion. My main concern with making the earlier pattern a canonicalization wasn't SPIR-V specific, it was more a matter of whether it was a pattern we wanted globally. Semantic quirks of shape_cast
aside (being discussed on discourse), vector.transpose
or vector.contract
(i.e. higher level vector operations) play nicer with transpose
than shape_cast
. That's why I see the vector lowering pattern as reasonable; SPIR-V should use the "LLVM" lowering for shape_cast in that case. Making it a canonicalization though means that this needs to be the canonical representation everywhere. Because there's already a shape_cast in the IR here, that's why I'm not blocking, but similarly there's no reason we couldn't have a shape_cast alongside "higher level" vector IR, hence my question.
I feel like we've been blocking your work with how much this conversation got blown up though, and I am sorry about that :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point for us is vector<[4]x1xf32>
is an impossible type, there's no legal lowering for that in LLVM (only trailing scalable dimensions are supported). So we need a mechanism (such as this), which allows it to be eliminated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also LGTM and seems like a sensible path forward that should unlock us given this only applies when there's a shape cast to begin with.
This does look good to me. Do you mind if I check if it fixes the issue and get back to you? |
Fine with me 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Kicked of a PR on IREE (iree-org/iree#15748) that cherry-picks this change and undoes the revert we were carrying locally if it fixes the original issue. |
This does not depend on the previous transpose lowering (which will likely still lead to problems for you) This patch alone is all we need for lowering SME matmuls :) |
Brain fade on my part.... So really we will also need to upstream revert the original patch.... Thats a different discussion. Fixed up the PR for now that just tests this patch. |
One thing to note is that with the original patch reverted, the transpose lowering generates invalid code for scalable vectors. It must at least return |
This folding rules seems wrong to me? It does not even look at the transpose pattern of the non-unit dims? I have code that looks like this:
and it gets simplified to
which obviously produces different values! |
…lvm#73951)" This reverts commit f42b761. The fold pattern is incorrect, because it does not even look at the permutation of non-unit dims and is happy to replace a pattern such as ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> %23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32> ``` with ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> ``` which is obviously incorrect.
…73951)" (#74579) This reverts commit f42b761. The fold pattern is incorrect, because it does not even look at the permutation of non-unit dims and is happy to replace a pattern such as ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> %23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32> ``` with ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> ``` which is obviously incorrect.
This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast. Example: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> ``` Folds to: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32> ``` This is an (alternate) fix for lowering matmuls to ArmSME. --- Corrected version of llvm#73951.
This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast. Example: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> ``` Folds to: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32> ``` This is an (alternate) fix for lowering matmuls to ArmSME. --- Corrected version of llvm#73951.
I've prepared a patch here: https://github.com/llvm/llvm-project/compare/main...MacDue:llvm-project:transpose_of_shape_cast_v2?expand=1, that I believe fixes the correctness issues. Sorry for the inconvenience! :) I won't create a new PR because we have alternate solutions to what this aimed to solve (and the general n-D case for this fold is a little tricky). |
This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast.
Example:
Folds to:
This is an (alternate) fix for lowering matmuls to ArmSME.