@@ -70,20 +70,56 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
70
70
return t1.compose (fusedConsumerArgIndexMap);
71
71
}
72
72
73
+ // Checks if the given operand can be dropped, and the remaining operands
74
+ // of the fused producer & consumer after the fusion can still compute the
75
+ // bounds of the op.
76
+ static bool isOpOperandCanBeDroppedAfterFusedLinalgs (GenericOp producer,
77
+ GenericOp consumer,
78
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
79
+ SmallVector<AffineMap> indexingMaps;
80
+
81
+ SmallVector<GenericOp> ops = {producer, consumer};
82
+ for (auto &op : ops) {
83
+ for (auto &opOperand : op->getOpOperands ()) {
84
+ if (llvm::is_contained (opOperandsToIgnore, &opOperand)) {
85
+ continue ;
86
+ }
87
+ indexingMaps.push_back (op.getMatchingIndexingMap (&opOperand));
88
+ }
89
+ }
90
+
91
+ // The concatanation of the remained indexing maps must be invertible, so
92
+ // the bounds of the op can be still computed after dropping the selected operand.
93
+ // inversePermutation returns an empty AffineMap in case the concatanated indexing
94
+ // maps are not invertible.
95
+ return inversePermutation (concatAffineMaps (indexingMaps)) != AffineMap ();
96
+ }
97
+
73
98
// / Returns a set of indices of the producer's results which would
74
99
// / be preserved after the fusion.
75
100
llvm::SmallDenseSet<int >
76
101
ElementwiseOpFusionResult::getPreservedProducerResults (GenericOp producer,
77
- GenericOp consumer) {
102
+ GenericOp consumer,
103
+ OpOperand *fusedOperand) {
78
104
llvm::SmallDenseSet<int > preservedProducerResults;
105
+ llvm::SmallVector<OpOperand*> opOperandsToIgnore;
106
+
107
+ // The fusedOperand will be removed during the fusion
108
+ opOperandsToIgnore.emplace_back (fusedOperand);
109
+
79
110
for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
80
111
auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
112
+ opOperandsToIgnore.emplace_back (outputOperand);
81
113
if (producer.payloadUsesValueFromOperand (outputOperand) ||
82
- !producer.canOpOperandsBeDropped (outputOperand) ||
114
+ !isOpOperandCanBeDroppedAfterFusedLinalgs (producer, consumer,
115
+ opOperandsToIgnore) ||
83
116
llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
84
117
return user != consumer.getOperation ();
85
118
})) {
86
119
preservedProducerResults.insert (producerResult.index ());
120
+
121
+ // In case the operand can't be dropped
122
+ opOperandsToIgnore.pop_back_val ();
87
123
}
88
124
}
89
125
return preservedProducerResults;
@@ -303,7 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
303
339
// / Find the results of the producer that have uses outside of the consumer.
304
340
llvm::SmallDenseSet<int > preservedProducerResults =
305
341
ElementwiseOpFusionResult::getPreservedProducerResults (producer,
306
- consumer);
342
+ consumer,
343
+ fusedOperand);
307
344
308
345
// Compute the fused operands list and indexing maps.
309
346
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
0 commit comments