Skip to content

[mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion #104409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,19 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
static llvm::SmallDenseSet<int>
getPreservedProducerResults(GenericOp producer, GenericOp consumer);
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);

/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
GenericOp consumer,
OpOperand *fusedOperand);

/// Try to peel and canonicalize loop `op` and return the new result.
/// Also applies affine_min/max bounds simplification on the fly where relevant.
// TODO: Add support for scf.parallel and affine.for loops.
Expand Down
52 changes: 45 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,57 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
return t1.compose(fusedConsumerArgIndexMap);
}

// Checks if the given operand can be dropped, and the remaining operands
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
GenericOp producer, GenericOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;

SmallVector<GenericOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
continue;
}
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
}
}

// The concatanation of the remained indexing maps must be invertible, so
// the bounds of the op can be still computed after dropping the selected
// operand. inversePermutation returns an empty AffineMap in case the
// concatanated indexing maps are not invertible.
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}

/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
llvm::SmallDenseSet<int>
ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
GenericOp consumer) {
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;

// The fusedOperand will be removed during the fusion
opOperandsToIgnore.emplace_back(fusedOperand);

for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
opOperandsToIgnore.emplace_back(outputOperand);
if (producer.payloadUsesValueFromOperand(outputOperand) ||
!producer.canOpOperandsBeDropped(outputOperand) ||
!isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
opOperandsToIgnore) ||
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
return user != consumer.getOperation();
})) {
preservedProducerResults.insert(producerResult.index());

// In case the operand can't be dropped
opOperandsToIgnore.pop_back_val();
}
}
return preservedProducerResults;
Expand Down Expand Up @@ -300,10 +337,11 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
/// Find the results of the producer that have uses outside of the consumer.
/// Find the results of the producer that have uses outside of the consumer,
/// after the fusion.
llvm::SmallDenseSet<int> preservedProducerResults =
ElementwiseOpFusionResult::getPreservedProducerResults(producer,
consumer);
mlir::linalg::getPreservedProducerResults(producer, consumer,
fusedOperand);

// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
Expand Down
Loading