Skip to content

[mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes #87660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Apr 4, 2024

This change addresses some of the additional review feedback on #87234.

A summary of the changes:

  1. Cleaned up the language to use 'roll back' rather than revert to reduce the chance of confusion. Improved some function names as well.
  2. Eliminated string comparisons on dialect names.
  3. Prevented the introduction of redundant tensor.cast operations for the same value.

This change addresses some of the additional review feedback on
llvm#87234.

A summary of the changes:

1. Cleaned up the language to use 'roll back' rather than revert to
   reduce the chance of confusion. Improved some function names as well.
2. Eliminated string comparisons on dialect names.
3. Prevented the introduction of redundant tensor.cast operations for the
   same value.
@sabauma sabauma requested a review from jpienaar April 4, 2024 16:55
@sabauma sabauma self-assigned this Apr 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Spenser Bauman (sabauma)

Changes

This change addresses some of the additional review feedback on #87234.

A summary of the changes:

  1. Cleaned up the language to use 'roll back' rather than revert to reduce the chance of confusion. Improved some function names as well.
  2. Eliminated string comparisons on dialect names.
  3. Prevented the introduction of redundant tensor.cast operations for the same value.

Full diff: https://github.com/llvm/llvm-project/pull/87660.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+22-22)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 8614559e2a6f13..d01891a04d2aac 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -39,13 +40,9 @@ namespace {
 // type-inference related interface.
 // When a non-replaceable use is encountered, the value is wrapped in a
 // cast back to the original type after inference.
-bool isReplaceableUser(Operation *user) {
-  // Handle unregistered dialects.
-  if (!user->getDialect())
-    return false;
-
-  return user->getDialect()->getNamespace() ==
-             TosaDialect::getDialectNamespace() ||
+bool canBeRefined(Operation *user) {
+  Dialect *tosaDialect = user->getContext()->getLoadedDialect<TosaDialect>();
+  return user->getDialect() == tosaDialect ||
          isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
 }
 
@@ -53,16 +50,16 @@ bool isReplaceableUser(Operation *user) {
 // updated. For the tosa.while_loop operation, types are speculatively updated
 // within the body region to determine the output type of the while_loop. This
 // process is performed until a fixed point is reached, then the types are
-// reverted.
+// rolled back.
 //
-// This class encapsulates the state information needed to perform the reversion
+// This class encapsulates the state information needed to perform the roll back
 // process or to commit to the final changes.
 class TypeModificationState {
 public:
   TypeModificationState() = default;
 
   ~TypeModificationState() {
-    // Ensure the recorded modifications are either committed or reverted.
+    // Ensure the recorded modifications are either committed or rolled back.
     assert(oldTypes.empty() && "unhandled type modifications");
   }
 
@@ -74,10 +71,9 @@ class TypeModificationState {
     }
   }
 
-  // Revert changes made to the types in the IR by setting all the affected
+  // Roll back changes made to the types in the IR by setting all the affected
   // values to their old types.
-  void revert() {
-    // Otherwise revert the changes.
+  void rollBack() {
     for (auto [value, type] : oldTypes)
       value.setType(type);
 
@@ -91,15 +87,18 @@ class TypeModificationState {
     // For each use whose type changed, cast the value with the new type back to
     // the old type.
     for (auto [value, oldType] : oldTypes) {
+      tensor::CastOp castedValue;
       for (auto &use : value.getUses()) {
-        if (isReplaceableUser(use.getOwner()))
+        if (canBeRefined(use.getOwner()))
           continue;
 
-        OpBuilder builder(value.getContext());
-        builder.setInsertionPoint(use.getOwner());
+        // Cache the cast to avoid generating duplicates
+        if (!castedValue) {
+          ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
+          castedValue = builder.create<tensor::CastOp>(oldType, value);
+        }
 
-        Location loc = value.getLoc();
-        use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+        use.set(castedValue);
       }
     }
 
@@ -211,8 +210,8 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
       argTypes[i] = newType;
     }
 
-    // Revert all changes made during the speculative part of the algorithm.
-    localState.revert();
+    // Roll back all changes made during the speculative part of the algorithm.
+    localState.rollBack();
   }
 
   // We now set the block arguments according to the most recent shape
@@ -228,10 +227,11 @@ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
 }
 
 void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
+
   for (auto &block : region) {
     for (Operation &op : block) {
-      if (!op.getDialect() ||
-          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      if (op.getDialect() != tosaDialect)
         continue;
 
       propagateShapesToTosaIf(op, state);

@joker-eph joker-eph changed the title [mlir][tosa] Address review feedback from jpiennar [mlir][tosa] Cleanups for post-merge review comments in tosa-infer-shapes Apr 9, 2024
@sabauma sabauma force-pushed the infer-shapes-followup branch from 4319af9 to 8ae0ffb Compare April 9, 2024 15:36
@sabauma sabauma merged commit e513f2c into llvm:main May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants