-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] VectorToLLVM: Fix vector.insert conversion for 0-D vectors, and add a test #128810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benoit Jacob <[email protected]>
@llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesThe handling of %0 = vector.insert %src, %dst[] : f32 into vector<f32> Since the type conversion to LLVM convertes Tests added. There were no tests convering the Full diff: https://github.com/llvm/llvm-project/pull/128810.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..a5e9e9bf6498b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1233,11 +1233,15 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
ArrayRef<OpFoldResult> position(positionVec);
+ // Case of empty position, used with 0-D destination vector. In that case,
+ // the converted destination type is a LLVM vector of size 1, and we need
+ // a 0 as the position.
if (position.empty()) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
+ rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+ insertOp, llvmResultType, adaptor.getDest(), adaptor.getSource(),
+ rewriter.create<LLVM::ConstantOp>(loc,
+ rewriter.getI64IntegerAttr(0)));
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..72ca06ba7d9a4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1787,3 +1787,32 @@ func.func @step() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @insert_0d
+// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+func.func @insert_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+ %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_1d
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+func.func @insert_1d(%src: f32, %dst: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.insert %src, %dst[1] : f32 into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @insert_2d
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+func.func @insert_2d(%src: f32, %dst: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.insert %src, %dst[1, 0] : f32 into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
|
@MaheshRavishankar , @kuhar , @jsjodin , @krzysz00 , this seems to actually fix the issue that I was trying earlier to fix in #128203. What I am trying to do is enable usage of ROCm device lib functions for all math functions by disabling polynomial approximations and instead preserving math ops as-is until MathToROCDL. What I was missing earlier is that some math functions are purposely not handled by MathToROCDL because they don't have a target ROCm device library function, because the intended lowering is to a LLVM intrinsics, in MathToLLVM. That is the case of math.log and math.exp for element type f32. Earlier, I was misinterpreting the failure to match in MathToROCDL as the problem. Instead, it was working-as-intended that it was failing, and it always was MathToLLVM that was handling it, correctly rewriting as an intrinsic. The conversion that was leading to the unrealized_conversion_cast that I was seeing was not a conversion of a math op, but of the next op in the source IR which was a vector.insert. |
The vector.insert tests are here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir#L628 |
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are existing tests in https://github.com/llvm/llvm-project/blob/main/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir#L628 . Can we move the newly added test there?
I did a similar bugfix for vector.extract recently: #117731 . One thing I did in that PR, which would be nice to do for vector.insert also is revisiting the logic for the lowering again. It has gone through three iterations: static indices only, dynamic indices and now this PR and has become special casing for each case. You don't have to do it if you just want to land this fix, but if you are interested in improving the logic, it would be good to do it in this PR. |
Thanks @Groverkss for the context! I will read your PR and try to mirror it for insert. |
Sent #128915 for review. Thanks again @Groverkss . |
The handling of
vector.insert
in VectorToLLVM was incorrectly handling the case of a 0-D destination vector, as in:Since the type conversion to LLVM convertes
vector<f32>
tovector<1xf32>
, it was required to rewrite the op into a llvm.insertelement into such avector<1xf32>
. Instead, the existing code simply returned the source value, as if the converted type was the scalar type.Tests added. There were no tests convering the
vector.insert
conversions.