Skip to content

[mlir] [TOSA] Allow any floating point type #91745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1857,11 +1857,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
}];

let arguments = (ins
Tosa_Tensor_Plus_F64:$input
Tosa_Tensor:$input
);

let results = (outs
Tosa_Tensor_Plus_F64:$output
Tosa_Tensor:$output
);

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
Expand Down Expand Up @@ -1944,7 +1944,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);

let results = (outs
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);

let hasFolder = 1;
Expand Down
21 changes: 4 additions & 17 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;

//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
def Tosa_Float : AnyTypeOf<[
F32,
F16,
BF16]>;

//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;

// Add F64 type support just for tosa::CastOp and tosa::ConstOp
def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
"number_plus_f64">;

// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, Tosa_Float]>;
Tosa_QuantizedInt, AnyFloat]>;

//===----------------------------------------------------------------------===//
// Tensor types
Expand All @@ -101,18 +89,17 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;

def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
def Tosa_FloatTensor : TensorOf<[AnyFloat]>;

// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;

// Must be ranked but no further constraints
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;

// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
Tosa_Float.predicate]>, "tosa.dtype">;
AnyFloat.predicate]>, "tosa.dtype">;

class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,10 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
}

bool TosaValidation::isValidElementType(Type type) {
if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
return false;
}
if (type.isF64()) {
return false;
if (isa<FloatType>(type)) {
if (profile == TosaProfileEnum::BaseInference)
return false;
return type.isF32() || type.isF16() || type.isBF16();
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
// -----

func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
// expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ func.func @test_const_ui32(%arg0 : tensor<1xui32>) {

// -----

func.func @test_const_f64(%arg0 : tensor<1xf64>) {
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'f64' is not legal}}
%0 = "tosa.const"() {value = dense<0.0> : tensor<1xf64>} : () -> tensor<1xf64>
return
}

// -----

func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
Expand Down