Skip to content

[mlir][vector] Add lower-vector-multi-reduction pass #87333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_

#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
Expand All @@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();

/// Creates an instance of the `vector.multi_reduction` lowering pass.
std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
VectorMultiReductionLowering option =
VectorMultiReductionLowering::InnerParallel);

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}

def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
let summary = "Lower 'vector.multi_reduction' operations";
let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
let options = [
Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
/*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
"Select the strategy to control how multi_reduction is lowered.",
[{::llvm::cl::values(
clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
"inner-parallel",
"Lower multi_reduction into outer-reduction and inner-parallel ops."),
clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
"inner-reduction",
"Lower multi_reduction into outer-parallel and inner-reduction ops.")
)}]>
];
}

#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace vector {
#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace vector
} // namespace mlir

#define DEBUG_TYPE "vector-multi-reduction"

Expand Down Expand Up @@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
return success();
}
};

struct LowerVectorMultiReductionPass
: public vector::impl::LowerVectorMultiReductionBase<
LowerVectorMultiReductionPass> {
LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
this->loweringStrategy = option;
}

void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();

RewritePatternSet loweringPatterns(context);
populateVectorMultiReductionLoweringPatterns(loweringPatterns,
this->loweringStrategy);

if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
};

} // namespace

void mlir::vector::populateVectorMultiReductionLoweringPatterns(
Expand All @@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
benefit);
}

std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
vector::VectorMultiReductionLowering option) {
return std::make_unique<LowerVectorMultiReductionPass>(option);
}
45 changes: 45 additions & 0 deletions mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-REDUCTION
// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
// RUN: mlir-opt -lower-vector-multi-reduction -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL

func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// ALL-LABEL: func @vector_multi_reduction
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
// INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
// INNER-REDUCTION-DAG: %[[C0:.+]] = arith.constant 0 : index
// INNER-REDUCTION-DAG: %[[C1:.+]] = arith.constant 1 : index
// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
// INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
// INNER-REDUCTION: return %[[RESULT_VEC]]

// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>

// -----

func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
return %0 : vector<4xf32>
}

// ALL-LABEL: func @vector_multi_reduction_parallel_middle
// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>