@@ -71,20 +71,57 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
71
71
return t1.compose (fusedConsumerArgIndexMap);
72
72
}
73
73
74
+ // Checks if the given operand can be dropped, and the remaining operands
75
+ // of the fused producer & consumer after the fusion can still compute the
76
+ // bounds of the op.
77
+ static bool isOpOperandCanBeDroppedAfterFusedLinalgs (
78
+ GenericOp producer, GenericOp consumer,
79
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
80
+ SmallVector<AffineMap> indexingMaps;
81
+
82
+ SmallVector<GenericOp> ops = {producer, consumer};
83
+ for (auto &op : ops) {
84
+ for (auto &opOperand : op->getOpOperands ()) {
85
+ if (llvm::is_contained (opOperandsToIgnore, &opOperand)) {
86
+ continue ;
87
+ }
88
+ indexingMaps.push_back (op.getMatchingIndexingMap (&opOperand));
89
+ }
90
+ }
91
+
92
+ // The concatanation of the remained indexing maps must be invertible, so
93
+ // the bounds of the op can be still computed after dropping the selected
94
+ // operand. inversePermutation returns an empty AffineMap in case the
95
+ // concatanated indexing maps are not invertible.
96
+ return inversePermutation (concatAffineMaps (indexingMaps)) != AffineMap ();
97
+ }
98
+
74
99
// / Returns a set of indices of the producer's results which would
75
100
// / be preserved after the fusion.
76
- llvm::SmallDenseSet<int >
77
- ElementwiseOpFusionResult::getPreservedProducerResults (GenericOp producer,
78
- GenericOp consumer) {
101
+ // / * There is a chance that the implementation of the transformation does not
102
+ // / agree with the result of this method. This function gives a prediction based
103
+ // / on an optimized fusion.
104
+ llvm::SmallDenseSet<int > mlir::linalg::getPreservedProducerResults (
105
+ GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
79
106
llvm::SmallDenseSet<int > preservedProducerResults;
107
+ llvm::SmallVector<OpOperand *> opOperandsToIgnore;
108
+
109
+ // The fusedOperand will be removed during the fusion
110
+ opOperandsToIgnore.emplace_back (fusedOperand);
111
+
80
112
for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
81
113
auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
114
+ opOperandsToIgnore.emplace_back (outputOperand);
82
115
if (producer.payloadUsesValueFromOperand (outputOperand) ||
83
- !producer.canOpOperandsBeDropped (outputOperand) ||
116
+ !isOpOperandCanBeDroppedAfterFusedLinalgs (producer, consumer,
117
+ opOperandsToIgnore) ||
84
118
llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
85
119
return user != consumer.getOperation ();
86
120
})) {
87
121
preservedProducerResults.insert (producerResult.index ());
122
+
123
+ // In case the operand can't be dropped
124
+ opOperandsToIgnore.pop_back_val ();
88
125
}
89
126
}
90
127
return preservedProducerResults;
@@ -301,10 +338,11 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
301
338
// TODO: allow fusing the producer of an output operand.
302
339
assert (consumer.isDpsInput (fusedOperand) &&
303
340
" expected producer of input operand" );
304
- // / Find the results of the producer that have uses outside of the consumer.
341
+ // / Find the results of the producer that have uses outside of the consumer,
342
+ // / after the fusion.
305
343
llvm::SmallDenseSet<int > preservedProducerResults =
306
- ElementwiseOpFusionResult:: getPreservedProducerResults (producer,
307
- consumer );
344
+ mlir::linalg:: getPreservedProducerResults (producer, consumer ,
345
+ fusedOperand );
308
346
309
347
// Compute the fused operands list and indexing maps.
310
348
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
0 commit comments