-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add lower-vector-multi-reduction pass #87333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add lower-vector-multi-reduction pass #87333
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: None (xiaoleis-nv) ChangesThis MR adds the Full diff: https://github.com/llvm/llvm-project/pull/87333.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60c5..911402551e14d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
+/// Creates an instance of the `vector.multi_reduction` lowering pass.
+std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
+ VectorMultiReductionLowering option =
+ VectorMultiReductionLowering::InnerParallel);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..31a0b3b2f0c53d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}
+def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
+ let summary = "Lower 'vector.multi_reduction' operations";
+ let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
+ let options = [
+ Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
+ /*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
+ "Select the strategy to control how multi_reduction is lowered.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
+ "inner-parallel",
+ "Lower multi_reduction into outer-reduction and inner-parallel ops."),
+ clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
+ "inner-reduction",
+ "Lower multi_reduction into outer-parallel and inner-reduction ops.")
+ )}]>
+ ];
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd..2f21c50c63473b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -12,9 +12,19 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
#define DEBUG_TYPE "vector-multi-reduction"
@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
return success();
}
};
+
+struct LowerVectorMultiReductionPass
+ : public vector::impl::LowerVectorMultiReductionBase<
+ LowerVectorMultiReductionPass> {
+ LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
+ this->loweringStrategy = option;
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *context = op->getContext();
+
+ RewritePatternSet loweringPatterns(context);
+ populateVectorMultiReductionLoweringPatterns(loweringPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
benefit);
}
+
+std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
+ vector::VectorMultiReductionLowering option) {
+ return std::make_unique<LowerVectorMultiReductionPass>(option);
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
new file mode 100644
index 00000000000000..502cd7a1cbbcbc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_min
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..f051ce73fc49b4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]]
+
+// -----
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
+// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
+// CHECK: return %[[RES]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of the PR? I think we already have test coverages in these two files:
mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
Outdated
Show resolved
Hide resolved
@xiaoleis-nv - could you elaborate a bit on the rationale for introducing this pass? And add it to the summary? :) |
@hanhanW @banach-space I've updated this PR as per your suggestions, including a revised summary that explains the rationale behind this PR. Could you please take another look and review it? Thank you. |
I don't have the permission to merge this PR. Could someone help me by merging it if there are no further concerns? Many thanks! @joker-eph @banach-space @hanhanW |
mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
Outdated
Show resolved
Hide resolved
I can land it for you tomorrow if there are no new traffic. Thanks for addressing my comments! |
I can help land this if my comment of |
@banach-space @hanhanW, all comments on this PR have been addressed, and it looks good to me for merging if there are no further comments. Thank you for your valuable suggestions! |
I think all the comments are addressed, so I'm going to land the PR. If there are other comments, let's address them in a follow-up. |
This MR adds the
lower-vector-multi-reduction
pass to lower the vector.multi_reduction operation.While the Transform Dialect includes an operation,
transform.apply_patterns.vector.lower_multi_reduction
, intended for a similar purpose, its utility is limited to projects that have adopted the Transform Dialect. Recognizing that not all projects are equipped to integrate this dialect, the proposed pass serves as a vital standalone alternative. It ensures that projects solely dependent on the traditional pass infrastructure can also benefit from the optimized lowering ofmulti_reduction
operation.