@@ -70,20 +70,57 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
70
70
return t1.compose (fusedConsumerArgIndexMap);
71
71
}
72
72
73
+ // Checks if the given operand can be dropped, and the remaining operands
74
+ // of the fused producer & consumer after the fusion can still compute the
75
+ // bounds of the op.
76
+ static bool isOpOperandCanBeDroppedAfterFusedLinalgs (
77
+ GenericOp producer, GenericOp consumer,
78
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
79
+ SmallVector<AffineMap> indexingMaps;
80
+
81
+ SmallVector<GenericOp> ops = {producer, consumer};
82
+ for (auto &op : ops) {
83
+ for (auto &opOperand : op->getOpOperands ()) {
84
+ if (llvm::is_contained (opOperandsToIgnore, &opOperand)) {
85
+ continue ;
86
+ }
87
+ indexingMaps.push_back (op.getMatchingIndexingMap (&opOperand));
88
+ }
89
+ }
90
+
91
+ // The concatanation of the remained indexing maps must be invertible, so
92
+ // the bounds of the op can be still computed after dropping the selected
93
+ // operand. inversePermutation returns an empty AffineMap in case the
94
+ // concatanated indexing maps are not invertible.
95
+ return inversePermutation (concatAffineMaps (indexingMaps)) != AffineMap ();
96
+ }
97
+
73
98
// / Returns a set of indices of the producer's results which would
74
99
// / be preserved after the fusion.
75
- llvm::SmallDenseSet<int >
76
- ElementwiseOpFusionResult::getPreservedProducerResults (GenericOp producer,
77
- GenericOp consumer) {
100
+ // / * There is a chance that the implementation of the transformation does not
101
+ // / agree with the result of this method. This function gives a prediction based
102
+ // / on an optimized fusion.
103
+ llvm::SmallDenseSet<int > mlir::linalg::getPreservedProducerResults (
104
+ GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
78
105
llvm::SmallDenseSet<int > preservedProducerResults;
106
+ llvm::SmallVector<OpOperand *> opOperandsToIgnore;
107
+
108
+ // The fusedOperand will be removed during the fusion
109
+ opOperandsToIgnore.emplace_back (fusedOperand);
110
+
79
111
for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
80
112
auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
113
+ opOperandsToIgnore.emplace_back (outputOperand);
81
114
if (producer.payloadUsesValueFromOperand (outputOperand) ||
82
- !producer.canOpOperandsBeDropped (outputOperand) ||
115
+ !isOpOperandCanBeDroppedAfterFusedLinalgs (producer, consumer,
116
+ opOperandsToIgnore) ||
83
117
llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
84
118
return user != consumer.getOperation ();
85
119
})) {
86
120
preservedProducerResults.insert (producerResult.index ());
121
+
122
+ // In case the operand can't be dropped
123
+ opOperandsToIgnore.pop_back_val ();
87
124
}
88
125
}
89
126
return preservedProducerResults;
@@ -300,10 +337,11 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
300
337
// TODO: allow fusing the producer of an output operand.
301
338
assert (consumer.isDpsInput (fusedOperand) &&
302
339
" expected producer of input operand" );
303
- // / Find the results of the producer that have uses outside of the consumer.
340
+ // / Find the results of the producer that have uses outside of the consumer,
341
+ // / after the fusion.
304
342
llvm::SmallDenseSet<int > preservedProducerResults =
305
- ElementwiseOpFusionResult:: getPreservedProducerResults (producer,
306
- consumer );
343
+ mlir::linalg:: getPreservedProducerResults (producer, consumer ,
344
+ fusedOperand );
307
345
308
346
// Compute the fused operands list and indexing maps.
309
347
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
0 commit comments