@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
71
71
return t1.compose (fusedConsumerArgIndexMap);
72
72
}
73
73
74
+ // / Returns a set of indices of the producer's results which would
75
+ // / be preserved after the fusion.
76
+ llvm::SmallDenseSet<int >
77
+ ElementwiseOpFusionResult::getPreservedProducerResults (GenericOp producer,
78
+ GenericOp consumer) {
79
+ llvm::SmallDenseSet<int > preservedProducerResults;
80
+ for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
81
+ auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
82
+ if (producer.payloadUsesValueFromOperand (outputOperand) ||
83
+ !producer.canOpOperandsBeDropped (outputOperand) ||
84
+ llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
85
+ return user != consumer.getOperation ();
86
+ })) {
87
+ preservedProducerResults.insert (producerResult.index ());
88
+ }
89
+ }
90
+ return preservedProducerResults;
91
+ }
92
+
74
93
// / Conditions for elementwise fusion of generic operations.
75
94
bool mlir::linalg::areElementwiseOpsFusable (OpOperand *fusedOperand) {
76
95
if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
285
304
assert (consumer.isDpsInput (fusedOperand) &&
286
305
" expected producer of input operand" );
287
306
// / Find the results of the producer that have uses outside of the consumer.
288
- llvm::SmallDenseSet<int > preservedProducerResults;
289
- for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
290
- auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
291
- if (producer.payloadUsesValueFromOperand (outputOperand) ||
292
- !producer.canOpOperandsBeDropped (outputOperand) ||
293
- llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
294
- return user != consumer.getOperation ();
295
- })) {
296
- preservedProducerResults.insert (producerResult.index ());
297
- }
298
- }
307
+ llvm::SmallDenseSet<int > preservedProducerResults =
308
+ ElementwiseOpFusionResult::getPreservedProducerResults (producer,
309
+ consumer);
299
310
300
311
// Compute the fused operands list and indexing maps.
301
312
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
0 commit comments