Skip to content

Commit 0bb7710

Browse files
lhutton1Jaddyen
authored andcommitted
[mlir][tosa] Allow zero-points to be unranked (llvm#143770)
This commit allows zero-points used by a number of tosa operations to be unranked. This allows the shape inference pass to propagate shape information.
1 parent 3148a18 commit 0bb7710

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
152152

153153
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
154154
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
155-
def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;
155+
def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
156156

157157
// We include unranked tensors as a supported type for all possible tosa
158158
// Tensors as unranked does not guarantee invalid. If unranked tensors exist

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1
10071007
func.func @test_conv2d_rank0_zp(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi32> {
10081008
%input_zp = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<i8>
10091009
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1010-
// expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
1010+
// expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor<i8>'}}
10111011
%0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
10121012
: (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<i8>, tensor<1xi8>) -> tensor<1x27x27x16xi32>
10131013
return %0 : tensor<1x27x27x16xi32>

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,17 @@ func.func @test_dynamic_mixed_matmul(%arg0 : tensor<?x3x?xi32>, %arg1 : tensor<?
333333

334334
// -----
335335

336+
// CHECK-LABEL: @test_unranked_zero_points_matmul
337+
func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x3x4xf32>, %zero_point: tensor<1xf32>) -> tensor<1x2x4xf32> {
338+
// CHECK: %[[ZP:.*]] = tosa.cast %arg2 : (tensor<1xf32>) -> tensor<1xf32>
339+
%zero_point_unranked = "tosa.cast"(%zero_point) : (tensor<1xf32>) -> tensor<*xf32>
340+
// CHECK: tosa.matmul %arg0, %arg1, %[[ZP]], %[[ZP]] : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4xf32>
341+
%0 = tosa.matmul %arg0, %arg1, %zero_point_unranked, %zero_point_unranked : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x2x4xf32>
342+
return %0 : tensor<1x2x4xf32>
343+
}
344+
345+
// -----
346+
336347
// CHECK-LABEL: @test_table_static
337348
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
338349
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>

0 commit comments

Comments
 (0)