Skip to content

[mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors #117731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 4, 2024

Conversation

Groverkss
Copy link
Member

The current implementation of lowering to llvm for vector.extract incorrectly assumes that if the number of indices is zero, the operation can be folded away. This PR removes this condition and relies on the folder to do it instead.

This PR also unifies the logic for scalar extracts and slice extracts, which as a side effect also enables vector.extract lowering for n-d vector.extract with dynamic inner most dimension. (This was only prevented by a conservative check in the old implementation)

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

The current implementation of lowering to llvm for vector.extract incorrectly assumes that if the number of indices is zero, the operation can be folded away. This PR removes this condition and relies on the folder to do it instead.

This PR also unifies the logic for scalar extracts and slice extracts, which as a side effect also enables vector.extract lowering for n-d vector.extract with dynamic inner most dimension. (This was only prevented by a conservative check in the old implementation)


Full diff: https://github.com/llvm/llvm-project/pull/117731.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+38-31)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+46-7)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..3f47b20cdb577b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Extract entire vector. Should be handled by folder, but just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(extractOp, adaptor.getVector());
-      return success();
-    }
-
-    // One-shot extraction of vector from array (only requires extractvalue).
-    // Except for extracting 1-element vectors.
-    if (isa<VectorType>(resultType) &&
-        position.size() !=
-            static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
-      if (extractOp.hasDynamicPosition())
-        return failure();
-
-      Value extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, adaptor.getVector(), getAsIntegers(position));
-      rewriter.replaceOp(extractOp, extracted);
-      return success();
-    }
+    // Determine if we need to extract a scalar as the result. We extract
+    // a scalar if the extract is full rank i.e. the number of indices is equal
+    // to source vector rank.
+    bool isScalarExtract =
+        positionVec.size() == extractOp.getSourceVectorType().getRank();
+    // Determine if we need to extract a slice out of the original vector. We
+    // always need to extract a slice if the input rank >= 2.
+    bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
 
-    // Potential extraction of 1-D vector from array.
     Value extracted = adaptor.getVector();
-    if (position.size() > 1) {
-      if (extractOp.hasDynamicPosition())
+    if (isSlicingExtract) {
+      ArrayRef<OpFoldResult> position(positionVec);
+      if (isScalarExtract) {
+        // If we are extracting a scalar from the returned slice, we need to
+        // extract a N-1 D slice.
+        position = position.drop_back();
+      }
+      // llvm.extractvalue does not support dynamic dimensions.
+      if (!llvm::all_of(position,
+                        [](OpFoldResult x) { return isa<Attribute>(x); })) {
         return failure();
+      }
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, extracted, getAsIntegers(position));
+    }
 
-      SmallVector<int64_t> nMinusOnePosition =
-          getAsIntegers(position.drop_back());
-      extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
-                                                        nMinusOnePosition);
+    if (isScalarExtract) {
+      Value position;
+      if (positionVec.empty()) {
+        // A scalar extract with no position is a 0-D vector extract. The LLVM
+        // type converter converts 0-D vectors to 1-D vectors, so we need to add
+        // a constant position.
+        auto idxType = rewriter.getIndexType();
+        position = rewriter.create<LLVM::ConstantOp>(
+            loc, typeConverter->convertType(idxType),
+            rewriter.getIntegerAttr(idxType, 0));
+      } else {
+        position = getAsLLVMValue(rewriter, loc, positionVec.back());
+      }
+      extracted =
+          rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
     }
 
-    Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
-    // Remaining extraction of element from 1-D LLVM vector.
-    rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
-                                                        lastPosition);
+    rewriter.replaceOp(extractOp, extracted);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index da0222bc942376..cd687becb82b82 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1258,26 +1258,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
 
 // -----
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.extract
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
 
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
   return %0 : f32
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
+//       CHECK:   llvm.extractvalue
+//       CHECK:   llvm.extractelement
+
+// -----
 
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
 //       CHECK:   vector.extract
 
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// shouldn't crash.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
+//       CHECK:   vector.extract
+
+// -----
+
+func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
+  %0 = vector.extract %arg0[]: index from vector<index>
+  return %0 : index
+}
+// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
+//  CHECK-SAME:   %[[A:.*]]: vector<index>)
+//       CHECK:   %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
+//       CHECK:   %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK:   %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
+//       CHECK:   %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+//       CHECK:   return %[[T3]] : index
+
 // -----
 
 func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of drive-by comments

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits. Please wait for an approval from someone who has stake in llvm lowering before merging this.

@Groverkss
Copy link
Member Author

@banach-space ping :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and sorry for missing this! LGTM, really appreciate the detailed documentation 🙏🏻

I've left a few nits. Feel free to ignore.

@Groverkss Groverkss merged commit a8f9271 into llvm:main Dec 4, 2024
6 of 7 checks passed
bjacob added a commit that referenced this pull request Feb 26, 2025
…ine with `vector.extract` (#128915)

This is doing the same as
#117731 did for
`vector.extract`, but for `vector.insert`.

It is a bit more complicated as the insertion destination may itself
need to be extracted.

As the test shows, this fixes two previously unsupported cases:
- Dynamic indices
- 0-D vectors.

---------

Signed-off-by: Benoit Jacob <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 26, 2025
…o LLVM in line with `vector.extract` (#128915)

This is doing the same as
llvm/llvm-project#117731 did for
`vector.extract`, but for `vector.insert`.

It is a bit more complicated as the insertion destination may itself
need to be extracted.

As the test shows, this fixes two previously unsupported cases:
- Dynamic indices
- 0-D vectors.

---------

Signed-off-by: Benoit Jacob <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants