-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors #117731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThe current implementation of lowering to llvm for vector.extract incorrectly assumes that if the number of indices is zero, the operation can be folded away. This PR removes this condition and relies on the folder to do it instead. This PR also unifies the logic for scalar extracts and slice extracts, which as a side effect also enables vector.extract lowering for n-d vector.extract with dynamic inner most dimension. (This was only prevented by a conservative check in the old implementation) Full diff: https://github.com/llvm/llvm-project/pull/117731.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..3f47b20cdb577b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Extract entire vector. Should be handled by folder, but just to be safe.
- ArrayRef<OpFoldResult> position(positionVec);
- if (position.empty()) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- // One-shot extraction of vector from array (only requires extractvalue).
- // Except for extracting 1-element vectors.
- if (isa<VectorType>(resultType) &&
- position.size() !=
- static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
- if (extractOp.hasDynamicPosition())
- return failure();
-
- Value extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, adaptor.getVector(), getAsIntegers(position));
- rewriter.replaceOp(extractOp, extracted);
- return success();
- }
+ // Determine if we need to extract a scalar as the result. We extract
+ // a scalar if the extract is full rank i.e. the number of indices is equal
+ // to source vector rank.
+ bool isScalarExtract =
+ positionVec.size() == extractOp.getSourceVectorType().getRank();
+ // Determine if we need to extract a slice out of the original vector. We
+ // always need to extract a slice if the input rank >= 2.
+ bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
- // Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
- if (position.size() > 1) {
- if (extractOp.hasDynamicPosition())
+ if (isSlicingExtract) {
+ ArrayRef<OpFoldResult> position(positionVec);
+ if (isScalarExtract) {
+ // If we are extracting a scalar from the returned slice, we need to
+ // extract a N-1 D slice.
+ position = position.drop_back();
+ }
+ // llvm.extractvalue does not support dynamic dimensions.
+ if (!llvm::all_of(position,
+ [](OpFoldResult x) { return isa<Attribute>(x); })) {
return failure();
+ }
+ extracted = rewriter.create<LLVM::ExtractValueOp>(
+ loc, extracted, getAsIntegers(position));
+ }
- SmallVector<int64_t> nMinusOnePosition =
- getAsIntegers(position.drop_back());
- extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
- nMinusOnePosition);
+ if (isScalarExtract) {
+ Value position;
+ if (positionVec.empty()) {
+ // A scalar extract with no position is a 0-D vector extract. The LLVM
+ // type converter converts 0-D vectors to 1-D vectors, so we need to add
+ // a constant position.
+ auto idxType = rewriter.getIndexType();
+ position = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
+ } else {
+ position = getAsLLVMValue(rewriter, loc, positionVec.back());
+ }
+ extracted =
+ rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
}
- Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
- // Remaining extraction of element from 1-D LLVM vector.
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
- lastPosition);
+ rewriter.replaceOp(extractOp, extracted);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index da0222bc942376..cd687becb82b82 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1258,26 +1258,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
// -----
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
return %0 : f32
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
-// CHECK: vector.extract
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
+// CHECK: llvm.extractvalue
+// CHECK: llvm.extractelement
-func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
return %0 : f32
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
+// CHECK: llvm.extractvalue
+// CHECK: llvm.extractelement
+
+// -----
-// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+ %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
+ return %0 : f32
+}
+
+// Multi-dim vectors are supported if the inner most dimension is dynamic.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
// CHECK: vector.extract
+func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
+ %0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
+ return %0 : f32
+}
+
+// Multi-dim vectors with outer dimension as dynamic are not supported, but it
+// shouldn't crash.
+
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
+// CHECK: vector.extract
+
+// -----
+
+func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
+ %0 = vector.extract %arg0[]: index from vector<index>
+ return %0 : index
+}
+// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
+// CHECK-SAME: %[[A:.*]]: vector<index>)
+// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
+// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
+// CHECK: return %[[T3]] : index
+
// -----
func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of drive-by comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nits. Please wait for an approval from someone who has stake in llvm lowering before merging this.
c513cb1
to
9ba9d44
Compare
@banach-space ping :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates and sorry for missing this! LGTM, really appreciate the detailed documentation 🙏🏻
I've left a few nits. Feel free to ignore.
…ine with `vector.extract` (#128915) This is doing the same as #117731 did for `vector.extract`, but for `vector.insert`. It is a bit more complicated as the insertion destination may itself need to be extracted. As the test shows, this fixes two previously unsupported cases: - Dynamic indices - 0-D vectors. --------- Signed-off-by: Benoit Jacob <[email protected]>
…o LLVM in line with `vector.extract` (#128915) This is doing the same as llvm/llvm-project#117731 did for `vector.extract`, but for `vector.insert`. It is a bit more complicated as the insertion destination may itself need to be extracted. As the test shows, this fixes two previously unsupported cases: - Dynamic indices - 0-D vectors. --------- Signed-off-by: Benoit Jacob <[email protected]>
The current implementation of lowering to llvm for vector.extract incorrectly assumes that if the number of indices is zero, the operation can be folded away. This PR removes this condition and relies on the folder to do it instead.
This PR also unifies the logic for scalar extracts and slice extracts, which as a side effect also enables vector.extract lowering for n-d vector.extract with dynamic inner most dimension. (This was only prevented by a conservative check in the old implementation)