-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector]advance support extract insert under dynamic case. #121631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector]advance support extract insert under dynamic case. #121631
Conversation
@llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) ChangesAdvance support for Full diff: https://github.com/llvm/llvm-project/pull/121631.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9657f583c375bb..4af03126fa1edd 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1096,6 +1096,26 @@ class VectorExtractOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+ for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+ if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+ auto defOp = position.getDefiningOp();
+ while (defOp) {
+ if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+ Attribute value =
+ defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+ positionVec[idx] = OpFoldResult{
+ rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+ break;
+ } else if (auto unrealizedCastOp =
+ llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+ defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
// The Vector -> LLVM lowering models N-D vectors as nested aggregates of
// 1-d vectors. This nesting is modeled using arrays. We do this conversion
// from a N-d vector extract to a nested aggregate vector extract in two
@@ -1231,6 +1251,25 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
+ for (unsigned idx = 0; idx < positionVec.size(); ++idx) {
+ if (auto position = llvm::dyn_cast<Value>(positionVec[idx])) {
+ auto defOp = position.getDefiningOp();
+ while (defOp) {
+ if (llvm::isa<arith::ConstantOp, LLVM::ConstantOp>(defOp)) {
+ Attribute value =
+ defOp->getAttr(arith::ConstantOp::getAttributeNames()[0]);
+ positionVec[idx] = OpFoldResult{
+ rewriter.getI64IntegerAttr(cast<IntegerAttr>(value).getInt())};
+ break;
+ } else if (auto unrealizedCastOp =
+ llvm::dyn_cast<UnrealizedConversionCastOp>(defOp)) {
+ defOp = unrealizedCastOp.getOperand(0).getDefiningOp();
+ } else {
+ break;
+ }
+ }
+ }
+ }
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
@@ -1242,8 +1281,9 @@ class VectorInsertOpConversion
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
@@ -1255,8 +1295,9 @@ class VectorInsertOpConversion
Value extracted = adaptor.getDest();
auto oneDVectorType = destVectorType;
if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
@@ -1270,8 +1311,9 @@ class VectorInsertOpConversion
// Potential insertion of resulting 1-D vector into array.
if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
+ if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
return failure();
+ }
inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f95e943250bd44..d16d78556da106 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4094,3 +4094,93 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+ return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_5]] : i32
+
+// -----
+
+// CHECK-LABEL: @extract_llvm_constnt()
+
+module {
+ func.func @extract_llvm_constnt() -> i32 {
+ %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+ %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+ %2 = llvm.mlir.constant(0 : index) : i64
+ %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+ %4 = vector.extract %1[%3, %3] : i32 from vector<32x1xi32>
+ return %4 : i32
+ }
+}
+
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+// CHECK: return %[[VAL_4]] : i32
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+ return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_7:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_8]] : vector<32x1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_llvm_constnt()
+
+module {
+ func.func @insert_llvm_constnt() -> vector<32x1xi32> {
+ %0 = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+ %1 = builtin.unrealized_conversion_cast %0 : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+ %2 = llvm.mlir.constant(0 : index) : i64
+ %3 = builtin.unrealized_conversion_cast %2 : i64 to index
+ %4 = llvm.mlir.constant(1 : i32) : i32
+ %5 = vector.insert %4, %1 [%3, %3] : i32 into vector<32x1xi32>
+ return %5 : vector<32x1xi32>
+ }
+}
+
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(dense<0> : vector<32x1xi32>) : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_2]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : !llvm.array<32 x vector<1xi32>> to vector<32x1xi32>
+// CHECK: return %[[VAL_7]] : vector<32x1xi32>
+// CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this? It's reasonable for the lowering to expect a canonical form of the operation. This seems to just add support for a non canonical form where the constant positions arent folded in.
At least for me it is useful that the patch is more favorable to the concept of MLIR multilayered intermediate representations.
With this PR you can use the above pipeline lower the code below.(Honestly, the code below should theoretically be able to manually codegen out the llvm dialect IR. since the loop count is clear. However, you need to take into account the level where the current IR is located, so there are times when it's better to use a loop instead of generating the llvm dialect directly.)
This is why this PR came up. |
Have you tried inserting -canonicalize before -convert-arith-to-llvm |
|
There is an error in the above program, I corrected it, the following program can be done with pipeline
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's reasonable for the lowering to expect a canonical form of the operation
Not necessarily. Canonicalization is expensive and it may not be desirable to run it after every single lowering pass. Plus, canonicalization patterns should not be mixed with conversion patterns. It's a trade-off and I'd rather have some simple cheap additional logic, then require to run the canonicalizer in the middle of lowerings.
This seems to just add support for a non canonical form where the constant positions arent folded in.
This looks through unrealized_conversion_cast
. This is beyond canonicalization. IMO, canonicalization should not be able to see through unrealized casts precisely because it doesn't know how they are going to be realized.
That being said, just running the canonicalizer after affine-unroll
here doesn't fold the constant indices into extract
. I'd start by implementing a folder for vector.extract
that can fold in constants as static positions, and seeing if that is enough.
I left some comments, but these are mostly for general education. I don't expect this approach to land.
I don't actually know how to fix this unrealized_conversion_cast better.
Thank you. |
Advance support for
vector.extract
andvector.insertOp
underdynamic Ops
.You can see the tests for specific changes, the duplicate code should be written as a function, but I don't know where to write it without calling it good. Feel free to give me suggestions, thank you.