-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file #73850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file #73850
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Amir Bishara (amirBish) ChangesDeclare Full diff: https://github.com/llvm/llvm-project/pull/73850.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6c4e16bd94f47d4..29787354a0f6cc1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -493,6 +493,8 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
+ static llvm::SmallDenseSet<int>
+ getPreservedProducerResults(GenericOp producer, GenericOp consumer);
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c27..325da9881b9391e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
return t1.compose(fusedConsumerArgIndexMap);
}
+/// Returns a set of indices of the producer's results which would
+/// be preserved after the fusion.
+llvm::SmallDenseSet<int>
+ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
+ GenericOp consumer) {
+ llvm::SmallDenseSet<int> preservedProducerResults;
+ for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
+ auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+ if (producer.payloadUsesValueFromOperand(outputOperand) ||
+ !producer.canOpOperandsBeDropped(outputOperand) ||
+ llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
+ return user != consumer.getOperation();
+ })) {
+ preservedProducerResults.insert(producerResult.index());
+ }
+ }
+ return preservedProducerResults;
+}
+
/// Conditions for elementwise fusion of generic operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
/// Find the results of the producer that have uses outside of the consumer.
- llvm::SmallDenseSet<int> preservedProducerResults;
- for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
- auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
- if (producer.payloadUsesValueFromOperand(outputOperand) ||
- !producer.canOpOperandsBeDropped(outputOperand) ||
- llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
- return user != consumer.getOperation();
- })) {
- preservedProducerResults.insert(producerResult.index());
- }
- }
+ llvm::SmallDenseSet<int> preservedProducerResults =
+ ElementwiseOpFusionResult::getPreservedProducerResults(producer,
+ consumer);
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what this is being done for. Could you provide some more context. In general having some callback like that in the result object is not recommended.
a78e102
to
4ac4db9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works too... But might be better to leave it as a static method in the result object (as you had it before). It won't matter in the end, all it's doing is providing scoping for the method.
4ac4db9
to
e21b3fb
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
1146126
to
31e462d
Compare
…wiseOpFusion file Declare `getPreservedProducerResults` function which helps to get the preserved results of the producer linalg generic operation as a result of elementwise fusion.
31e462d
to
2a3996d
Compare
Declare
getPreservedProducerResults
function which helps to get the preserved results of the producer linalg generic operation as a result of elementwise fusion.