@@ -48,6 +48,10 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
48
48
if (consumerIndexMap.getNumResults () != producer.getNumLoops ())
49
49
return false ;
50
50
51
+ // Currently support only operations with single result.
52
+ if (producer.getNumOutputs () != 1 )
53
+ return false ;
54
+
51
55
// Finally the index_map for the result must be invertible. For now just
52
56
// verify it is a permutation.
53
57
AffineMap producerResultIndexMap = producer.getOutputIndexingMap (0 );
@@ -209,10 +213,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
209
213
210
214
static Optional<SmallVector<Value, 1 >>
211
215
fuseElementwiseOpsImpl (LinalgOp producer, OpOperand &consumerOpOperand,
216
+ const ControlElementwiseOpsFusionFn &controlFn,
212
217
PatternRewriter &rewriter) {
213
218
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner ());
214
219
unsigned consumerIdx = consumerOpOperand.getOperandNumber ();
215
- if (!areElementwiseOpsFusable (producer, consumer, consumerIdx))
220
+ if (!areElementwiseOpsFusable (producer, consumer, consumerIdx) ||
221
+ !controlFn (producer->getResult (0 ), consumerOpOperand))
216
222
return llvm::None;
217
223
218
224
unsigned numFusedOperands =
@@ -1041,18 +1047,22 @@ struct FoldReshapeWithGenericOpByExpansion
1041
1047
1042
1048
// / Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
1043
1049
template <typename LinalgOpTy>
1044
- struct FoldSplatConstants : public OpRewritePattern <LinalgOpTy> {
1045
- using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1050
+ class FoldSplatConstants : public OpRewritePattern <LinalgOpTy> {
1051
+ public:
1052
+ FoldSplatConstants (MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1053
+ PatternBenefit benefit = 1 )
1054
+ : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
1046
1055
1047
1056
LogicalResult matchAndRewrite (LinalgOpTy op,
1048
1057
PatternRewriter &rewriter) const override {
1049
1058
if (!op.hasTensorSemantics ())
1050
1059
return failure ();
1051
1060
LinalgOp linalgOp = cast<LinalgOp>(op.getOperation ());
1052
- for (auto operand : llvm::enumerate (linalgOp.getInputs ())) {
1053
- ConstantOp constantOp = operand.value ().getDefiningOp <ConstantOp>();
1061
+ for (auto operand : llvm::enumerate (linalgOp.getInputOpOperands ())) {
1062
+ ConstantOp constantOp = operand.value ().get (). getDefiningOp <ConstantOp>();
1054
1063
if (!constantOp ||
1055
- !constantOp.value ().cast <DenseElementsAttr>().isSplat ())
1064
+ !constantOp.value ().cast <DenseElementsAttr>().isSplat () ||
1065
+ !controlFn (constantOp->getResult (0 ), operand.value ()))
1056
1066
continue ;
1057
1067
1058
1068
// The indexing_maps for the operands of the fused operation are same as
@@ -1099,11 +1109,15 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
1099
1109
}
1100
1110
return failure ();
1101
1111
}
1112
+
1113
+ private:
1114
+ ControlElementwiseOpsFusionFn controlFn;
1102
1115
};
1103
1116
} // namespace
1104
1117
1105
1118
static Optional<SmallVector<Value, 1 >>
1106
- fuseElementwiseOps (PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
1119
+ fuseElementwiseOps (PatternRewriter &rewriter, OpOperand &consumerOpOperand,
1120
+ const ControlElementwiseOpsFusionFn &controlFn) {
1107
1121
Operation *producer = consumerOpOperand.get ().getDefiningOp ();
1108
1122
if (!producer || producer->getNumResults () != 1 )
1109
1123
return llvm::None;
@@ -1114,14 +1128,17 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
1114
1128
return llvm::None;
1115
1129
1116
1130
return fuseElementwiseOpsImpl (cast<LinalgOp>(producer), consumerOpOperand,
1117
- rewriter);
1131
+ controlFn, rewriter);
1118
1132
}
1119
1133
1120
1134
namespace {
1121
1135
// / Patterns to fuse a generic op, with the producer of its operands.
1122
1136
template <typename LinalgOpTy>
1123
- struct FuseElementwiseOps : public OpRewritePattern <LinalgOpTy> {
1124
- using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
1137
+ class FuseElementwiseOps : public OpRewritePattern <LinalgOpTy> {
1138
+ public:
1139
+ FuseElementwiseOps (MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
1140
+ PatternBenefit benefit = 1 )
1141
+ : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
1125
1142
1126
1143
LogicalResult matchAndRewrite (LinalgOpTy op,
1127
1144
PatternRewriter &rewriter) const override {
@@ -1132,14 +1149,17 @@ struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
1132
1149
if (!producerOp || !producerOp.hasTensorSemantics ())
1133
1150
continue ;
1134
1151
Optional<SmallVector<Value, 1 >> fusedOpResults =
1135
- fuseElementwiseOps (rewriter, opOperand);
1152
+ fuseElementwiseOps (rewriter, opOperand, controlFn );
1136
1153
if (fusedOpResults) {
1137
1154
rewriter.replaceOp (op, *fusedOpResults);
1138
1155
return success ();
1139
1156
}
1140
1157
}
1141
1158
return failure ();
1142
1159
}
1160
+
1161
+ private:
1162
+ ControlElementwiseOpsFusionFn controlFn;
1143
1163
};
1144
1164
1145
1165
// / Pass that fuses generic ops on tensors. Used only for testing.
@@ -1148,7 +1168,10 @@ struct FusionOfTensorOpsPass
1148
1168
void runOnOperation () override {
1149
1169
Operation *op = getOperation ();
1150
1170
RewritePatternSet patterns (op->getContext ());
1151
- populateElementwiseOpsFusionPatterns (patterns, allowFoldingUnitDimReshapes);
1171
+ populateElementwiseOpsFusionPatterns (
1172
+ patterns,
1173
+ LinalgElementwiseFusionOptions ().setAllowFoldingUnitDimReshapes (
1174
+ allowFoldingUnitDimReshapes));
1152
1175
(void )applyPatternsAndFoldGreedily (op->getRegions (), std::move (patterns));
1153
1176
}
1154
1177
};
@@ -1193,14 +1216,14 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1193
1216
}
1194
1217
1195
1218
void mlir::linalg::populateElementwiseOpsFusionPatterns (
1196
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes ) {
1219
+ RewritePatternSet &patterns, LinalgElementwiseFusionOptions options ) {
1197
1220
auto *context = patterns.getContext ();
1198
1221
patterns
1199
1222
.add <FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
1200
1223
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
1201
- context);
1202
- populateFoldReshapeOpsByExpansionPatterns (patterns,
1203
- allowFoldingUnitDimReshapes);
1224
+ context, options. controlElementwiseOpsFusionFn );
1225
+ populateFoldReshapeOpsByExpansionPatterns (
1226
+ patterns, options. allowFoldingUnitDimReshapes );
1204
1227
GenericOp::getCanonicalizationPatterns (patterns, context);
1205
1228
IndexedGenericOp::getCanonicalizationPatterns (patterns, context);
1206
1229
TensorReshapeOp::getCanonicalizationPatterns (patterns, context);
0 commit comments