Skip to content

Commit 353fbec

Browse files
committed
Improve getPreservedProducerResults estimation in ElementwiseOpFusion
This commit changes the getPreservedProducerResults function so that it takes the consumer into account along with the producer, in order to predict which of the producer’s outputs can be dropped during the fusion process. It provides a more accurate prediction, considering that the fusion process also depends on the consumer.
1 parent 3d1e1d9 commit 353fbec

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ struct ElementwiseOpFusionResult {
498498
Operation *fusedOp;
499499
llvm::DenseMap<Value, Value> replacements;
500500
static llvm::SmallDenseSet<int>
501-
getPreservedProducerResults(GenericOp producer, GenericOp consumer);
501+
getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand);
502502
};
503503
FailureOr<ElementwiseOpFusionResult>
504504
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,56 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7070
return t1.compose(fusedConsumerArgIndexMap);
7171
}
7272

73+
// Checks if the given operand can be dropped, and the remaining operands
74+
// of the fused producer & consumer after the fusion can still compute the
75+
// bounds of the op.
76+
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer,
77+
GenericOp consumer,
78+
ArrayRef<OpOperand *> opOperandsToIgnore) {
79+
SmallVector<AffineMap> indexingMaps;
80+
81+
SmallVector<GenericOp> ops = {producer, consumer};
82+
for (auto &op : ops) {
83+
for (auto &opOperand : op->getOpOperands()) {
84+
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
85+
continue;
86+
}
87+
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
88+
}
89+
}
90+
91+
// The concatanation of the remained indexing maps must be invertible, so
92+
// the bounds of the op can be still computed after dropping the selected operand.
93+
// inversePermutation returns an empty AffineMap in case the concatanated indexing
94+
// maps are not invertible.
95+
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
96+
}
97+
7398
/// Returns a set of indices of the producer's results which would
7499
/// be preserved after the fusion.
75100
llvm::SmallDenseSet<int>
76101
ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
77-
GenericOp consumer) {
102+
GenericOp consumer,
103+
OpOperand *fusedOperand) {
78104
llvm::SmallDenseSet<int> preservedProducerResults;
105+
llvm::SmallVector<OpOperand*> opOperandsToIgnore;
106+
107+
// The fusedOperand will be removed during the fusion
108+
opOperandsToIgnore.emplace_back(fusedOperand);
109+
79110
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
80111
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
112+
opOperandsToIgnore.emplace_back(outputOperand);
81113
if (producer.payloadUsesValueFromOperand(outputOperand) ||
82-
!producer.canOpOperandsBeDropped(outputOperand) ||
114+
!isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
115+
opOperandsToIgnore) ||
83116
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
84117
return user != consumer.getOperation();
85118
})) {
86119
preservedProducerResults.insert(producerResult.index());
120+
121+
// In case the operand can't be dropped
122+
opOperandsToIgnore.pop_back_val();
87123
}
88124
}
89125
return preservedProducerResults;
@@ -303,7 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
303339
/// Find the results of the producer that have uses outside of the consumer.
304340
llvm::SmallDenseSet<int> preservedProducerResults =
305341
ElementwiseOpFusionResult::getPreservedProducerResults(producer,
306-
consumer);
342+
consumer,
343+
fusedOperand);
307344

308345
// Compute the fused operands list and indexing maps.
309346
SmallVector<Value> fusedInputOperands, fusedOutputOperands;

0 commit comments

Comments
 (0)