Skip to content

[mlir][tosa] Fix lowering of tosa.matmul with dynamic outputs #72724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2023

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Nov 18, 2023

The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types.

%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>

When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.

The existing lowering of tosa.matmul will construct illegal tensor.empty
operations when the output type is more dynamic than the input types.

  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>

When constructing the tensor.empty operation, consult the output type
rather than the input types to decide whether a dimension is dynamic.
@llvmbot
Copy link
Member

llvmbot commented Nov 18, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Spenser Bauman (sabauma)

Changes

The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types.

%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>

When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.


Full diff: https://github.com/llvm/llvm-project/pull/72724.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+3-6)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+14)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..9e374be534985e5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -540,21 +540,18 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
     auto outputTy = cast<ShapedType>(op.getType());
     auto outputElementTy = outputTy.getElementType();
 
-    auto firstOperandTy = cast<ShapedType>(op->getOperand(0).getType());
-    auto secondOperandTy = cast<ShapedType>(op->getOperand(1).getType());
-
     SmallVector<Value> dynDims;
     dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
 
-    if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
+    if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
       dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
     }
 
-    if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
+    if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
       dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
     }
 
-    if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
+    if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
       dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..4edc75331932803 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,6 +68,20 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
 
 // -----
 
+// CHECK-LABEL: @matmul_dyn_output
+func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x1x1xf32> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] : tensor<1x1x8xf32>
+  // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
+  // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+  // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+  %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+  return %0 : tensor<?x1x1xf32>
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 

@sabauma
Copy link
Contributor Author

sabauma commented Nov 20, 2023

Hi @matthias-springer,

If you're comfortable with this change, would you mind merging it? I do not have write access to the repo.

@GeorgeARM GeorgeARM merged commit eed17dc into llvm:main Nov 22, 2023
sabauma added a commit to sabauma/llvm-project that referenced this pull request Dec 1, 2023
The lowering of tosa.conv2d produces an illegal tensor.empty operation
where the number of inputs do not match the number of dynamic dimensions
in the output type.

The fix is to base the generation of tensor.dim operations off the
result type of the conv2d operation, rather than the input type.
The problem and fix are very similar to this fix

llvm#72724

but for convolution.
sabauma added a commit that referenced this pull request Dec 1, 2023
The lowering of tosa.conv2d produces an illegal tensor.empty operation
where the number of inputs do not match the number of dynamic dimensions
in the output type.

The fix is to base the generation of tensor.dim operations off the
result type of the conv2d operation, rather than the input type. The
problem and fix are very similar to this fix

#72724

but for convolution.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants