Skip to content

[mlir][tosa] Improve tosa-infer-shapes for ops consumed by non-TOSA operators #72715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Nov 17, 2023

TOSA operators consumed by non-TOSA ops generally do not have their types inferred, as that would alter the types expected by their consumers. This prevents type refinement on many TOSA operators when the IR contains a mix of dialects.

This change modifies tosa-infer-shapes to update the types of all TOSA operators during inference. When a consumer of that TOSA op is not safe to update, a tensor.cast is inserted back to the original type. This behavior is similar to how TOSA ops consumed by func.return are handled.

This allows for more type refinement of TOSA ops, and the additional tensor.cast operators may be removed by later canonicalizations.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Spenser Bauman (sabauma)

Changes

TOSA operators consumed by non-TOSA ops generally do not have their types inferred, as that would alter the types expected by their consumers. This prevents type refinement on many TOSA operators when the IR contains a mix of dialects.

This change modifies tosa-infer-shapes to update the types of all TOSA operators during inference. When a consumer of that TOSA op is not safe to update, a tensor.cast is inserted back to the original type. This behavior is similar to how TOSA ops consumed by func.return are handled.

This allows for more type refinement of TOSA ops, and the additional tensor.cast operators may be removed by later canonicalizations.


Full diff: https://github.com/llvm/llvm-project/pull/72715.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+32-43)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+12)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 3cc16a91edce747..94066d044ddd990 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -183,17 +183,27 @@ void propagateShapesToTosaWhile(Operation &op) {
   }
 }
 
+// Track the old type for each operand whose type was updated
+// during inference. This information is used to introduce casts
+// back to the type expected by the operand after inference.
+struct TypeRewriteInfo {
+  OpOperand* operand;
+  Type oldType;
+};
+
 void propagateShapesInRegion(Region &region) {
   // Check whether this use case is replaceable. We define an op as
-  // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
+  // being replaceable if it is used by a TosaOp, or an op with a
   // type-inference related interface.
+  // When a non-replaceable use is encountered, the value is wrapped in a
+  // cast back to the original type after inference.
   auto isReplaceableUser = [](Operation *user) -> bool {
-    return isa<func::ReturnOp>(user) ||
-           user->getDialect()->getNamespace() ==
+    return user->getDialect()->getNamespace() ==
                TosaDialect::getDialectNamespace() ||
            isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
   };
 
+  llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
   for (auto &block : region) {
     for (Operation &op : block) {
       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
@@ -219,9 +229,6 @@ void propagateShapesInRegion(Region &region) {
           Value result = std::get<0>(it);
           ShapedTypeComponents predictedShape = std::get<1>(it);
 
-          if (!llvm::all_of(result.getUsers(), isReplaceableUser))
-            continue;
-
           // Determine the knowledge based on the output type.
           // TODO: should also query WIP type probably
           Type resultTy = result.getType();
@@ -246,10 +253,29 @@ void propagateShapesInRegion(Region &region) {
 
           // Set new type
           result.setType(newKnowledge.getType());
+
+          // Collect all uses of the operation which require update.
+          for (auto& user : result.getUses()) {
+            if (!isReplaceableUser(user.getOwner()))
+              requiresUpdate.push_back({&user, resultTy});
+          }
         }
       }
     }
   }
+
+  // For each use whose type changed, cast the value with the new type back to
+  // the old type.
+  IRRewriter rewriter(region.getContext());
+  for (auto [operand, oldType] : requiresUpdate) {
+    rewriter.setInsertionPoint(operand->getOwner());
+
+    auto oldValue = operand->get();
+
+    auto loc = oldValue.getLoc();
+    auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
+    operand->set(castOp);
+  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes
@@ -259,44 +285,7 @@ struct TosaInferShapes
 public:
   void runOnOperation() override {
     func::FuncOp func = getOperation();
-
-    IRRewriter rewriter(func.getContext());
-
     propagateShapesInRegion(func.getBody());
-
-    // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
-    // the FuncOp type.
-    func.walk([&](func::ReturnOp op) {
-      func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp());
-      if (!parent)
-        return;
-
-      rewriter.setInsertionPoint(op);
-      FunctionType funcTy = func.getFunctionType();
-      auto resultTys = funcTy.getResults();
-
-      bool castAdded = false;
-      SmallVector<Value> castedValues;
-      for (auto it : llvm::zip(op->getOperands(), resultTys)) {
-        auto operand = std::get<0>(it);
-        auto currentTy = operand.getType();
-        auto castTy = std::get<1>(it);
-        if (currentTy == castTy) {
-          castedValues.push_back(operand);
-          continue;
-        }
-
-        castedValues.push_back(
-            rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
-                .getResult());
-
-        castAdded = true;
-      }
-
-      if (castAdded) {
-        rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues);
-      }
-    });
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 7af66ae1dbc90f0..f057431a841b591 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1262,6 +1262,17 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
 
 // -----
 
+// CHECK-LABEL: test_non_tosa_consumer_still_propagates
+func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x?xf32> {
+  // CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32>
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+  %1 = arith.constant dense<[1, 1]> : tensor<2xindex>
+  %2 = tensor.reshape %0(%1) : (tensor<?x1x1xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_tosa_use_def_chain
 func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
   // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
@@ -1298,3 +1309,4 @@ func.func @test_large_constant_permutation() {
   %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
   return
 }
+

@sabauma
Copy link
Contributor Author

sabauma commented Nov 17, 2023

Copy link

github-actions bot commented Nov 17, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

…perators

TOSA operators consumed by non-TOSA ops generally do not have their
types inferred, as that would alter the types expected by their
consumers. This prevents type refinement on many TOSA operators when the
IR contains a mix of dialects.

This change modifies tosa-infer-shapes to update the types of all TOSA
operators during inference. When a consumer of that TOSA op is not safe
to update, a tensor.cast is inserted back to the original type. This
behavior is similar to how TOSA ops consumed by func.return are handled.

This allows for more type refinement of TOSA ops, and the additional
tensor.cast operators may be removed by later canonicalizations.
@sabauma sabauma force-pushed the infer-shapes-improvements branch from e3b95cb to 1f5e063 Compare November 17, 2023 23:40
@rsuderman rsuderman self-requested a review November 20, 2023 18:04
@sabauma
Copy link
Contributor Author

sabauma commented Nov 20, 2023

@rsuderman, if the change looks good to you, would you mind merging. I don't have write permissions.

@sabauma sabauma merged commit 852f6be into llvm:main Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants