Skip to content

[mlir][tosa] Fix tosa-infer-shapes crash #87234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Apr 1, 2024

The tosa-infer-shapes pass inserts tensor.cast operations to mediate refined result types with consumers whose types cannot be refined. This process interferes with how types are refined in tosa.while_loop body regions, where types are propagated speculatively (to determine the types of the tosa.yield terminator) and then reverted.

The new tosa.cast operations result in a crash due to not having types associated to them for the reversion process.

This change modifies the shape propagation behavior so that the introduction to tensor.cast operations behaves better with this type reversion process. The new behavior is to only introduce tensor.cast operations once we wish to commit the newly computed types to the IR.

This is an example causing the crash:

func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>

  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
    tosa.yield %3 : tensor<*xi1>
  } do {
  ^bb0(%arg1: tensor<*xi32>):
    // Inferrable operation whose type will refine to tensor<i32>
    %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>

    // Non-inferrable use site, will require the cast:
    //     tensor.cast %3 : tensor<i32> to tensor<*xi32>
    // 
    // The new cast operation will result in accessing undefined memory through
    // originalTypeMap in the C++ code.
    "use"(%3) : (tensor<*xi32>) -> ()
    tosa.yield %3 : tensor<*xi32>
  }

  return %1 : tensor<*xi32>
}

The tensor.cast operation inserted in the loop body causes a failure in the code which resets the types after propagation through the loop body:

// The types inferred in the block assume the operand types specified for
// this iteration. We need to restore the original types to ensure that
// future iterations only use the already specified types, not possible
// types from previous iterations.
for (auto &block : bodyRegion) {
  for (auto arg : block.getArguments())
    arg.setType(originalTypeMap[arg]);
  for (auto &op : block)
    for (auto result : op.getResults())
      result.setType(originalTypeMap[result]);  // problematic access
}

The tosa-infer-shapes pass inserts tensor.cast operations to mediate
refined result types with consumers whose types cannot be refined.
This process interferes with how types are refined in tosa.while_loop
body regions, where types are propagated speculatively (to determine the
types of the tosa.yield terminator) and then reverted.

The new tosa.cast operations result in a crash due to not having types
associated to them for the reversion process.

This change modifies the shape propagation behavior so that the
introduction to tensor.cast operations behaves better with this type
reversion process. The new behavior is to only introduce tensor.cast
operations once we wish to commit the newly computed types to the IR.
@sabauma sabauma requested review from rsuderman and eric-k256 April 1, 2024 12:26
@sabauma
Copy link
Contributor Author

sabauma commented Apr 1, 2024

@rafaelubalmw

@llvmbot
Copy link
Member

llvmbot commented Apr 1, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Spenser Bauman (sabauma)

Changes

The tosa-infer-shapes pass inserts tensor.cast operations to mediate refined result types with consumers whose types cannot be refined. This process interferes with how types are refined in tosa.while_loop body regions, where types are propagated speculatively (to determine the types of the tosa.yield terminator) and then reverted.

The new tosa.cast operations result in a crash due to not having types associated to them for the reversion process.

This change modifies the shape propagation behavior so that the introduction to tensor.cast operations behaves better with this type reversion process. The new behavior is to only introduce tensor.cast operations once we wish to commit the newly computed types to the IR.


Full diff: https://github.com/llvm/llvm-project/pull/87234.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+103-95)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+115-1)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index ad28c564f7dbdd..8614559e2a6f13 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,14 +18,9 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
 
 namespace {
 
-void propagateShapesInRegion(Region &region);
+// Check whether this use case is replaceable. We define an op as
+// being replaceable if it is used by a TosaOp, or an op with a
+// type-inference related interface.
+// When a non-replaceable use is encountered, the value is wrapped in a
+// cast back to the original type after inference.
+bool isReplaceableUser(Operation *user) {
+  // Handle unregistered dialects.
+  if (!user->getDialect())
+    return false;
+
+  return user->getDialect()->getNamespace() ==
+             TosaDialect::getDialectNamespace() ||
+         isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+}
+
+// During type propagation, the types of values in the operator graph are
+// updated. For the tosa.while_loop operation, types are speculatively updated
+// within the body region to determine the output type of the while_loop. This
+// process is performed until a fixed point is reached, then the types are
+// reverted.
+//
+// This class encapsulates the state information needed to perform the reversion
+// process or to commit to the final changes.
+class TypeModificationState {
+public:
+  TypeModificationState() = default;
+
+  ~TypeModificationState() {
+    // Ensure the recorded modifications are either committed or reverted.
+    assert(oldTypes.empty() && "unhandled type modifications");
+  }
+
+  // Update the state of the value and record the old type.
+  void setType(Value value, Type type) {
+    if (value.getType() != type) {
+      oldTypes.emplace_back(value, value.getType());
+      value.setType(type);
+    }
+  }
 
-void propagateShapesToTosaIf(Operation &op) {
+  // Revert changes made to the types in the IR by setting all the affected
+  // values to their old types.
+  void revert() {
+    // Otherwise revert the changes.
+    for (auto [value, type] : oldTypes)
+      value.setType(type);
+
+    oldTypes.clear();
+  }
+
+  // Commit the changes to the types in the IR.
+  // This requires inserting tensor.cast operations to mediate the newly
+  // inferred result types with users that do not support type inference.
+  void commit() {
+    // For each use whose type changed, cast the value with the new type back to
+    // the old type.
+    for (auto [value, oldType] : oldTypes) {
+      for (auto &use : value.getUses()) {
+        if (isReplaceableUser(use.getOwner()))
+          continue;
+
+        OpBuilder builder(value.getContext());
+        builder.setInsertionPoint(use.getOwner());
+
+        Location loc = value.getLoc();
+        use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+      }
+    }
+
+    oldTypes.clear();
+  }
+
+private:
+  // A record of each value whose type was updated along with that value's
+  // previous type.
+  llvm::SmallVector<std::pair<Value, Type>> oldTypes;
+};
+
+void propagateShapesInRegion(Region &region, TypeModificationState &state);
+
+void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
   IfOp ifOp = dyn_cast<IfOp>(op);
   if (!ifOp)
     return;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
 
       if (inferredTy.hasRank()) {
         Type newType = oldType.clone(inferredTy.getShape());
-        blockArg.setType(newType);
+        state.setType(blockArg, newType);
       }
     }
 
@@ -71,14 +144,14 @@ void propagateShapesToTosaIf(Operation &op) {
           ValueKnowledge::join(operandKnowledge, blockKnowledge);
       if (!joinedKnowledge)
         continue;
-      frontBlock.getArgument(i).setType(joinedKnowledge.getType());
+      state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
     }
 
-    propagateShapesInRegion(region);
+    propagateShapesInRegion(region, state);
   }
 }
 
-void propagateShapesToTosaWhile(Operation &op) {
+void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
   WhileOp whileOp = dyn_cast<WhileOp>(op);
   if (!whileOp)
     return;
@@ -86,49 +159,29 @@ void propagateShapesToTosaWhile(Operation &op) {
   // Determine what the expected argument types are to the cond/body blocks.
   // The expected arguments should be compatible with ever iteration of the
   // loop body / condition for tosa.while.
-  llvm::SmallVector<Type> argTypes;
-  for (auto operand : op.getOperands()) {
-    auto operandTy = cast<ShapedType>(operand.getType());
-    if (operandTy.hasRank()) {
-      auto newTy = operandTy.clone(operandTy.getShape());
-      argTypes.push_back(newTy);
-    } else {
-      argTypes.push_back(operand.getType());
-    }
-  }
-
-  // Save out the type information so we can restore at the end.
-  llvm::DenseMap<Value, Type> originalTypeMap;
-  for (auto &block : op.getRegion(1)) {
-    for (auto arg : block.getArguments())
-      originalTypeMap[arg] = arg.getType();
-    for (auto &op : block)
-      for (auto result : op.getResults())
-        originalTypeMap[result] = result.getType();
-  }
+  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
 
   bool hasNewTypes = true;
   while (hasNewTypes) {
+    TypeModificationState localState;
 
     // Set types on the block args.
     Region &bodyRegion = op.getRegion(1);
     Block &block = bodyRegion.front();
     for (int i = 0, s = argTypes.size(); i < s; i++) {
-      block.getArgument(i).setType(argTypes[i]);
+      localState.setType(block.getArgument(i), argTypes[i]);
     }
 
     // Propagate to the end.
-    propagateShapesInRegion(bodyRegion);
+    propagateShapesInRegion(bodyRegion, localState);
 
-    // Find all the tosa yield types and verify there is atleast one.
+    // Find all the tosa yield types and verify there is a single one.
     llvm::SmallVector<YieldOp> yieldOps;
     for (auto &block : bodyRegion)
       if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
         yieldOps.push_back(yieldOp);
 
-    if (yieldOps.empty())
-      return;
-
+    assert(yieldOps.size() == 1 && "missing or non-unique yield op");
     // Using the new tosa.yield operand types, infer the new subtypes.
     llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
     for (auto ty : argTypes) {
@@ -158,17 +211,8 @@ void propagateShapesToTosaWhile(Operation &op) {
       argTypes[i] = newType;
     }
 
-    // The types inferred in the block assume the operand types specified for
-    // this iteration. We need to restore the original types to ensure that
-    // future iterations only use the already specified types, not possible
-    // types from previous iterations.
-    for (auto &block : bodyRegion) {
-      for (auto arg : block.getArguments())
-        arg.setType(originalTypeMap[arg]);
-      for (auto &op : block)
-        for (auto result : op.getResults())
-          result.setType(originalTypeMap[result]);
-    }
+    // Revert all changes made during the speculative part of the algorithm.
+    localState.revert();
   }
 
   // We now set the block arguments according to the most recent shape
@@ -176,41 +220,22 @@ void propagateShapesToTosaWhile(Operation &op) {
   // iteration.
   for (auto &region : op.getRegions()) {
     for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
-      region.front().getArgument(i).setType(argTypes[i]);
+      state.setType(region.front().getArgument(i), argTypes[i]);
     }
 
-    propagateShapesInRegion(region);
+    propagateShapesInRegion(region, state);
   }
 }
 
-// Track the old type for each operand whose type was updated
-// during inference. This information is used to introduce casts
-// back to the type expected by the operand after inference.
-struct TypeRewriteInfo {
-  OpOperand *operand;
-  Type oldType;
-};
-
-void propagateShapesInRegion(Region &region) {
-  // Check whether this use case is replaceable. We define an op as
-  // being replaceable if it is used by a TosaOp, or an op with a
-  // type-inference related interface.
-  // When a non-replaceable use is encountered, the value is wrapped in a
-  // cast back to the original type after inference.
-  auto isReplaceableUser = [](Operation *user) -> bool {
-    return user->getDialect()->getNamespace() ==
-               TosaDialect::getDialectNamespace() ||
-           isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
-  };
-
-  llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
+void propagateShapesInRegion(Region &region, TypeModificationState &state) {
   for (auto &block : region) {
     for (Operation &op : block) {
-      if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+      if (!op.getDialect() ||
+          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
         continue;
 
-      propagateShapesToTosaIf(op);
-      propagateShapesToTosaWhile(op);
+      propagateShapesToTosaIf(op, state);
+      propagateShapesToTosaWhile(op, state);
 
       InferShapedTypeOpInterface shapeInterface =
           dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region &region) {
             continue;
 
           // Set new type
-          result.setType(newKnowledge.getType());
-
-          // Collect all uses of the operation which require update.
-          for (auto &user : result.getUses()) {
-            if (!isReplaceableUser(user.getOwner()))
-              requiresUpdate.push_back({&user, resultTy});
-          }
+          state.setType(result, newKnowledge.getType());
         }
       }
     }
   }
-
-  // For each use whose type changed, cast the value with the new type back to
-  // the old type.
-  IRRewriter rewriter(region.getContext());
-  for (auto [operand, oldType] : requiresUpdate) {
-    rewriter.setInsertionPoint(operand->getOwner());
-
-    auto oldValue = operand->get();
-
-    auto loc = oldValue.getLoc();
-    auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
-    operand->set(castOp);
-  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
 public:
   void runOnOperation() override {
     func::FuncOp func = getOperation();
-    propagateShapesInRegion(func.getBody());
+    TypeModificationState state;
+    propagateShapesInRegion(func.getBody(), state);
+    state.commit();
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1f0cfaf92c5c74..781e2ddf76fbd3 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
 
 // CHECK-LABEL: @test_return
 func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
@@ -1177,6 +1177,120 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
 
 // -----
 
+// This test locks down a fix for a crash in the type inference process.
+// The relevant pattern is a while loop whose body contains a TOSA operation which is
+// consumed by a non-inferrable user in the same body.
+// Previously, this would trigger a crash due to how types are cached and then
+// reapplied to the operations in the loops body.
+
+// CHECK-LABEL: @while_dont_crash
+func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK:      tosa.while_loop
+  // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
+    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+
+    // CHECK:       tosa.greater_equal
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+
+    tosa.yield %3 : tensor<*xi1>
+
+  } do {
+
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  ^bb0(%arg1: tensor<*xi32>):
+
+    // CHECK:     tosa.add
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+
+    // CHECK: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+    // CHECK: "use"(%[[CAST]]) : (tensor<*xi32>) -> ()
+    "use"(%3) : (tensor<*xi32>) -> ()
+
+    tosa.yield %3 : tensor<*xi32>
+  }
+
+  // CHECK: tensor.cast
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+// This test locks down a fix for a crash in the type inference process.
+// The relevant pattern is a while loop whose body contains a TOSA operation which is
+// consumed by a non-inferrable user in the same body.
+
+// CHECK-LABEL: @while_dont_crash_nested
+func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+  %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK:      tosa.while_loop
+  // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+  %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> {
+    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+
+    // CHECK:       tosa.greater_equal
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+
+    // CHECK:      tosa.yield
+    // CHECK-SAME: tensor<i1>
+    tosa.yield %3 : tensor<*xi1>
+
+  } do {
+
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  ^bb0(%arg1: tensor<*xi32>):
+
+    // CHECK:      tosa.while_loop
+    // CHECK-SAME: (tensor<i32>) -> tensor<i32>
+    %1 = tosa.while_loop (%arg2 = %arg1) : (tensor<*xi32>) -> tensor<*xi32> {
+      %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
+
+      // CHECK:       tosa.greater_equal
+      // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+      %4 = tosa.greater_equal %2, %arg2 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+
+      // CHECK:      tosa.yield
+      // CHECK-SAME: tensor<i1>
+      tosa.yield %4 : tensor<*xi1>
+
+    } do {
+
+    // CHECK:      ^bb0
+    // CHECK-SAME: tensor<i32>
+    ^bb0(%arg2: tensor<*xi32>):
+
+      // CHECK:     tosa.add
+      // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+      %4 = tosa.add %arg2, %arg2 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
+
+      // CHECK: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
+      // CHECK: "use"(%[[CAST]]) : (tensor<*xi32>) -> ()
+      "use"(%4) : (tensor<*xi32>) -> ()
+
+      // CHECK:      tosa.yield
+      // CHECK-SAME: tensor<i32>
+      tosa.yield %4 : tensor<*xi32>
+    }
+
+    // CHECK:      tosa.yield
+    // CHECK-SAME: tensor<i32>
+    tosa.yield %1 : tensor<*xi32>
+  }
+
+  // CHECK: tensor.cast
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_static_rfft2d
 func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () {
   // CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>)

Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation appears to be fine, just could use some condensing of the test / checks.

// CHECK-LABEL: @while_dont_crash
func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
%0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try to condense this a little bit? There are a ton of empty newlines everywhere and it would be nice to get it a little more compact/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trimmed out all the whitespace. I was just following the structure of prior tests. Arguably, the filecheck directives are not that important, since we're just ensuring the pass does not crash.

I would also be happy moving this to a different file, rather than having it inline with all the correctness tests.

@sabauma sabauma merged commit 0a94d35 into llvm:main Apr 2, 2024
if (!user->getDialect())
return false;

return user->getDialect()->getNamespace() ==
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is a string compare for dialects, comparing the id's would be cheaper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code already existed in the pass. I just moved it into a separate function so it could be used elsewhere.

@@ -39,9 +34,87 @@ using namespace mlir::tosa;

namespace {

void propagateShapesInRegion(Region &region);
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a TosaOp, or an op with a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC what TOSA ops do not have the type inference interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried removing that condition. Based on the how the tests failed, at least tosa.yield. Everthing in mlir/Dialect/Tosa/IR/TosaUtilOps.td lacks an inference interface as well.

// process is performed until a fixed point is reached, then the types are
// reverted.
//
// This class encapsulates the state information needed to perform the reversion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/reversion/roll back/ ? (just given that's the term used in dialect conversion)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can update that. Given that I've already merged the change, I'll create a followup PR to address your comments.

if (isReplaceableUser(use.getOwner()))
continue;

OpBuilder builder(value.getContext());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can combine these 3 instructions by using

ImplicitLocOpBuilder builder(value.getLoc(), use.getOwner());

// This requires inserting tensor.cast operations to mediate the newly
// inferred result types with users that do not support type inference.
void commit() {
// For each use whose type changed, cast the value with the new type back to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast is only where type information can't be propagated/refined (e.g., isReplaceableUser == canBeRefined).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. canBeRefined would be a better name.

// For each use whose type changed, cast the value with the new type back to
// the old type.
for (auto [value, oldType] : oldTypes) {
for (auto &use : value.getUses()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you have a user that uses the result of the op multiple times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic will generate 1 tensor.cast operation per use. That was true of the prior logic, though.

for (auto result : op.getResults())
result.setType(originalTypeMap[result]);
}
// Revert all changes made during the speculative part of the algorithm.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this speculative? Is it too aggressive in assuming equality but lacking a way to relax it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 'speculative' is the wrong word. Happy to change if you have a better term.

The inference step for while loops, refines the block argument types and propagates those types through the body block multiple times until a fixed-point is reached. This is to infer consistent types for the tosa.while_loop's block arguments, yield arguments, and operation results.

After each propagation, the types of each operation in the body are reset to their previous values to avoid leaking possibly overly refined results between propagation steps.

I think that, ideally, the TosaInferShapes pass would avoid modifying the IR until it has fully inferred the desired types, and only then update the IR. That would avoid these statefulness issues. I initially attempted to do that, but the InferShapedTypeOpInterface does not allow enough information to be passed in to the inference method (e.g. refinements made to body region operations).

for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
if (!op.getDialect() ||
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is string compare too.

sabauma pushed a commit to sabauma/llvm-project that referenced this pull request Apr 3, 2024
This change addresses some of the additional review feedback on
llvm#87234.

A summary of the changes:

1. Cleaned up the language to use 'roll back' rather than revert to
   reduce the chance of confusion. Improved some function names as well.
2. Eliminated string comparisons on dialect names.
3. Prevented the introduction of redundant tensor.cast operations for the
   same value.
sabauma added a commit that referenced this pull request May 10, 2024
…apes (#87660)

This change addresses some of the additional review feedback on
#87234.

A summary of the changes:

1. Cleaned up the language to use 'roll back' rather than revert to
reduce the chance of confusion. Improved some function names as well.
2. Eliminated string comparisons on dialect names.
3. Prevented the introduction of redundant tensor.cast operations for
the same value.

---------

Co-authored-by: Spenser Bauman <sabauma@fastmail>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants