@@ -70,20 +70,54 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
70
70
return t1.compose (fusedConsumerArgIndexMap);
71
71
}
72
72
73
+ // Checks if the given operand can be dropped, and the remaining operands
74
+ // of the fused producer & consumer after the fusion can still compute the
75
+ // bounds of the op.
76
+ static bool isOpOperandCanBeDroppedAfterFusedLinalgs (
77
+ GenericOp producer, GenericOp consumer,
78
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
79
+ SmallVector<AffineMap> indexingMaps;
80
+
81
+ SmallVector<GenericOp> ops = {producer, consumer};
82
+ for (auto &op : ops) {
83
+ for (auto &opOperand : op->getOpOperands ()) {
84
+ if (llvm::is_contained (opOperandsToIgnore, &opOperand)) {
85
+ continue ;
86
+ }
87
+ indexingMaps.push_back (op.getMatchingIndexingMap (&opOperand));
88
+ }
89
+ }
90
+
91
+ // The concatanation of the remained indexing maps must be invertible, so
92
+ // the bounds of the op can be still computed after dropping the selected
93
+ // operand. inversePermutation returns an empty AffineMap in case the
94
+ // concatanated indexing maps are not invertible.
95
+ return inversePermutation (concatAffineMaps (indexingMaps)) != AffineMap ();
96
+ }
97
+
73
98
// / Returns a set of indices of the producer's results which would
74
99
// / be preserved after the fusion.
75
- llvm::SmallDenseSet<int >
76
- ElementwiseOpFusionResult::getPreservedProducerResults (GenericOp producer,
77
- GenericOp consumer) {
100
+ llvm::SmallDenseSet<int > ElementwiseOpFusionResult::getPreservedProducerResults (
101
+ GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
78
102
llvm::SmallDenseSet<int > preservedProducerResults;
103
+ llvm::SmallVector<OpOperand *> opOperandsToIgnore;
104
+
105
+ // The fusedOperand will be removed during the fusion
106
+ opOperandsToIgnore.emplace_back (fusedOperand);
107
+
79
108
for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
80
109
auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
110
+ opOperandsToIgnore.emplace_back (outputOperand);
81
111
if (producer.payloadUsesValueFromOperand (outputOperand) ||
82
- !producer.canOpOperandsBeDropped (outputOperand) ||
112
+ !isOpOperandCanBeDroppedAfterFusedLinalgs (producer, consumer,
113
+ opOperandsToIgnore) ||
83
114
llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
84
115
return user != consumer.getOperation ();
85
116
})) {
86
117
preservedProducerResults.insert (producerResult.index ());
118
+
119
+ // In case the operand can't be dropped
120
+ opOperandsToIgnore.pop_back_val ();
87
121
}
88
122
}
89
123
return preservedProducerResults;
@@ -302,8 +336,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
302
336
" expected producer of input operand" );
303
337
// / Find the results of the producer that have uses outside of the consumer.
304
338
llvm::SmallDenseSet<int > preservedProducerResults =
305
- ElementwiseOpFusionResult::getPreservedProducerResults (producer,
306
- consumer );
339
+ ElementwiseOpFusionResult::getPreservedProducerResults (producer, consumer,
340
+ fusedOperand );
307
341
308
342
// Compute the fused operands list and indexing maps.
309
343
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
0 commit comments