Skip to content

Commit 944a2fe

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add callbacks to fusion of elementwise operations to control fusion.
Right now Elementwise operations fusion in Linalg fuses everything it can. This can run up against resource limits of the target hardware without some checks. This patch adds a callback function that clients can use to implement a cost function. When two elementwise operations are deemed structurally fusable, the callback can be used to control if the fusion applies. Differential Revision: https://reviews.llvm.org/D99820
1 parent 8c7bf2f commit 944a2fe

File tree

6 files changed

+216
-17
lines changed

6 files changed

+216
-17
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class FrozenRewritePatternSet;
2323

2424
namespace linalg {
2525

26+
struct LinalgElementwiseFusionOptions;
2627
struct LinalgFusionOptions;
2728
struct LinalgTilingOptions;
2829

@@ -69,9 +70,40 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
6970
/// tensors.
7071
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
7172

73+
using ControlElementwiseOpsFusionFn =
74+
std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
75+
76+
/// Options that control fusion of elementwise operations.
77+
struct LinalgElementwiseFusionOptions {
78+
/// Enable fusion of reshapes that are introducing unit-dimensions into the
79+
/// shape with elementwise operations. By default this is disabled.
80+
bool allowFoldingUnitDimReshapes = false;
81+
82+
LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
83+
allowFoldingUnitDimReshapes = val;
84+
return *this;
85+
}
86+
87+
/// Function that allows the caller to control when to stop fusion. Once a
88+
/// producer is deemed fusable with the consumer (structurally), this callback
89+
/// can be used to abort the fusion based on non-structural constraints. This
90+
/// is the hook for cost models to control the amount of fusion done.
91+
ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn =
92+
[](const OpResult & /*producer */, const OpOperand & /*consumer */) {
93+
return true;
94+
};
95+
96+
LinalgElementwiseFusionOptions &
97+
setControlElementwiseOpsFusionFn(ControlElementwiseOpsFusionFn fun) {
98+
controlElementwiseOpsFusionFn = std::move(fun);
99+
return *this;
100+
}
101+
};
102+
72103
/// Patterns for fusing linalg operation on tensors.
73104
void populateElementwiseOpsFusionPatterns(
74-
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
105+
RewritePatternSet &patterns,
106+
LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
75107

76108
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
77109
/// and permute the loop nest according to `interchangeVector`

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
4848
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
4949
return false;
5050

51+
// Currently support only operations with single result.
52+
if (producer.getNumOutputs() != 1)
53+
return false;
54+
5155
// Finally the index_map for the result must be invertible. For now just
5256
// verify it is a permutation.
5357
AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
@@ -209,10 +213,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
209213

210214
static Optional<SmallVector<Value, 1>>
211215
fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
216+
const ControlElementwiseOpsFusionFn &controlFn,
212217
PatternRewriter &rewriter) {
213218
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
214219
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
215-
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
220+
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
221+
!controlFn(producer->getResult(0), consumerOpOperand))
216222
return llvm::None;
217223

218224
unsigned numFusedOperands =
@@ -1041,18 +1047,22 @@ struct FoldReshapeWithGenericOpByExpansion
10411047

10421048
/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
10431049
template <typename LinalgOpTy>
1044-
struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
1045-
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1050+
class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
1051+
public:
1052+
FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1053+
PatternBenefit benefit = 1)
1054+
: OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
10461055

10471056
LogicalResult matchAndRewrite(LinalgOpTy op,
10481057
PatternRewriter &rewriter) const override {
10491058
if (!op.hasTensorSemantics())
10501059
return failure();
10511060
LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
1052-
for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
1053-
ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>();
1061+
for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) {
1062+
ConstantOp constantOp = operand.value().get().getDefiningOp<ConstantOp>();
10541063
if (!constantOp ||
1055-
!constantOp.value().cast<DenseElementsAttr>().isSplat())
1064+
!constantOp.value().cast<DenseElementsAttr>().isSplat() ||
1065+
!controlFn(constantOp->getResult(0), operand.value()))
10561066
continue;
10571067

10581068
// The indexing_maps for the operands of the fused operation are same as
@@ -1099,11 +1109,15 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
10991109
}
11001110
return failure();
11011111
}
1112+
1113+
private:
1114+
ControlElementwiseOpsFusionFn controlFn;
11021115
};
11031116
} // namespace
11041117

11051118
static Optional<SmallVector<Value, 1>>
1106-
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
1119+
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
1120+
const ControlElementwiseOpsFusionFn &controlFn) {
11071121
Operation *producer = consumerOpOperand.get().getDefiningOp();
11081122
if (!producer || producer->getNumResults() != 1)
11091123
return llvm::None;
@@ -1114,14 +1128,17 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
11141128
return llvm::None;
11151129

11161130
return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
1117-
rewriter);
1131+
controlFn, rewriter);
11181132
}
11191133

11201134
namespace {
11211135
/// Patterns to fuse a generic op, with the producer of its operands.
11221136
template <typename LinalgOpTy>
1123-
struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
1124-
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1137+
class FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
1138+
public:
1139+
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1140+
PatternBenefit benefit = 1)
1141+
: OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
11251142

11261143
LogicalResult matchAndRewrite(LinalgOpTy op,
11271144
PatternRewriter &rewriter) const override {
@@ -1132,14 +1149,17 @@ struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
11321149
if (!producerOp || !producerOp.hasTensorSemantics())
11331150
continue;
11341151
Optional<SmallVector<Value, 1>> fusedOpResults =
1135-
fuseElementwiseOps(rewriter, opOperand);
1152+
fuseElementwiseOps(rewriter, opOperand, controlFn);
11361153
if (fusedOpResults) {
11371154
rewriter.replaceOp(op, *fusedOpResults);
11381155
return success();
11391156
}
11401157
}
11411158
return failure();
11421159
}
1160+
1161+
private:
1162+
ControlElementwiseOpsFusionFn controlFn;
11431163
};
11441164

11451165
/// Pass that fuses generic ops on tensors. Used only for testing.
@@ -1148,7 +1168,10 @@ struct FusionOfTensorOpsPass
11481168
void runOnOperation() override {
11491169
Operation *op = getOperation();
11501170
RewritePatternSet patterns(op->getContext());
1151-
populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
1171+
populateElementwiseOpsFusionPatterns(
1172+
patterns,
1173+
LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
1174+
allowFoldingUnitDimReshapes));
11521175
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
11531176
}
11541177
};
@@ -1193,14 +1216,14 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
11931216
}
11941217

11951218
void mlir::linalg::populateElementwiseOpsFusionPatterns(
1196-
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
1219+
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
11971220
auto *context = patterns.getContext();
11981221
patterns
11991222
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
12001223
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
1201-
context);
1202-
populateFoldReshapeOpsByExpansionPatterns(patterns,
1203-
allowFoldingUnitDimReshapes);
1224+
context, options.controlElementwiseOpsFusionFn);
1225+
populateFoldReshapeOpsByExpansionPatterns(
1226+
patterns, options.allowFoldingUnitDimReshapes);
12041227
GenericOp::getCanonicalizationPatterns(patterns, context);
12051228
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
12061229
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
2+
3+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
4+
#binary2Dpointwise = {
5+
indexing_maps = [#map0, #map0, #map0],
6+
iterator_types = ["parallel", "parallel"]
7+
}
8+
#ternary2Dpointwise = {
9+
indexing_maps = [#map0, #map0, #map0, #map0],
10+
iterator_types = ["parallel", "parallel"]
11+
}
12+
func @test_fusion_limit(
13+
%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
14+
%arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>, %arg5 : tensor<?x?xf32>)
15+
-> tensor<?x?xf32> {
16+
%c0 = constant 0 : index
17+
%c1 = constant 1 : index
18+
%d0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
19+
%d1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
20+
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
21+
%0 = linalg.generic #binary2Dpointwise
22+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
23+
outs(%init : tensor<?x?xf32>) {
24+
^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32):
25+
%1 = mulf %arg6, %arg7 : f32
26+
linalg.yield %1 : f32
27+
} -> tensor<?x?xf32>
28+
%2 = linalg.generic #binary2Dpointwise
29+
ins(%arg2, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
30+
outs(%init : tensor<?x?xf32>) {
31+
^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32):
32+
%3 = mulf %arg6, %arg7 : f32
33+
linalg.yield %3 : f32
34+
} -> tensor<?x?xf32>
35+
%4 = linalg.generic #binary2Dpointwise
36+
ins(%arg4, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
37+
outs(%init : tensor<?x?xf32>) {
38+
^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32):
39+
%5 = mulf %arg6, %arg7 : f32
40+
linalg.yield %5 : f32
41+
} -> tensor<?x?xf32>
42+
%6 = linalg.generic #ternary2Dpointwise
43+
ins(%0, %2, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
44+
outs(%init : tensor<?x?xf32>) {
45+
^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32, %arg9 : f32):
46+
%7 = addf %arg6, %arg7 : f32
47+
%8 = addf %7, %arg8 : f32
48+
linalg.yield %8 : f32
49+
} -> tensor<?x?xf32>
50+
return %6 : tensor<?x?xf32>
51+
}
52+
// CHECK-LABEL: func @test_fusion_limit
53+
// CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
54+
// CHECK-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
55+
// CHECK-SAME: %[[ARG2:[a-zA-z0-9_]+]]: tensor<?x?xf32>
56+
// CHECK-SAME: %[[ARG3:[a-zA-z0-9_]+]]: tensor<?x?xf32>
57+
// CHECK-SAME: %[[ARG4:[a-zA-z0-9_]+]]: tensor<?x?xf32>
58+
// CHECK-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
59+
// CHECK: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
60+
// CHECK: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
61+
// CHECK: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
62+
// CHECK: return %[[OP3]]

mlir/test/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_mlir_library(MLIRTestTransforms
1919
TestGpuRewrite.cpp
2020
TestInlining.cpp
2121
TestLinalgCodegenStrategy.cpp
22+
TestLinalgElementwiseFusion.cpp
2223
TestLinalgFusionTransforms.cpp
2324
TestLinalgHoisting.cpp
2425
TestLinalgTransforms.cpp
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass for testing fusion of elementwise operations in
10+
// Linalg, mainly linalg options.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Pass/PassManager.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include "llvm/ADT/TypeSwitch.h"
19+
20+
namespace mlir {
21+
22+
static void addOperands(Operation *op, llvm::SetVector<Value> &operandSet) {
23+
if (!op)
24+
return;
25+
TypeSwitch<Operation *, void>(op)
26+
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
27+
operandSet.insert(linalgOp.getInputs().begin(),
28+
linalgOp.getInputs().end());
29+
})
30+
.Default([&](Operation *operation) {
31+
operandSet.insert(operation->operand_begin(), operation->operand_end());
32+
});
33+
}
34+
35+
template <int limit = 3>
36+
static bool setFusedOpOperandLimit(const OpResult &producer,
37+
const OpOperand &consumer) {
38+
llvm::SetVector<Value> fusedOpOperands;
39+
if (producer.getOwner()->getNumResults() != 1)
40+
return false;
41+
addOperands(consumer.getOwner(), fusedOpOperands);
42+
fusedOpOperands.remove(producer);
43+
addOperands(producer.getOwner(), fusedOpOperands);
44+
return fusedOpOperands.size() <= limit;
45+
}
46+
47+
namespace {
48+
struct TestLinalgElementwiseFusion
49+
: public PassWrapper<TestLinalgElementwiseFusion, FunctionPass> {
50+
void getDependentDialects(DialectRegistry &registry) const override {
51+
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
52+
tensor::TensorDialect>();
53+
}
54+
55+
void runOnFunction() override {
56+
MLIRContext *context = &this->getContext();
57+
FuncOp funcOp = this->getFunction();
58+
RewritePatternSet fusionPatterns(context);
59+
60+
linalg::populateElementwiseOpsFusionPatterns(
61+
fusionPatterns,
62+
linalg::LinalgElementwiseFusionOptions()
63+
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
64+
65+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
66+
std::move(fusionPatterns));
67+
}
68+
};
69+
} // namespace
70+
71+
namespace test {
72+
void registerTestLinalgElementwiseFusion() {
73+
PassRegistration<TestLinalgElementwiseFusion> testElementwiseFusionPass(
74+
"test-linalg-elementwise-fusion-patterns",
75+
"Test Linalg element wise operation fusion patterns");
76+
}
77+
} // namespace test
78+
79+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ void registerTestGpuParallelLoopMappingPass();
7777
void registerTestIRVisitorsPass();
7878
void registerTestInterfaces();
7979
void registerTestLinalgCodegenStrategy();
80+
void registerTestLinalgElementwiseFusion();
8081
void registerTestLinalgFusionTransforms();
8182
void registerTestLinalgTensorFusionTransforms();
8283
void registerTestLinalgGreedyFusion();
@@ -154,6 +155,7 @@ void registerTestPasses() {
154155
test::registerTestIRVisitorsPass();
155156
test::registerTestInterfaces();
156157
test::registerTestLinalgCodegenStrategy();
158+
test::registerTestLinalgElementwiseFusion();
157159
test::registerTestLinalgFusionTransforms();
158160
test::registerTestLinalgTensorFusionTransforms();
159161
test::registerTestLinalgGreedyFusion();

0 commit comments

Comments
 (0)