Skip to content

Commit 0767711

Browse files
authored
[mlir][vector] Add pattern to break down reductions into arith ops (#75727)
The number of vector elements considered 'small' enough to extract is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Targets without dedicated reduction intrinsics can use that as an emulation path too. Depends on #75846.
1 parent 5b57da3 commit 0767711

File tree

4 files changed

+231
-0
lines changed

4 files changed

+231
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
166166
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
167167
PatternBenefit benefit = 1);
168168

169+
/// Patterns to break down vector reductions into a series of arith reductions
170+
/// over vector elements. This is intended to be simplify code with reductions
171+
/// over small vector types and avoid more specialized reduction lowering when
172+
/// possible.
173+
///
174+
/// Example:
175+
/// ```
176+
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
177+
/// ```
178+
/// is transformed into:
179+
/// ```
180+
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
181+
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
182+
/// %a = arith.addf %y, %z : f32
183+
/// ```
184+
void populateBreakDownVectorReductionPatterns(
185+
RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
186+
PatternBenefit benefit = 1);
187+
169188
/// Populate `patterns` with the following patterns.
170189
///
171190
/// [DecomposeDifferentRankInsertStridedSlice]

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1414

15+
#include <cassert>
1516
#include <cstdint>
1617
#include <functional>
1718
#include <optional>
@@ -44,6 +45,7 @@
4445
#include "llvm/ADT/STLExtras.h"
4546
#include "llvm/Support/CommandLine.h"
4647
#include "llvm/Support/Debug.h"
48+
#include "llvm/Support/FormatVariadic.h"
4749
#include "llvm/Support/raw_ostream.h"
4850

4951
#define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,60 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
15781580
}
15791581
};
15801582

1583+
/// Example:
1584+
/// ```
1585+
/// %a = vector.reduction <add> %x : vector<2xf32> into f32
1586+
/// ```
1587+
/// is transformed into:
1588+
/// ```
1589+
/// %y = vector.extract %x[0] : f32 from vector<2xf32>
1590+
/// %z = vector.extract %x[1] : f32 from vector<2xf32>
1591+
/// %a = arith.addf %y, %z : f32
1592+
/// ```
1593+
struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1594+
BreakDownVectorReduction(MLIRContext *context,
1595+
unsigned maxNumElementsToExtract,
1596+
PatternBenefit benefit)
1597+
: OpRewritePattern(context, benefit),
1598+
maxNumElementsToExtract(maxNumElementsToExtract) {}
1599+
1600+
LogicalResult matchAndRewrite(vector::ReductionOp op,
1601+
PatternRewriter &rewriter) const override {
1602+
VectorType type = op.getSourceVectorType();
1603+
if (type.isScalable() || op.isMasked())
1604+
return failure();
1605+
assert(type.getRank() == 1 && "Expected a 1-d vector");
1606+
1607+
int64_t numElems = type.getNumElements();
1608+
if (numElems > maxNumElementsToExtract) {
1609+
return rewriter.notifyMatchFailure(
1610+
op, llvm::formatv("has too many vector elements ({0}) to break down "
1611+
"(max allowed: {1})",
1612+
numElems, maxNumElementsToExtract));
1613+
}
1614+
1615+
Location loc = op.getLoc();
1616+
SmallVector<Value> extracted(numElems, nullptr);
1617+
for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1618+
extractedElem = rewriter.create<vector::ExtractOp>(
1619+
loc, op.getVector(), static_cast<int64_t>(idx));
1620+
1621+
Value res = extracted.front();
1622+
for (auto extractedElem : llvm::drop_begin(extracted))
1623+
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1624+
extractedElem, op.getFastmathAttr());
1625+
if (Value acc = op.getAcc())
1626+
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1627+
op.getFastmathAttr());
1628+
1629+
rewriter.replaceOp(op, res);
1630+
return success();
1631+
}
1632+
1633+
private:
1634+
unsigned maxNumElementsToExtract = 0;
1635+
};
1636+
15811637
} // namespace
15821638

15831639
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1712,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
16561712
PatternBenefit(benefit.getBenefit() + 1));
16571713
}
16581714

1715+
void mlir::vector::populateBreakDownVectorReductionPatterns(
1716+
RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
1717+
PatternBenefit benefit) {
1718+
patterns.add<BreakDownVectorReduction>(patterns.getContext(),
1719+
maxNumElementsToExtract, benefit);
1720+
}
1721+
16591722
//===----------------------------------------------------------------------===//
16601723
// TableGen'd enum attribute definitions
16611724
//===----------------------------------------------------------------------===//
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
2+
3+
// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
4+
5+
// CHECK-LABEL: func.func @reduce_2x_f32(
6+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
7+
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
8+
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
9+
// CHECK-DAG: %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
10+
// CHECK-DAG: %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
11+
// CHECK-DAG: %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
12+
// CHECK-DAG: %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
13+
// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
14+
// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
15+
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
16+
func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
17+
%0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
18+
%1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
19+
%2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
20+
%3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
21+
%4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
22+
%5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
23+
return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
24+
}
25+
26+
// CHECK-LABEL: func.func @reduce_2x_i32(
27+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
28+
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
29+
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
30+
// CHECK-DAG: %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
31+
// CHECK-DAG: %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
32+
// CHECK-DAG: %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
33+
// CHECK-DAG: %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
34+
// CHECK-DAG: %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
35+
// CHECK-DAG: %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
36+
// CHECK-DAG: %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
37+
// CHECK-DAG: %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
38+
// CHECK-DAG: %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
39+
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
40+
func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
41+
%0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
42+
%1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
43+
%2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
44+
%3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
45+
%4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
46+
%5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
47+
%6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
48+
%7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
49+
%8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
50+
return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
51+
}
52+
53+
// CHECK-LABEL: func.func @reduce_1x_f32(
54+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
55+
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
56+
// CHECK-NEXT: return %[[E0]] : f32
57+
func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
58+
%0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
59+
return %0 : f32
60+
}
61+
62+
// CHECK-LABEL: func.func @reduce_1x_acc_f32(
63+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
64+
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
65+
// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
66+
// CHECK-NEXT: return %[[R0]] : f32
67+
func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
68+
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
69+
return %0 : f32
70+
}
71+
72+
// CHECK-LABEL: func.func @reduce_1x_acc_i32(
73+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
74+
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
75+
// CHECK-NEXT: %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
76+
// CHECK-NEXT: return %[[R0]] : i32
77+
func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
78+
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
79+
return %0 : i32
80+
}
81+
82+
// CHECK-LABEL: func.func @reduce_2x_acc_f32(
83+
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
84+
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
85+
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
86+
// CHECK: %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
87+
// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
88+
// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
89+
// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
90+
// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
91+
func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
92+
%0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
93+
%1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
94+
return %0, %1 : f32, f32
95+
}
96+
97+
// CHECK-LABEL: func.func @reduce_3x_f32(
98+
// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
99+
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
100+
// CHECK-NEXT: return %[[R0]] : f32
101+
func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
102+
%0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
103+
return %0 : f32
104+
}
105+
106+
// Masking is not handled yet.
107+
// CHECK-LABEL: func.func @reduce_mask_3x_f32
108+
// CHECK-NEXT: %[[M:.+]] = vector.create_mask
109+
// CHECK-NEXT: %[[R:.+]] = vector.mask %[[M]]
110+
// CHECK-SAME: vector.reduction <add>
111+
// CHECK-NEXT: return %[[R]] : f32
112+
func.func @reduce_mask_3x_f32(%arg0: vector<3xf32>, %arg1: index) -> f32 {
113+
%mask = vector.create_mask %arg1 : vector<3xi1>
114+
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<3xf32> into f32 } : vector<3xi1> -> f32
115+
return %0 : f32
116+
}
117+
118+
// Scalable vectors are not supported.
119+
// CHECK-LABEL: func.func @reduce_scalable_f32(
120+
// CHECK-SAME: %[[ARG0:.+]]: vector<[1]xf32>) -> f32 {
121+
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<[1]xf32> into f32
122+
// CHECK-NEXT: return %[[R0]] : f32
123+
func.func @reduce_scalable_f32(%arg0: vector<[1]xf32>) -> f32 {
124+
%0 = vector.reduction <add>, %arg0 : vector<[1]xf32> into f32
125+
return %0 : f32
126+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
439439
}
440440
};
441441

442+
struct TestVectorBreakDownReductionPatterns
443+
: public PassWrapper<TestVectorBreakDownReductionPatterns,
444+
OperationPass<func::FuncOp>> {
445+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
446+
TestVectorBreakDownReductionPatterns)
447+
448+
StringRef getArgument() const final {
449+
return "test-vector-break-down-reduction-patterns";
450+
}
451+
StringRef getDescription() const final {
452+
return "Test patterns to break down vector reductions into arith "
453+
"reductions";
454+
}
455+
void runOnOperation() override {
456+
RewritePatternSet patterns(&getContext());
457+
populateBreakDownVectorReductionPatterns(patterns,
458+
/*maxNumElementsToExtract=*/2);
459+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
460+
}
461+
};
462+
442463
struct TestFlattenVectorTransferPatterns
443464
: public PassWrapper<TestFlattenVectorTransferPatterns,
444465
OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
827848

828849
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
829850

851+
PassRegistration<TestVectorBreakDownReductionPatterns>();
852+
830853
PassRegistration<TestFlattenVectorTransferPatterns>();
831854

832855
PassRegistration<TestVectorScanLowering>();

0 commit comments

Comments
 (0)