Skip to content

Commit cf2d625

Browse files
authored
[mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file (#73850)
Declare `getPreservedProducerResults` function which helps to get the preserved results of the producer linalg generic operation as a result of elementwise fusion.
1 parent 52296e2 commit cf2d625

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,8 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
493493
struct ElementwiseOpFusionResult {
494494
Operation *fusedOp;
495495
llvm::DenseMap<Value, Value> replacements;
496+
static llvm::SmallDenseSet<int>
497+
getPreservedProducerResults(GenericOp producer, GenericOp consumer);
496498
};
497499
FailureOr<ElementwiseOpFusionResult>
498500
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7171
return t1.compose(fusedConsumerArgIndexMap);
7272
}
7373

74+
/// Returns a set of indices of the producer's results which would
75+
/// be preserved after the fusion.
76+
llvm::SmallDenseSet<int>
77+
ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
78+
GenericOp consumer) {
79+
llvm::SmallDenseSet<int> preservedProducerResults;
80+
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
81+
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
82+
if (producer.payloadUsesValueFromOperand(outputOperand) ||
83+
!producer.canOpOperandsBeDropped(outputOperand) ||
84+
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
85+
return user != consumer.getOperation();
86+
})) {
87+
preservedProducerResults.insert(producerResult.index());
88+
}
89+
}
90+
return preservedProducerResults;
91+
}
92+
7493
/// Conditions for elementwise fusion of generic operations.
7594
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
7695
if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
285304
assert(consumer.isDpsInput(fusedOperand) &&
286305
"expected producer of input operand");
287306
/// Find the results of the producer that have uses outside of the consumer.
288-
llvm::SmallDenseSet<int> preservedProducerResults;
289-
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
290-
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
291-
if (producer.payloadUsesValueFromOperand(outputOperand) ||
292-
!producer.canOpOperandsBeDropped(outputOperand) ||
293-
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
294-
return user != consumer.getOperation();
295-
})) {
296-
preservedProducerResults.insert(producerResult.index());
297-
}
298-
}
307+
llvm::SmallDenseSet<int> preservedProducerResults =
308+
ElementwiseOpFusionResult::getPreservedProducerResults(producer,
309+
consumer);
299310

300311
// Compute the fused operands list and indexing maps.
301312
SmallVector<Value> fusedInputOperands, fusedOutputOperands;

0 commit comments

Comments
 (0)