Skip to content

Commit 111ead9

Browse files
committed
[mlir][vector] Add pattern to break down small reductions into arith ops
The number of vector elements considered 'small' enough to extract is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Targets without dedicated reduction intrinsics can use that as an emulation path too. Depends on #75846. Please enter the commit message for your changes. Lines starting
1 parent a528cee commit 111ead9

File tree

4 files changed

+231
-0
lines changed

4 files changed

+231
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
166166
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
167167
PatternBenefit benefit = 1);
168168

169+
/// Patterns to break down vector reductions into a series of arith reductions
170+
/// over vector elements. This is intended to be simplify code with reductions
171+
/// over small vector types and avoid more specialized reduction lowering when
172+
/// possible.
173+
///
174+
/// Example:
175+
/// ```
176+
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
177+
/// ```
178+
/// is transformed into:
179+
/// ```
180+
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
181+
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
182+
/// %a = arith.addf %y, %z : f32
183+
/// ```
184+
void populateBreakDownVectorReductionPatterns(
185+
RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
186+
PatternBenefit benefit = 1);
187+
169188
/// Populate `patterns` with the following patterns.
170189
///
171190
/// [DecomposeDifferentRankInsertStridedSlice]

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1414

15+
#include <cassert>
1516
#include <cstdint>
1617
#include <functional>
1718
#include <optional>
@@ -44,6 +45,7 @@
4445
#include "llvm/ADT/STLExtras.h"
4546
#include "llvm/Support/CommandLine.h"
4647
#include "llvm/Support/Debug.h"
48+
#include "llvm/Support/FormatVariadic.h"
4749
#include "llvm/Support/raw_ostream.h"
4850

4951
#define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,60 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
15781580
}
15791581
};
15801582

1583+
/// Example:
1584+
/// ```
1585+
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
1586+
/// ```
1587+
/// is transformed into:
1588+
/// ```
1589+
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
1590+
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
1591+
/// %a = arith.addf %y, %z : f32
1592+
/// ```
1593+
struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1594+
BreakDownVectorReduction(MLIRContext *context,
1595+
unsigned maxNumElementsToExtract,
1596+
PatternBenefit benefit)
1597+
: OpRewritePattern(context, benefit),
1598+
maxNumElementsToExtract(maxNumElementsToExtract) {}
1599+
1600+
LogicalResult matchAndRewrite(vector::ReductionOp op,
1601+
PatternRewriter &rewriter) const override {
1602+
VectorType type = op.getSourceVectorType();
1603+
if (type.isScalable() || op.isMasked())
1604+
return failure();
1605+
assert(type.getRank() == 1 && "Expected a 1-d vector");
1606+
1607+
int64_t numElems = type.getNumElements();
1608+
if (numElems > maxNumElementsToExtract) {
1609+
return rewriter.notifyMatchFailure(
1610+
op, llvm::formatv("has too many vector elements ({0}) to break down "
1611+
"(max allowed: {1})",
1612+
numElems, maxNumElementsToExtract));
1613+
}
1614+
1615+
Location loc = op.getLoc();
1616+
SmallVector<Value> extracted(numElems, nullptr);
1617+
for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1618+
extractedElem = rewriter.create<vector::ExtractOp>(
1619+
loc, op.getVector(), static_cast<int64_t>(idx));
1620+
1621+
Value res = extracted.front();
1622+
for (auto extractedElem : llvm::drop_begin(extracted))
1623+
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1624+
extractedElem, op.getFastmathAttr());
1625+
if (Value acc = op.getAcc())
1626+
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1627+
op.getFastmathAttr());
1628+
1629+
rewriter.replaceOp(op, res);
1630+
return success();
1631+
}
1632+
1633+
private:
1634+
unsigned maxNumElementsToExtract = 0;
1635+
};
1636+
15811637
} // namespace
15821638

15831639
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1712,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
16561712
PatternBenefit(benefit.getBenefit() + 1));
16571713
}
16581714

1715+
void mlir::vector::populateBreakDownVectorReductionPatterns(
1716+
RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
1717+
PatternBenefit benefit) {
1718+
patterns.add<BreakDownVectorReduction>(patterns.getContext(),
1719+
maxNumElementsToExtract, benefit);
1720+
}
1721+
16591722
//===----------------------------------------------------------------------===//
16601723
// TableGen'd enum attribute definitions
16611724
//===----------------------------------------------------------------------===//
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
2+
3+
// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
4+
5+
// CHECK-LABEL: func.func @reduce_2x_f32(
6+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
7+
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
8+
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
9+
// CHECK-DAG: %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
10+
// CHECK-DAG: %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
11+
// CHECK-DAG: %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
12+
// CHECK-DAG: %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
13+
// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
14+
// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
15+
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
16+
func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
17+
%0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
18+
%1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
19+
%2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
20+
%3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
21+
%4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
22+
%5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
23+
return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
24+
}
25+
26+
// CHECK-LABEL: func.func @reduce_2x_i32(
27+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
28+
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
29+
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
30+
// CHECK-DAG: %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
31+
// CHECK-DAG: %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
32+
// CHECK-DAG: %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
33+
// CHECK-DAG: %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
34+
// CHECK-DAG: %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
35+
// CHECK-DAG: %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
36+
// CHECK-DAG: %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
37+
// CHECK-DAG: %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
38+
// CHECK-DAG: %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
39+
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
40+
func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
41+
%0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
42+
%1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
43+
%2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
44+
%3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
45+
%4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
46+
%5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
47+
%6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
48+
%7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
49+
%8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
50+
return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
51+
}
52+
53+
// CHECK-LABEL: func.func @reduce_1x_f32(
54+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
55+
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
56+
// CHECK-NEXT: return %[[E0]] : f32
57+
func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
58+
%0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
59+
return %0 : f32
60+
}
61+
62+
// CHECK-LABEL: func.func @reduce_1x_acc_f32(
63+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
64+
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
65+
// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
66+
// CHECK-NEXT: return %[[R0]] : f32
67+
func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
68+
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
69+
return %0 : f32
70+
}
71+
72+
// CHECK-LABEL: func.func @reduce_1x_acc_i32(
73+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
74+
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
75+
// CHECK-NEXT: %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
76+
// CHECK-NEXT: return %[[R0]] : i32
77+
func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
78+
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
79+
return %0 : i32
80+
}
81+
82+
// CHECK-LABEL: func.func @reduce_2x_acc_f32(
83+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
84+
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
85+
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
86+
// CHECK: %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
87+
// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
88+
// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
89+
// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
90+
// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
91+
func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
92+
%0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
93+
%1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
94+
return %0, %1 : f32, f32
95+
}
96+
97+
// CHECK-LABEL: func.func @reduce_3x_f32(
98+
// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
99+
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
100+
// CHECK-NEXT: return %[[R0]] : f32
101+
func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
102+
%0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
103+
return %0 : f32
104+
}
105+
106+
// Masking is not handled yet.
107+
// CHECK-LABEL: func.func @reduce_mask_3x_f32
108+
// CHECK-NEXT: %[[M:.+]] = vector.create_mask
109+
// CHECK-NEXT: %[[R:.+]] = vector.mask %[[M]]
110+
// CHECK-SAME: vector.reduction <add>
111+
// CHECK-NEXT: return %[[R]] : f32
112+
func.func @reduce_mask_3x_f32(%arg0: vector<3xf32>, %arg1: index) -> f32 {
113+
%mask = vector.create_mask %arg1 : vector<3xi1>
114+
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<3xf32> into f32 } : vector<3xi1> -> f32
115+
return %0 : f32
116+
}
117+
118+
// Scalable vectors are not supported.
119+
// CHECK-LABEL: func.func @reduce_scalable_f32(
120+
// CHECK-SAME: %[[ARG0:.+]]: vector<[1]xf32>) -> f32 {
121+
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<[1]xf32> into f32
122+
// CHECK-NEXT: return %[[R0]] : f32
123+
func.func @reduce_scalable_f32(%arg0: vector<[1]xf32>) -> f32 {
124+
%0 = vector.reduction <add>, %arg0 : vector<[1]xf32> into f32
125+
return %0 : f32
126+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
439439
}
440440
};
441441

442+
struct TestVectorBreakDownReductionPatterns
443+
: public PassWrapper<TestVectorBreakDownReductionPatterns,
444+
OperationPass<func::FuncOp>> {
445+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
446+
TestVectorBreakDownReductionPatterns)
447+
448+
StringRef getArgument() const final {
449+
return "test-vector-break-down-reduction-patterns";
450+
}
451+
StringRef getDescription() const final {
452+
return "Test patterns to break down vector reductions into arith "
453+
"reductions";
454+
}
455+
void runOnOperation() override {
456+
RewritePatternSet patterns(&getContext());
457+
populateBreakDownVectorReductionPatterns(patterns,
458+
/*maxNumElementsToExtract=*/2);
459+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
460+
}
461+
};
462+
442463
struct TestFlattenVectorTransferPatterns
443464
: public PassWrapper<TestFlattenVectorTransferPatterns,
444465
OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
827848

828849
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
829850

851+
PassRegistration<TestVectorBreakDownReductionPatterns>();
852+
830853
PassRegistration<TestFlattenVectorTransferPatterns>();
831854

832855
PassRegistration<TestVectorScanLowering>();

0 commit comments

Comments
 (0)