Skip to content

Commit 1b9d475

Browse files
authored
[mlir][tosa] Align validation profiles and extensions to TOSA v1.0 spec (#132768)
- Add missing int16 extension for concat operator - Remove int16 extension for cast operator - Add pro_int and pro_fp profiles for const_shape operator Signed-off-by: Jerry Ge <[email protected]>
1 parent 0fb4ef4 commit 1b9d475

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,7 +1882,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18821882

18831883
list<Availability> availability = [
18841884
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
1885-
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
1885+
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_INT16]>,
18861886
];
18871887

18881888
let hasCanonicalizer = 1;
@@ -2318,7 +2318,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
23182318

23192319
list<Availability> availability = [
23202320
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
2321-
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
2321+
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
23222322
];
23232323

23242324
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
7575

7676
let results = (outs Tosa_Shape : $output);
7777

78+
list<Availability> availability = [
79+
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
80+
Extension<[]>,
81+
];
82+
7883
let hasVerifier = 1;
7984
}
8085

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> {
507507
// CHECK-LABEL: concat
508508
func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
509509
// CHECK: profiles: [ [pro_int, pro_fp] ]
510-
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
510+
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int16] ]
511511
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
512512
return %0 : tensor<26x21x3xf32>
513513
}
@@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
606606
// CHECK-LABEL: cast
607607
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
608608
// CHECK: profiles: [ [pro_int, pro_fp] ]
609-
// CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
609+
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
610610
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
611611
return %0 : tensor<13x21x3xf32>
612612
}

0 commit comments

Comments
 (0)