Skip to content

[mlir][vector] Add pattern to break down reductions into arith ops #75727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Patterns to break down vector reductions into a series of arith reductions
/// over vector elements. This is intended to be simplify code with reductions
/// over small vector types and avoid more specialized reduction lowering when
/// possible.
///
/// Example:
/// ```
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
/// ```
/// is transformed into:
/// ```
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
/// %a = arith.addf %y, %z : f32
/// ```
void populateBreakDownVectorReductionPatterns(
RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
PatternBenefit benefit = 1);

/// Populate `patterns` with the following patterns.
///
/// [DecomposeDifferentRankInsertStridedSlice]
Expand Down
63 changes: 63 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"

#include <cassert>
#include <cstdint>
#include <functional>
#include <optional>
Expand Down Expand Up @@ -44,6 +45,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "vector-to-vector"
Expand Down Expand Up @@ -1578,6 +1580,60 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
}
};

/// Example:
/// ```
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
/// ```
/// is transformed into:
/// ```
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
/// %a = arith.addf %y, %z : f32
/// ```
struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
BreakDownVectorReduction(MLIRContext *context,
unsigned maxNumElementsToExtract,
PatternBenefit benefit)
: OpRewritePattern(context, benefit),
maxNumElementsToExtract(maxNumElementsToExtract) {}

LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
VectorType type = op.getSourceVectorType();
if (type.isScalable() || op.isMasked())
return failure();
assert(type.getRank() == 1 && "Expected a 1-d vector");

int64_t numElems = type.getNumElements();
if (numElems > maxNumElementsToExtract) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("has too many vector elements ({0}) to break down "
"(max allowed: {1})",
numElems, maxNumElementsToExtract));
}

Location loc = op.getLoc();
SmallVector<Value> extracted(numElems, nullptr);
for (auto [idx, extractedElem] : llvm::enumerate(extracted))
extractedElem = rewriter.create<vector::ExtractOp>(
loc, op.getVector(), static_cast<int64_t>(idx));

Value res = extracted.front();
for (auto extractedElem : llvm::drop_begin(extracted))
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
extractedElem, op.getFastmathAttr());
if (Value acc = op.getAcc())
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
op.getFastmathAttr());

rewriter.replaceOp(op, res);
return success();
}

private:
unsigned maxNumElementsToExtract = 0;
};

} // namespace

void mlir::vector::populateFoldArithExtensionPatterns(
Expand Down Expand Up @@ -1656,6 +1712,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
PatternBenefit(benefit.getBenefit() + 1));
}

void mlir::vector::populateBreakDownVectorReductionPatterns(
RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
PatternBenefit benefit) {
patterns.add<BreakDownVectorReduction>(patterns.getContext(),
maxNumElementsToExtract, benefit);
}

//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
Expand Down
126 changes: 126 additions & 0 deletions mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s

// NOTE: This test pass is set break down vector reductions of size 2 or fewer.

// CHECK-LABEL: func.func @reduce_2x_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
// CHECK-DAG: %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
// CHECK-DAG: %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
// CHECK-DAG: %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
// CHECK-DAG: %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
%0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
%1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
%2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
%3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
%4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
%5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
}

// CHECK-LABEL: func.func @reduce_2x_i32(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
// CHECK-DAG: %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
// CHECK-DAG: %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
%0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
%1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
%2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
%3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
%4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
%5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
%6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
%7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
%8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
}

// CHECK-LABEL: func.func @reduce_1x_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
// CHECK-NEXT: return %[[E0]] : f32
func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
return %0 : f32
}

// CHECK-LABEL: func.func @reduce_1x_acc_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
// CHECK-NEXT: return %[[R0]] : f32
func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
return %0 : f32
}

// CHECK-LABEL: func.func @reduce_1x_acc_i32(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
// CHECK-NEXT: %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
// CHECK-NEXT: return %[[R0]] : i32
func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
return %0 : i32
}

// CHECK-LABEL: func.func @reduce_2x_acc_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
// CHECK: %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
%1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
return %0, %1 : f32, f32
}

// CHECK-LABEL: func.func @reduce_3x_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
// CHECK-NEXT: return %[[R0]] : f32
func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
return %0 : f32
}

// Masking is not handled yet.
// CHECK-LABEL: func.func @reduce_mask_3x_f32
// CHECK-NEXT: %[[M:.+]] = vector.create_mask
// CHECK-NEXT: %[[R:.+]] = vector.mask %[[M]]
// CHECK-SAME: vector.reduction <add>
// CHECK-NEXT: return %[[R]] : f32
func.func @reduce_mask_3x_f32(%arg0: vector<3xf32>, %arg1: index) -> f32 {
%mask = vector.create_mask %arg1 : vector<3xi1>
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<3xf32> into f32 } : vector<3xi1> -> f32
return %0 : f32
}

// Scalable vectors are not supported.
// CHECK-LABEL: func.func @reduce_scalable_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<[1]xf32>) -> f32 {
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<[1]xf32> into f32
// CHECK-NEXT: return %[[R0]] : f32
func.func @reduce_scalable_f32(%arg0: vector<[1]xf32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<[1]xf32> into f32
return %0 : f32
}
23 changes: 23 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
}
};

struct TestVectorBreakDownReductionPatterns
: public PassWrapper<TestVectorBreakDownReductionPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorBreakDownReductionPatterns)

StringRef getArgument() const final {
return "test-vector-break-down-reduction-patterns";
}
StringRef getDescription() const final {
return "Test patterns to break down vector reductions into arith "
"reductions";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateBreakDownVectorReductionPatterns(patterns,
/*maxNumElementsToExtract=*/2);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -827,6 +848,8 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorChainedReductionFoldingPatterns>();

PassRegistration<TestVectorBreakDownReductionPatterns>();

PassRegistration<TestFlattenVectorTransferPatterns>();

PassRegistration<TestVectorScanLowering>();
Expand Down