-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][tosa] Allow zero-points to be unranked #143770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThis commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information. Full diff: https://github.com/llvm/llvm-project/pull/143770.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 536551c8f8437..1cfe6eee576b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -152,7 +152,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
-def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;
+def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..9982e1a7fe197 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1007,7 +1007,7 @@ func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1
func.func @test_conv2d_rank0_zp(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi32> {
%input_zp = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<i8>
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
+ // expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<i8>, tensor<1xi8>) -> tensor<1x27x27x16xi32>
return %0 : tensor<1x27x27x16xi32>
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 591a3f0acf65d..f5ab4a2241358 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -333,6 +333,19 @@ func.func @test_dynamic_mixed_matmul(%arg0 : tensor<?x3x?xi32>, %arg1 : tensor<?
// -----
+// CHECK-LABEL: @test_unranked_zero_points_matmul
+func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x3x4xf32>) -> tensor<1x2x4xf32> {
+ %a_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %b_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %a_zp_cast = "tosa.cast"(%a_zp) : (tensor<1xi8>) -> tensor<*xf32>
+ %b_zp_cast = "tosa.cast"(%b_zp) : (tensor<1xi8>) -> tensor<*xf32>
+ // CHECK: tosa.matmul %arg0, %arg1, %2, %3 : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp_cast, %b_zp_cast : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x2x4xf32>
+ return %0 : tensor<1x2x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @test_table_static
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Just one minor nit.
This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information. Change-Id: I20c1a04eb2d306f181ffd5c1574f69fd9e410102
015ab39
to
5f9f404
Compare
This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.
This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.
This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.