-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering #124861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesIt looks like scalable This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter. Full diff: https://github.com/llvm/llvm-project/pull/124861.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f1..2c32634544b90b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
+ int64_t srcRank = srcType.getRank();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+ return failure();
if (op.getOffsets().getValue().empty())
return failure();
- int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
+ auto srcType = op.getSourceVectorType();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if (dstType.isScalable() || srcType.isScalable())
+ return failure();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
+ // Generate chains of extract/insert ops for scalable vectors only as they
+ // can't be lowered to vector shuffles.
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ patterns,
+ /*controlFn=*/
+ [](ExtractStridedSliceOp op) {
+ return op.getType().isScalable() ||
+ op.getSourceVectorType().isScalable();
+ },
+ benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1..7df6defc0f202f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-// CHECK: return %[[T5]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK: return %[[RES]]
// -----
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesIt looks like scalable This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter. Full diff: https://github.com/llvm/llvm-project/pull/124861.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f13..2c32634544b90b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
+ int64_t srcRank = srcType.getRank();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+ return failure();
if (op.getOffsets().getValue().empty())
return failure();
- int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
+ auto srcType = op.getSourceVectorType();
+
+ // Scalable vectors are not supported by vector shuffle.
+ if (dstType.isScalable() || srcType.isScalable())
+ return failure();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
+ // Generate chains of extract/insert ops for scalable vectors only as they
+ // can't be lowered to vector shuffles.
+ populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+ patterns,
+ /*controlFn=*/
+ [](ExtractStridedSliceOp op) {
+ return op.getType().isScalable() ||
+ op.getSourceVectorType().isScalable();
+ },
+ benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1a..7df6defc0f202f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: func.func @extract_strided_slice_f32_1d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-// CHECK: return %[[T5]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+// CHECK: %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+// CHECK: %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+// CHECK: return %[[RES]]
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
I see 3 changes (i.e. places where scalable vectors are disabled). However, there are no new tests. Shouldn't there be 3 new negative tests that show that the updated patterns fail?
int64_t srcRank = srcType.getRank(); | ||
|
||
// Scalable vectors are not supported by vector shuffle. | ||
if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we also checking for rank-1? Seems unrelated to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
srcRank == 1
you mean? That check is needed because that's the rank that leads to generating a vector shuffle. Otherwise the op is decomposed into lower order ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Sorry, I replied from my phone. See line 127 for the vector shuffle path).
The bug is triggered by existing tests + one of the upcoming PRs (which is not adding any new scalable tests). The output IR is the same except for one test where the op order changes but it's effectively the same code. As mentioned, we generated valid IR but going through an intermediate invalid vector shuffle. The output IR should be the same before and after this change. Happy to add any extra test you think makes sense. |
Sorry for being a bit slow, reading this on a tablet and the GitHub app is far from ideal.
Have you checked the debug dump? Sometimes, an intermediate IR might be invalid while the overall result remains valid. In such cases,
If the bug is triggered by existing tests, shouldn't they be updated as part of this PR? I suspect this isn’t feasible because the affected patterns aren’t tested in isolation—highlighting a gap in Vector dialect testing. The fact that only
By "IR," do you mean LLVM IR or the intermediate representation being checked?
Yes, that would be great. The vector-to-llvm pipeline feels like a jackhammer when used to test such a nuance in a very specific pattern. Sadly, the only relevant-looking test file that I could find is this: But I'm on a tablet using GitHub app, so I'm pretty limited with my search. Tl;Dr The fix is correct and much appreciated—I’ll approve this to unblock you. Let’s add a TODO for a dedicated test for this pattern. I can take care of it when adding tests for scalable vectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Please, could you include a TODO to add tests for the affected patterns (which seem to be missing)?
I checked with GDB
Just to make it clearer. With this PR we go from:
to
so the output IR is "the same". The modified CHECK rules are needed because with the PR it looks like sometimes the llvm extract/insert ops are generated in a different order. That is, we go from:
to
There are no more CHECK rules that need to be modified as part of this fix. We can of course test this |
It looks like scalable `vector.insert/extractslice` ops made their way through lowering patterns that generate `vector.shuffle`` ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified. This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.
a25b5fe
to
2dfe220
Compare
It looks like scalable
vector.insert/extractslice
ops made their way through lowering patterns that generatevector.shuffle
ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified.This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.