-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Fix lowering of tosa.matmul with dynamic outputs #72724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types. %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32> When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Spenser Bauman (sabauma) ChangesThe existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types. %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32> When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic. Full diff: https://github.com/llvm/llvm-project/pull/72724.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..9e374be534985e5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -540,21 +540,18 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
auto outputTy = cast<ShapedType>(op.getType());
auto outputElementTy = outputTy.getElementType();
- auto firstOperandTy = cast<ShapedType>(op->getOperand(0).getType());
- auto secondOperandTy = cast<ShapedType>(op->getOperand(1).getType());
-
SmallVector<Value> dynDims;
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
- if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
+ if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
}
- if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
+ if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
}
- if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
+ if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..4edc75331932803 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,6 +68,20 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
// -----
+// CHECK-LABEL: @matmul_dyn_output
+func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x1x1xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] : tensor<1x1x8xf32>
+ // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
+ // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+ // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
+ %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+ return %0 : tensor<?x1x1xf32>
+}
+
+// -----
+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
If you're comfortable with this change, would you mind merging it? I do not have write access to the repo. |
The lowering of tosa.conv2d produces an illegal tensor.empty operation where the number of inputs do not match the number of dynamic dimensions in the output type. The fix is to base the generation of tensor.dim operations off the result type of the conv2d operation, rather than the input type. The problem and fix are very similar to this fix llvm#72724 but for convolution.
The lowering of tosa.conv2d produces an illegal tensor.empty operation where the number of inputs do not match the number of dynamic dimensions in the output type. The fix is to base the generation of tensor.dim operations off the result type of the conv2d operation, rather than the input type. The problem and fix are very similar to this fix #72724 but for convolution.
The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types.
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.