Skip to content

Commit 4a4b233

Browse files
authored
[mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion (#104409)
This commit changes the getPreservedProducerResults function so that it takes the consumer into account along with the producer, in order to predict which of the producer’s outputs can be dropped during the fusion process. It provides a more accurate prediction, considering that the fusion process also depends on the consumer.
1 parent f01f80c commit 4a4b233

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,12 +497,19 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
497497
struct ElementwiseOpFusionResult {
498498
Operation *fusedOp;
499499
llvm::DenseMap<Value, Value> replacements;
500-
static llvm::SmallDenseSet<int>
501-
getPreservedProducerResults(GenericOp producer, GenericOp consumer);
502500
};
503501
FailureOr<ElementwiseOpFusionResult>
504502
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
505503

504+
/// Returns a set of indices of the producer's results which would
505+
/// be preserved after the fusion.
506+
/// * There is a chance that the implementation of the transformation does not
507+
/// agree with the result of this method. This function gives a prediction based
508+
/// on an optimized fusion.
509+
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
510+
GenericOp consumer,
511+
OpOperand *fusedOperand);
512+
506513
/// Try to peel and canonicalize loop `op` and return the new result.
507514
/// Also applies affine_min/max bounds simplification on the fly where relevant.
508515
// TODO: Add support for scf.parallel and affine.for loops.

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,57 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7171
return t1.compose(fusedConsumerArgIndexMap);
7272
}
7373

74+
// Checks if the given operand can be dropped, and the remaining operands
75+
// of the fused producer & consumer after the fusion can still compute the
76+
// bounds of the op.
77+
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
78+
GenericOp producer, GenericOp consumer,
79+
ArrayRef<OpOperand *> opOperandsToIgnore) {
80+
SmallVector<AffineMap> indexingMaps;
81+
82+
SmallVector<GenericOp> ops = {producer, consumer};
83+
for (auto &op : ops) {
84+
for (auto &opOperand : op->getOpOperands()) {
85+
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
86+
continue;
87+
}
88+
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
89+
}
90+
}
91+
92+
// The concatanation of the remained indexing maps must be invertible, so
93+
// the bounds of the op can be still computed after dropping the selected
94+
// operand. inversePermutation returns an empty AffineMap in case the
95+
// concatanated indexing maps are not invertible.
96+
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
97+
}
98+
7499
/// Returns a set of indices of the producer's results which would
75100
/// be preserved after the fusion.
76-
llvm::SmallDenseSet<int>
77-
ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
78-
GenericOp consumer) {
101+
/// * There is a chance that the implementation of the transformation does not
102+
/// agree with the result of this method. This function gives a prediction based
103+
/// on an optimized fusion.
104+
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
105+
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
79106
llvm::SmallDenseSet<int> preservedProducerResults;
107+
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
108+
109+
// The fusedOperand will be removed during the fusion
110+
opOperandsToIgnore.emplace_back(fusedOperand);
111+
80112
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
81113
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
114+
opOperandsToIgnore.emplace_back(outputOperand);
82115
if (producer.payloadUsesValueFromOperand(outputOperand) ||
83-
!producer.canOpOperandsBeDropped(outputOperand) ||
116+
!isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
117+
opOperandsToIgnore) ||
84118
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
85119
return user != consumer.getOperation();
86120
})) {
87121
preservedProducerResults.insert(producerResult.index());
122+
123+
// In case the operand can't be dropped
124+
opOperandsToIgnore.pop_back_val();
88125
}
89126
}
90127
return preservedProducerResults;
@@ -301,10 +338,11 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
301338
// TODO: allow fusing the producer of an output operand.
302339
assert(consumer.isDpsInput(fusedOperand) &&
303340
"expected producer of input operand");
304-
/// Find the results of the producer that have uses outside of the consumer.
341+
/// Find the results of the producer that have uses outside of the consumer,
342+
/// after the fusion.
305343
llvm::SmallDenseSet<int> preservedProducerResults =
306-
ElementwiseOpFusionResult::getPreservedProducerResults(producer,
307-
consumer);
344+
mlir::linalg::getPreservedProducerResults(producer, consumer,
345+
fusedOperand);
308346

309347
// Compute the fused operands list and indexing maps.
310348
SmallVector<Value> fusedInputOperands, fusedOutputOperands;

0 commit comments

Comments
 (0)