Skip to content

[mlir][vector] Add lower-vector-multi-reduction pass #87333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

xiaoleis-nv
Copy link
Contributor

@xiaoleis-nv xiaoleis-nv commented Apr 2, 2024

This MR adds the lower-vector-multi-reduction pass to lower the vector.multi_reduction operation.

While the Transform Dialect includes an operation, transform.apply_patterns.vector.lower_multi_reduction, intended for a similar purpose, its utility is limited to projects that have adopted the Transform Dialect. Recognizing that not all projects are equipped to integrate this dialect, the proposed pass serves as a vital standalone alternative. It ensures that projects solely dependent on the traditional pass infrastructure can also benefit from the optimized lowering of multi_reduction operation.

@llvmbot
Copy link
Member

llvmbot commented Apr 2, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: None (xiaoleis-nv)

Changes

This MR adds the lower-vector-multi-reduction pass to lower the vector.multi_reduction operation. Two test files have been added to ensure that different lowering strategies work as expected.


Full diff: https://github.com/llvm/llvm-project/pull/87333.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.h (+6)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+18)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+40)
  • (added) mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir (+39)
  • (added) mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60c5..911402551e14d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
 /// Creates an instance of the `vector.mask` lowering pass.
 std::unique_ptr<Pass> createLowerVectorMaskPass();
 
+/// Creates an instance of the `vector.multi_reduction` lowering pass.
+std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
+    VectorMultiReductionLowering option =
+        VectorMultiReductionLowering::InnerParallel);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..31a0b3b2f0c53d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
+  let summary = "Lower 'vector.multi_reduction' operations";
+  let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
+  let options = [
+    Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
+           /*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
+           "Select the strategy to control how multi_reduction is lowered.",
+           [{::llvm::cl::values(
+            clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
+                       "inner-parallel",
+                       "Lower multi_reduction into outer-reduction and inner-parallel ops."),
+            clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
+                       "inner-reduction",
+                       "Lower multi_reduction into outer-parallel and inner-reduction ops.")
+        )}]>
+  ];
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd..2f21c50c63473b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -12,9 +12,19 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
 
 #define DEBUG_TYPE "vector-multi-reduction"
 
@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
     return success();
   }
 };
+
+struct LowerVectorMultiReductionPass
+    : public vector::impl::LowerVectorMultiReductionBase<
+          LowerVectorMultiReductionPass> {
+  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
+    this->loweringStrategy = option;
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *context = op->getContext();
+
+    RewritePatternSet loweringPatterns(context);
+    populateVectorMultiReductionLoweringPatterns(loweringPatterns,
+                                                 this->loweringStrategy);
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
                                                     benefit);
 }
+
+std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
+    vector::VectorMultiReductionLowering option) {
+  return std::make_unique<LowerVectorMultiReductionPass>(option);
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
new file mode 100644
index 00000000000000..502cd7a1cbbcbc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_min
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..f051ce73fc49b4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+//   CHECK-DAG:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
+//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+//       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
+//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+//       CHECK:       return %[[RESULT_VEC]]
+
+// -----
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
+//       CHECK:   %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+//       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
+//       CHECK:   return %[[RES]]

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space
Copy link
Contributor

What is the purpose of the PR?

@xiaoleis-nv - could you elaborate a bit on the rationale for introducing this pass? And add it to the summary? :)

@xiaoleis-nv
Copy link
Contributor Author

@hanhanW @banach-space I've updated this PR as per your suggestions, including a revised summary that explains the rationale behind this PR. Could you please take another look and review it? Thank you.

@xiaoleis-nv
Copy link
Contributor Author

xiaoleis-nv commented Apr 8, 2024

I don't have the permission to merge this PR. Could someone help me by merging it if there are no further concerns? Many thanks! @joker-eph @banach-space @hanhanW

@banach-space
Copy link
Contributor

I don't have the permission to merge this PR. Could someone help me by merging it if there are no further concerns? Many thanks! @joker-eph @banach-space @hanhanW

I can land it for you tomorrow if there are no new traffic. Thanks for addressing my comments!

@hanhanW
Copy link
Contributor

hanhanW commented Apr 8, 2024

I can help land this if my comment of // ----- is addressed.

@xiaoleis-nv
Copy link
Contributor Author

@banach-space @hanhanW, all comments on this PR have been addressed, and it looks good to me for merging if there are no further comments. Thank you for your valuable suggestions!

@hanhanW
Copy link
Contributor

hanhanW commented Apr 9, 2024

I think all the comments are addressed, so I'm going to land the PR. If there are other comments, let's address them in a follow-up.

@hanhanW hanhanW changed the title Add lower-vector-multi-reduction pass [mlir][vector] Add lower-vector-multi-reduction pass Apr 9, 2024
@hanhanW hanhanW merged commit 8d6469b into llvm:main Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants