Skip to content

[MLIR] VectorToLLVM: Fix vector.insert conversion for 0-D vectors, and add a test #128810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Feb 26, 2025

The handling of vector.insert in VectorToLLVM was incorrectly handling the case of a 0-D destination vector, as in:

%0 = vector.insert %src, %dst[] : f32 into vector<f32>

Since the type conversion to LLVM convertes vector<f32> to vector<1xf32>, it was required to rewrite the op into a llvm.insertelement into such a vector<1xf32>. Instead, the existing code simply returned the source value, as if the converted type was the scalar type.

Tests added. There were no tests convering the vector.insert conversions.

Signed-off-by: Benoit Jacob <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2025

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

The handling of vector.insert in VectorToLLVM was incorrectly handling the case of a 0-D destination vector, as in:

%0 = vector.insert %src, %dst[] : f32 into vector&lt;f32&gt;

Since the type conversion to LLVM convertes vector&lt;f32&gt; to vector&lt;1xf32&gt;, it was required to rewrite the op into a llvm.insertelement into such a vector&lt;1xf32&gt;. Instead, the existing code simply returned the source value, as if the converted type was the scalar type.

Tests added. There were no tests convering the vector.insert conversions.


Full diff: https://github.com/llvm/llvm-project/pull/128810.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+7-3)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+29)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..a5e9e9bf6498b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1233,11 +1233,15 @@ class VectorInsertOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
     ArrayRef<OpFoldResult> position(positionVec);
+    // Case of empty position, used with 0-D destination vector. In that case,
+    // the converted destination type is a LLVM vector of size 1, and we need
+    // a 0 as the position.
     if (position.empty()) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
+      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+          insertOp, llvmResultType, adaptor.getDest(), adaptor.getSource(),
+          rewriter.create<LLVM::ConstantOp>(loc,
+                                            rewriter.getI64IntegerAttr(0)));
       return success();
     }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..72ca06ba7d9a4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1787,3 +1787,32 @@ func.func @step() -> vector<4xindex> {
   %0 = vector.step : vector<4xindex>
   return %0 : vector<4xindex>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @insert_0d
+// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+func.func @insert_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+  %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_1d
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+func.func @insert_1d(%src: f32, %dst: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.insert %src, %dst[1] : f32 into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @insert_2d
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+func.func @insert_2d(%src: f32, %dst: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = vector.insert %src, %dst[1, 0] : f32 into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

@bjacob
Copy link
Contributor Author

bjacob commented Feb 26, 2025

@MaheshRavishankar , @kuhar , @jsjodin , @krzysz00 , this seems to actually fix the issue that I was trying earlier to fix in #128203.

What I am trying to do is enable usage of ROCm device lib functions for all math functions by disabling polynomial approximations and instead preserving math ops as-is until MathToROCDL. What I was missing earlier is that some math functions are purposely not handled by MathToROCDL because they don't have a target ROCm device library function, because the intended lowering is to a LLVM intrinsics, in MathToLLVM. That is the case of math.log and math.exp for element type f32. Earlier, I was misinterpreting the failure to match in MathToROCDL as the problem. Instead, it was working-as-intended that it was failing, and it always was MathToLLVM that was handling it, correctly rewriting as an intrinsic.

The conversion that was leading to the unrealized_conversion_cast that I was seeing was not a conversion of a math op, but of the next op in the source IR which was a vector.insert.

@kuhar kuhar requested a review from Groverkss February 26, 2025 05:07
@Groverkss
Copy link
Member


// -----

//===----------------------------------------------------------------------===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Groverkss
Copy link
Member

Groverkss commented Feb 26, 2025

I did a similar bugfix for vector.extract recently: #117731 . One thing I did in that PR, which would be nice to do for vector.insert also is revisiting the logic for the lowering again. It has gone through three iterations: static indices only, dynamic indices and now this PR and has become special casing for each case.

You don't have to do it if you just want to land this fix, but if you are interested in improving the logic, it would be good to do it in this PR.

@bjacob
Copy link
Contributor Author

bjacob commented Feb 26, 2025

Thanks @Groverkss for the context! I will read your PR and try to mirror it for insert.

@bjacob bjacob closed this Feb 26, 2025
@bjacob
Copy link
Contributor Author

bjacob commented Feb 26, 2025

Sent #128915 for review. Thanks again @Groverkss .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants