-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. #124399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. #124399
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) Changessee #121631 Full diff: https://github.com/llvm/llvm-project/pull/124399.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fbfcb4979b495..5021b097fc5ef6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -1977,6 +1978,46 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}
+// If the dynamic operands of `extractOp` or `insertOp` is result of
+// `constantOp`, then fold it.
+template <typename T>
+static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
+ auto staticPosition = op.getStaticPosition().vec();
+ OperandRange dynamicPosition = op.getDynamicPosition();
+
+ // If the dynamic operands is empty, it is returned directly.
+ if (!dynamicPosition.size())
+ return;
+ unsigned index = 0;
+
+ // `opChange` is a flog. If it is true, it means to update `op` in place.
+ bool opChange = false;
+ for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
+ if (!ShapedType::isDynamic(staticPosition[i]))
+ continue;
+ Value position = dynamicPosition[index++];
+
+ // If it is a block parameter, proceed to the next iteration.
+ if (!position.getDefiningOp()) {
+ operands.push_back(position);
+ continue;
+ }
+
+ if (auto constantOp =
+ mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
+ opChange = true;
+ staticPosition[i] = constantOp.value();
+ continue;
+ }
+ operands.push_back(position);
+ }
+
+ if (opChange) {
+ op.setStaticPosition(staticPosition);
+ op.getOperation()->setOperands(operands);
+ }
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1999,6 +2040,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
+ SmallVector<Value> operands = {getVector()};
+ foldConstantOp(*this, operands);
return OpFoldResult();
}
@@ -3028,6 +3071,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
+ SmallVector<Value> operands = {getSource(), getDest()};
+ foldConstantOp(*this, operands);
return {};
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 29bed9aae56827..f8f5f9039bb146 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
%0 = vector.step : vector<[4]xindex>
return %0 : vector<[4]xindex>
}
+
+// -----
+
+// CHECK-LABEL: @extract_arith_constnt
+func.func @extract_arith_constnt() -> i32 {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
+ return %elem : i32
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
+
+// -----
+
+// CHECK-LABEL: @insert_arith_constnt()
+
+func.func @insert_arith_constnt() -> vector<32x1xi32> {
+ %v = arith.constant dense<0> : vector<32x1xi32>
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : i32
+ %v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
+ return %v_1 : vector<32x1xi32>
+}
+
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index dbe0b39422369c..38771f25934495 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
-// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
-// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
+// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
|
3ff7e59
to
b6b4362
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! This is good, just some test change comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! The overall approach looks reasonable, but the implementation can be simplified.
I seem to be having some problems, I don't seem to be able to turn on integration testing.
|
You are missing a dash before DMLIR_INCLUDE_INTEGRATION_TESTS, should be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG. Would you mind adding support for poison indices? They are about to be introduced by #123488 and they would also fall into the "constant index" category.
mlir/test/Integration/Dialect/Vector/CPU/extract-insert-fold-constant.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates so far 🙏🏻 The folder itself looks good, but I am leaving some suggestions re testing.
Ping: @banach-space Maybe you forgot about this PR, can you speed up this PR a bit? thank you. |
Sorry about the delay, I have been travelling. Posting new round of reviews shortly. |
47ff7c6
to
a2a5c93
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me, thanks so much for seeing this through!
Please wait for @ftynse to unblock this and for @dcaballe and @Groverkss to take another look. Thank you 🙏🏻
Ping @ftynse @Groverkss @dcaballe This PR still needs a final review from you, I re-eliminated the conflict.Thank you. |
… it to extractOp and insertOp. (llvm#124399) add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp.
… it to extractOp and insertOp. (llvm#124399) add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp.
see #121631
It implements the constant folding functionality of
insertOp
andextractOp
.Why implement this feature?
I am implementing efficient Gemm on GPU (currently using cuda core), and when using registers, I should use the vector abstraction level.
Consider the following example.It cannot be lowered to llvm ir. The fundamental reason is that
insertOp
andextractOp
have the abovevector.extract %v[%0,%0] : i32 from vector<32x1xi32>
situation.The fundamental purpose of this PR is to solve this problem.mlir-opt vector-extract-insert.mlir -affine-loop-unroll=“unroll-full=true unroll-num-reps=10” -convert -arith-to-llvm -convert-vector-to-llvm