Skip to content

[mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file #73850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

amirBish
Copy link
Contributor

Declare getPreservedProducerResults function which helps to get the preserved results of the producer linalg generic operation as a result of elementwise fusion.

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Amir Bishara (amirBish)

Changes

Declare getPreservedProducerResults function which helps to get the preserved results of the producer linalg generic operation as a result of elementwise fusion.


Full diff: https://github.com/llvm/llvm-project/pull/73850.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+22-11)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6c4e16bd94f47d4..29787354a0f6cc1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -493,6 +493,8 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
 struct ElementwiseOpFusionResult {
   Operation *fusedOp;
   llvm::DenseMap<Value, Value> replacements;
+  static llvm::SmallDenseSet<int>
+  getPreservedProducerResults(GenericOp producer, GenericOp consumer);
 };
 FailureOr<ElementwiseOpFusionResult>
 fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c27..325da9881b9391e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
   return t1.compose(fusedConsumerArgIndexMap);
 }
 
+/// Returns a set of indices of the producer's results which would
+/// be preserved after the fusion.
+llvm::SmallDenseSet<int>
+ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
+                                                       GenericOp consumer) {
+  llvm::SmallDenseSet<int> preservedProducerResults;
+  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
+    auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+    if (producer.payloadUsesValueFromOperand(outputOperand) ||
+        !producer.canOpOperandsBeDropped(outputOperand) ||
+        llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
+          return user != consumer.getOperation();
+        })) {
+      preservedProducerResults.insert(producerResult.index());
+    }
+  }
+  return preservedProducerResults;
+}
+
 /// Conditions for elementwise fusion of generic operations.
 bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   assert(consumer.isDpsInput(fusedOperand) &&
          "expected producer of input operand");
   /// Find the results of the producer that have uses outside of the consumer.
-  llvm::SmallDenseSet<int> preservedProducerResults;
-  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
-    auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
-    if (producer.payloadUsesValueFromOperand(outputOperand) ||
-        !producer.canOpOperandsBeDropped(outputOperand) ||
-        llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
-          return user != consumer.getOperation();
-        })) {
-      preservedProducerResults.insert(producerResult.index());
-    }
-  }
+  llvm::SmallDenseSet<int> preservedProducerResults =
+      ElementwiseOpFusionResult::getPreservedProducerResults(producer,
+                                                             consumer);
 
   // Compute the fused operands list and indexing maps.
   SmallVector<Value> fusedInputOperands, fusedOutputOperands;

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what this is being done for. Could you provide some more context. In general having some callback like that in the result object is not recommended.

@amirBish amirBish force-pushed the amirBish/linalg/expose-getPreservedProducerResults branch from a78e102 to 4ac4db9 Compare December 7, 2023 15:49
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works too... But might be better to leave it as a static method in the result object (as you had it before). It won't matter in the end, all it's doing is providing scoping for the method.

@amirBish amirBish force-pushed the amirBish/linalg/expose-getPreservedProducerResults branch from 4ac4db9 to e21b3fb Compare December 7, 2023 16:22
Copy link

github-actions bot commented Dec 7, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@amirBish amirBish force-pushed the amirBish/linalg/expose-getPreservedProducerResults branch 2 times, most recently from 1146126 to 31e462d Compare December 7, 2023 22:22
…wiseOpFusion file

Declare `getPreservedProducerResults` function which helps
to get the preserved results of the producer linalg generic
operation as a result of elementwise fusion.
@amirBish amirBish force-pushed the amirBish/linalg/expose-getPreservedProducerResults branch from 31e462d to 2a3996d Compare December 8, 2023 08:01
@amirBish amirBish merged commit cf2d625 into llvm:main Dec 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants