Skip to content

Commit 8d6469b

Browse files
[mlir][vector] Add lower-vector-multi-reduction pass (#87333)
This MR adds the `lower-vector-multi-reduction` pass to lower the vector.multi_reduction operation. While the Transform Dialect includes an operation, `transform.apply_patterns.vector.lower_multi_reduction`, intended for a similar purpose, its utility is limited to projects that have adopted the Transform Dialect. Recognizing that not all projects are equipped to integrate this dialect, the proposed pass serves as a vital standalone alternative. It ensures that projects solely dependent on the traditional pass infrastructure can also benefit from the optimized lowering of `multi_reduction` operation. --------- Co-authored-by: Xiaolei Shi <[email protected]>
1 parent 4956118 commit 8d6469b

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
1010
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
1111

12+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1213
#include "mlir/Pass/Pass.h"
1314

1415
namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
2223
/// Creates an instance of the `vector.mask` lowering pass.
2324
std::unique_ptr<Pass> createLowerVectorMaskPass();
2425

26+
/// Creates an instance of the `vector.multi_reduction` lowering pass.
27+
std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
28+
VectorMultiReductionLowering option =
29+
VectorMultiReductionLowering::InnerParallel);
30+
2531
//===----------------------------------------------------------------------===//
2632
// Registration
2733
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/Transforms/Passes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
2121
let constructor = "mlir::vector::createLowerVectorMaskPass()";
2222
}
2323

24+
def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
25+
let summary = "Lower 'vector.multi_reduction' operations";
26+
let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
27+
let options = [
28+
Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
29+
/*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
30+
"Select the strategy to control how multi_reduction is lowered.",
31+
[{::llvm::cl::values(
32+
clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
33+
"inner-parallel",
34+
"Lower multi_reduction into outer-reduction and inner-parallel ops."),
35+
clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
36+
"inner-reduction",
37+
"Lower multi_reduction into outer-parallel and inner-reduction ops.")
38+
)}]>
39+
];
40+
}
41+
2442
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES

mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,19 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1516
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
#include "mlir/Dialect/Vector/Transforms/Passes.h"
1618
#include "mlir/IR/Builders.h"
1719
#include "mlir/IR/TypeUtilities.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
22+
namespace mlir {
23+
namespace vector {
24+
#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25+
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26+
} // namespace vector
27+
} // namespace mlir
1828

1929
#define DEBUG_TYPE "vector-multi-reduction"
2030

@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
461471
return success();
462472
}
463473
};
474+
475+
struct LowerVectorMultiReductionPass
476+
: public vector::impl::LowerVectorMultiReductionBase<
477+
LowerVectorMultiReductionPass> {
478+
LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
479+
this->loweringStrategy = option;
480+
}
481+
482+
void runOnOperation() override {
483+
Operation *op = getOperation();
484+
MLIRContext *context = op->getContext();
485+
486+
RewritePatternSet loweringPatterns(context);
487+
populateVectorMultiReductionLoweringPatterns(loweringPatterns,
488+
this->loweringStrategy);
489+
490+
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
491+
signalPassFailure();
492+
}
493+
494+
void getDependentDialects(DialectRegistry &registry) const override {
495+
registry.insert<vector::VectorDialect>();
496+
}
497+
};
498+
464499
} // namespace
465500

466501
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
476511
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
477512
benefit);
478513
}
514+
515+
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
516+
vector::VectorMultiReductionLowering option) {
517+
return std::make_unique<LowerVectorMultiReductionPass>(option);
518+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-REDUCTION
2+
// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
3+
// RUN: mlir-opt -lower-vector-multi-reduction -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
4+
5+
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
6+
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
7+
return %0 : vector<2xf32>
8+
}
9+
// ALL-LABEL: func @vector_multi_reduction
10+
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
11+
// INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
12+
// INNER-REDUCTION-DAG: %[[C0:.+]] = arith.constant 0 : index
13+
// INNER-REDUCTION-DAG: %[[C1:.+]] = arith.constant 1 : index
14+
// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
15+
// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
16+
// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
17+
// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
18+
// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
19+
// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
20+
// INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
21+
// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
22+
// INNER-REDUCTION: return %[[RESULT_VEC]]
23+
24+
// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
25+
// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
26+
// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
27+
// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
28+
// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
29+
// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
30+
// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
31+
// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
32+
// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
33+
// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
34+
35+
// -----
36+
37+
func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
38+
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
39+
return %0 : vector<4xf32>
40+
}
41+
42+
// ALL-LABEL: func @vector_multi_reduction_parallel_middle
43+
// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
44+
// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
45+
// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>

0 commit comments

Comments
 (0)