Skip to content

Commit ea23897

Browse files
authored
[mlir] [TOSA] Allow any floating point type (#91745)
After #86509 allowed all integer types in TOSA ops, this PR allows TOSA ops on all floating point types. This helps to experiment with `f64` and 8-bit float types when spec conformance is not required.
1 parent 922fafa commit ea23897

File tree

5 files changed

+20
-26
lines changed

5 files changed

+20
-26
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,11 +1857,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
18571857
}];
18581858

18591859
let arguments = (ins
1860-
Tosa_Tensor_Plus_F64:$input
1860+
Tosa_Tensor:$input
18611861
);
18621862

18631863
let results = (outs
1864-
Tosa_Tensor_Plus_F64:$output
1864+
Tosa_Tensor:$output
18651865
);
18661866

18671867
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
@@ -1944,7 +1944,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
19441944
);
19451945

19461946
let results = (outs
1947-
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
1947+
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
19481948
);
19491949

19501950
let hasFolder = 1;

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,28 +71,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
7171
Tosa_QuantizedType<"int16", [16, 0], 1>,
7272
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
7373

74-
//===----------------------------------------------------------------------===//
75-
// Floating-point types.
76-
//===----------------------------------------------------------------------===//
77-
def Tosa_Float : AnyTypeOf<[
78-
F32,
79-
F16,
80-
BF16]>;
81-
8274
//===----------------------------------------------------------------------===//
8375
// Multi-category types.
8476
//===----------------------------------------------------------------------===//
85-
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
77+
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8678
"number">;
8779

88-
// Add F64 type support just for tosa::CastOp and tosa::ConstOp
89-
def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
90-
"number_plus_f64">;
91-
9280
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
9381
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
9482
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
95-
Tosa_QuantizedInt, Tosa_Float]>;
83+
Tosa_QuantizedInt, AnyFloat]>;
9684

9785
//===----------------------------------------------------------------------===//
9886
// Tensor types
@@ -101,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
10189
def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
10290
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
10391

104-
def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
92+
def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
10593

10694
// Either ranked or unranked tensor of TOSA supported element types.
10795
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
108-
def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
10996

11097
// Must be ranked but no further constraints
11198
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
11299

113100
// Any tensor element type allowed in Tosa ops.
114101
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
115-
Tosa_Float.predicate]>, "tosa.dtype">;
102+
AnyFloat.predicate]>, "tosa.dtype">;
116103

117104
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
118105
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,11 +506,10 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
506506
}
507507

508508
bool TosaValidation::isValidElementType(Type type) {
509-
if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
510-
return false;
511-
}
512-
if (type.isF64()) {
513-
return false;
509+
if (isa<FloatType>(type)) {
510+
if (profile == TosaProfileEnum::BaseInference)
511+
return false;
512+
return type.isF32() || type.isF16() || type.isBF16();
514513
}
515514
if (auto intTy = dyn_cast<IntegerType>(type)) {
516515
if (intTy.isUnsigned()) {

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
2020
// -----
2121

2222
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
23-
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
23+
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
2424
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
2525
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
2626
return %0 : tensor<1x27x27x16xi8>

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
131131

132132
// -----
133133

134+
func.func @test_const_f64(%arg0 : tensor<1xf64>) {
135+
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}}
136+
%0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64>
137+
return
138+
}
139+
140+
// -----
141+
134142
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
135143
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
136144
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

0 commit comments

Comments
 (0)