Skip to content

[mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering #124861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 31, 2025

Conversation

dcaballe
Copy link
Contributor

It looks like scalable vector.insert/extractslice ops made their way through lowering patterns that generate vector.shuffle ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified.

This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

It looks like scalable vector.insert/extractslice ops made their way through lowering patterns that generate vector.shuffle ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified.

This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.


Full diff: https://github.com/llvm/llvm-project/pull/124861.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+20-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+8-7)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f1..2c32634544b90b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
                                 PatternRewriter &rewriter) const override {
     auto srcType = op.getSourceVectorType();
     auto dstType = op.getDestVectorType();
+    int64_t srcRank = srcType.getRank();
+
+    // Scalable vectors are not supported by vector shuffle.
+    if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+      return failure();
 
     if (op.getOffsets().getValue().empty())
       return failure();
 
-    int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
     assert(dstRank >= srcRank);
     if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
+    auto srcType = op.getSourceVectorType();
+
+    // Scalable vectors are not supported by vector shuffle.
+    if (dstType.isScalable() || srcType.isScalable())
+      return failure();
 
     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
 
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
                                                         benefit);
+  // Generate chains of extract/insert ops for scalable vectors only as they
+  // can't be lowered to vector shuffles.
+  populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+      patterns,
+      /*controlFn=*/
+      [](ExtractStridedSliceOp op) {
+        return op.getType().isScalable() ||
+               op.getSourceVectorType().isScalable();
+      },
+      benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1..7df6defc0f202f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL:   func.func @extract_strided_slice_f32_1d_from_2d_scalable(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x[8]xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-//       CHECK:    return %[[T5]]
+//       CHECK:    %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+//       CHECK:    %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+//       CHECK:    return %[[RES]]
 
 // -----
 

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

It looks like scalable vector.insert/extractslice ops made their way through lowering patterns that generate vector.shuffle ops. I'm not sure why this wasn't caught by the verifier, probably because the shuffle op was folded into something else as part of the same rewrite and the IR wasn't verified.

This PR fixes the issue by preventing scalable vector.insert/extractslice ops to be lowered to vector shuffles. Instead, they are now lowered to a sequence of insert/extractelement ops using an existing patter.


Full diff: https://github.com/llvm/llvm-project/pull/124861.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+20-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+8-7)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f13..2c32634544b90b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
                                 PatternRewriter &rewriter) const override {
     auto srcType = op.getSourceVectorType();
     auto dstType = op.getDestVectorType();
+    int64_t srcRank = srcType.getRank();
+
+    // Scalable vectors are not supported by vector shuffle.
+    if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+      return failure();
 
     if (op.getOffsets().getValue().empty())
       return failure();
 
-    int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
     assert(dstRank >= srcRank);
     if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
+    auto srcType = op.getSourceVectorType();
+
+    // Scalable vectors are not supported by vector shuffle.
+    if (dstType.isScalable() || srcType.isScalable())
+      return failure();
 
     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
 
@@ -331,4 +340,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
                                                         benefit);
+  // Generate chains of extract/insert ops for scalable vectors only as they
+  // can't be lowered to vector shuffles.
+  populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+      patterns,
+      /*controlFn=*/
+      [](ExtractStridedSliceOp op) {
+        return op.getType().isScalable() ||
+               op.getSourceVectorType().isScalable();
+      },
+      benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1a..7df6defc0f202f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL:   func.func @extract_strided_slice_f32_1d_from_2d_scalable(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x[8]xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-//       CHECK:    return %[[T5]]
+//       CHECK:    %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+//       CHECK:    %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+//       CHECK:    return %[[RES]]
 
 // -----
 

@dcaballe dcaballe requested review from kuhar and Groverkss January 29, 2025 23:57
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

I see 3 changes (i.e. places where scalable vectors are disabled). However, there are no new tests. Shouldn't there be 3 new negative tests that show that the updated patterns fail?

int64_t srcRank = srcType.getRank();

// Scalable vectors are not supported by vector shuffle.
if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we also checking for rank-1? Seems unrelated to this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

srcRank == 1 you mean? That check is needed because that's the rank that leads to generating a vector shuffle. Otherwise the op is decomposed into lower order ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Sorry, I replied from my phone. See line 127 for the vector shuffle path).

@dcaballe
Copy link
Contributor Author

The bug is triggered by existing tests + one of the upcoming PRs (which is not adding any new scalable tests). The output IR is the same except for one test where the op order changes but it's effectively the same code. As mentioned, we generated valid IR but going through an intermediate invalid vector shuffle. The output IR should be the same before and after this change. Happy to add any extra test you think makes sense.

@banach-space
Copy link
Contributor

Sorry for being a bit slow, reading this on a tablet and the GitHub app is far from ideal.

I'm not sure why this wasn't caught by the verifier

Have you checked the debug dump? Sometimes, an intermediate IR might be invalid while the overall result remains valid. In such cases, mlir-opt would return success. This could help in designing a dedicated test for the issue.

The bug is triggered by existing tests

If the bug is triggered by existing tests, shouldn't they be updated as part of this PR? I suspect this isn’t feasible because the affected patterns aren’t tested in isolation—highlighting a gap in Vector dialect testing. The fact that only mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir is updated supports this point (as I also alluded to in #124308).

As mentioned, we generated valid IR but going through an intermediate invalid vector shuffle.

By "IR," do you mean LLVM IR or the intermediate representation being checked?

Happy to add any extra test you think makes sense

Yes, that would be great. The vector-to-llvm pipeline feels like a jackhammer when used to test such a nuance in a very specific pattern.

Sadly, the only relevant-looking test file that I could find is this:

But I'm on a tablet using GitHub app, so I'm pretty limited with my search.


Tl;Dr The fix is correct and much appreciated—I’ll approve this to unblock you. Let’s add a TODO for a dedicated test for this pattern. I can take care of it when adding tests for scalable vectors.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Please, could you include a TODO to add tests for the affected patterns (which seem to be missing)?

@dcaballe
Copy link
Contributor Author

Have you checked the debug dump?

I checked with GDB

If the bug is triggered by existing tests, shouldn't they be updated as part of this PR?

Just to make it clearer. With this PR we go from:

vector.insert_slice/extract_slice op -> vector.shuffle op -> chain of llvm extract/insert ops

to

vector.insert_slice/extract_slice op -> chain of llvm extract/insert ops

so the output IR is "the same". The modified CHECK rules are needed because with the PR it looks like sometimes the llvm extract/insert ops are generated in a different order. That is, we go from:

llvm.extractvalue
llvm.insertvalue
llvm.extractvalue
llvm.insertvalue

to

llvm.extractvalue
llvm.extractvalue
llvm.insertvalue
llvm.insertvalue

There are no more CHECK rules that need to be modified as part of this fix.

We can of course test this populate... also in isolation by adding a new test pass. We can do that for this populate... and others but I wouldn't do that as part of this PR. I'll add a TODO, as suggested.

It looks like scalable `vector.insert/extractslice` ops made their way
through lowering patterns that generate `vector.shuffle`` ops. I'm not
sure why this wasn't caught by the verifier, probably because the shuffle
op was folded into something else as part of the same rewrite and the IR
wasn't verified.

This PR fixes the issue by preventing scalable vector.insert/extractslice
ops to be lowered to vector shuffles. Instead, they are now lowered to a
sequence of insert/extractelement ops using an existing patter.
@dcaballe dcaballe force-pushed the scalable-insert-extract-shuffle branch from a25b5fe to 2dfe220 Compare January 31, 2025 22:13
@dcaballe dcaballe merged commit 2b04291 into llvm:main Jan 31, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants