Skip to content

Commit 9436558

Browse files
committed
Improve getPreservedProducerResults estimation in ElementwiseOpFusion
This commit changes the getPreservedProducerResults function so that it takes the consumer into account along with the producer, in order to predict which of the producer’s outputs can be dropped during the fusion process. It provides a more accurate prediction, considering that the fusion process also depends on the consumer.
1 parent 3d1e1d9 commit 9436558

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,8 @@ struct ElementwiseOpFusionResult {
498498
Operation *fusedOp;
499499
llvm::DenseMap<Value, Value> replacements;
500500
static llvm::SmallDenseSet<int>
501-
getPreservedProducerResults(GenericOp producer, GenericOp consumer);
501+
getPreservedProducerResults(GenericOp producer, GenericOp consumer,
502+
OpOperand *fusedOperand);
502503
};
503504
FailureOr<ElementwiseOpFusionResult>
504505
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,54 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7070
return t1.compose(fusedConsumerArgIndexMap);
7171
}
7272

73+
// Checks if the given operand can be dropped, and the remaining operands
74+
// of the fused producer & consumer after the fusion can still compute the
75+
// bounds of the op.
76+
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
77+
GenericOp producer, GenericOp consumer,
78+
ArrayRef<OpOperand *> opOperandsToIgnore) {
79+
SmallVector<AffineMap> indexingMaps;
80+
81+
SmallVector<GenericOp> ops = {producer, consumer};
82+
for (auto &op : ops) {
83+
for (auto &opOperand : op->getOpOperands()) {
84+
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
85+
continue;
86+
}
87+
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
88+
}
89+
}
90+
91+
// The concatanation of the remained indexing maps must be invertible, so
92+
// the bounds of the op can be still computed after dropping the selected
93+
// operand. inversePermutation returns an empty AffineMap in case the
94+
// concatanated indexing maps are not invertible.
95+
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
96+
}
97+
7398
/// Returns a set of indices of the producer's results which would
7499
/// be preserved after the fusion.
75-
llvm::SmallDenseSet<int>
76-
ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
77-
GenericOp consumer) {
100+
llvm::SmallDenseSet<int> ElementwiseOpFusionResult::getPreservedProducerResults(
101+
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
78102
llvm::SmallDenseSet<int> preservedProducerResults;
103+
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
104+
105+
// The fusedOperand will be removed during the fusion
106+
opOperandsToIgnore.emplace_back(fusedOperand);
107+
79108
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
80109
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
110+
opOperandsToIgnore.emplace_back(outputOperand);
81111
if (producer.payloadUsesValueFromOperand(outputOperand) ||
82-
!producer.canOpOperandsBeDropped(outputOperand) ||
112+
!isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
113+
opOperandsToIgnore) ||
83114
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
84115
return user != consumer.getOperation();
85116
})) {
86117
preservedProducerResults.insert(producerResult.index());
118+
119+
// In case the operand can't be dropped
120+
opOperandsToIgnore.pop_back_val();
87121
}
88122
}
89123
return preservedProducerResults;
@@ -302,8 +336,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
302336
"expected producer of input operand");
303337
/// Find the results of the producer that have uses outside of the consumer.
304338
llvm::SmallDenseSet<int> preservedProducerResults =
305-
ElementwiseOpFusionResult::getPreservedProducerResults(producer,
306-
consumer);
339+
ElementwiseOpFusionResult::getPreservedProducerResults(producer, consumer,
340+
fusedOperand);
307341

308342
// Compute the fused operands list and indexing maps.
309343
SmallVector<Value> fusedInputOperands, fusedOutputOperands;

0 commit comments

Comments
 (0)