Skip to content

Commit c6d419c

Browse files
authored
[TOSA] Allow all integer types in most ops (#86509)
As discussed in one of the previous TOSA community meetings, we would like to allow for more integer types in the TOSA dialect to enable more use cases. For strict standards conformance, the TosaValidation pass can be used. Follow up PRs will extend conversions from TOSA where needed.
1 parent a22bd00 commit c6d419c

File tree

4 files changed

+71
-25
lines changed

4 files changed

+71
-25
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
19421942
);
19431943

19441944
let results = (outs
1945-
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
1945+
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
19461946
);
19471947

19481948
let hasFolder = 1;

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
3838
// Used to express accumulator results or compare results.
3939
//===----------------------------------------------------------------------===//
4040

41-
def Tosa_UInt8 : UI<8>;
42-
def Tosa_UInt16 : UI<16>;
43-
4441
def Tosa_Int4 : I<4>;
4542
def Tosa_Int8 : I<8>;
46-
def Tosa_Int16 : I<16>;
4743
def Tosa_Int32 : I<32>;
48-
def Tosa_Int48 : I<48>;
4944
def Tosa_Int64 : I<64>;
5045

51-
def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
52-
Tosa_Int16,
53-
Tosa_Int32,
54-
Tosa_Int48,
55-
Tosa_Int64]>;
56-
57-
def Tosa_Bool : I<1>;
58-
59-
// No unsigned unquantized int types.
60-
def Tosa_Int : AnyTypeOf<[Tosa_Bool,
61-
Tosa_UInt8,
62-
Tosa_UInt16,
63-
Tosa_SignedInt]>;
46+
// The TOSA dialect allows more types than the TOSA standard to allow for
47+
// experimentation. For historical reasons, signless is used in the place of
48+
// signed.
49+
// The TosaValidation pass can be used to check for standard conformance.
50+
def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
51+
AnySignlessInteger]>;
6452

6553
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
6654
Tosa_Int64]>;
@@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<
172160

173161
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
174162
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
175-
def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
176-
def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
177-
def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
178163

179164
//===----------------------------------------------------------------------===//
180165
// Attribute predicates and classes.

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
410410
bool CheckVariable(Operation *op);
411411
bool CheckVariableReadOrWrite(Operation *op);
412412

413+
bool isValidElementType(Type type);
414+
413415
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
414416
TosaLevel tosaLevel;
415417
DenseMap<StringAttr, mlir::Type> variablesMap;
@@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
503505
return success();
504506
}
505507

508+
bool TosaValidation::isValidElementType(Type type) {
509+
if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
510+
return false;
511+
}
512+
if (type.isF64()) {
513+
return false;
514+
}
515+
if (auto intTy = dyn_cast<IntegerType>(type)) {
516+
if (intTy.isUnsigned()) {
517+
switch (intTy.getWidth()) {
518+
case 8:
519+
case 16:
520+
return true;
521+
default:
522+
return false;
523+
}
524+
} else {
525+
// Signless - treated as signed.
526+
switch (intTy.getWidth()) {
527+
case 1:
528+
case 4:
529+
case 8:
530+
case 16:
531+
case 32:
532+
case 48:
533+
case 64:
534+
return true;
535+
default:
536+
return false;
537+
}
538+
}
539+
return false;
540+
}
541+
return true;
542+
}
543+
506544
void TosaValidation::runOnOperation() {
507545
configLevelAndProfile();
508546
getOperation().walk([&](Operation *op) {
509547
for (Value operand : op->getOperands()) {
510-
if ((profile == TosaProfileEnum::BaseInference) &&
511-
isa<FloatType>(getElementTypeOrSelf(operand))) {
548+
auto elementTy = getElementTypeOrSelf(operand);
549+
if (!isValidElementType(elementTy)) {
550+
op->emitOpError() << "is not profile-aligned: element type "
551+
<< elementTy << " is not legal";
512552
return signalPassFailure();
513553
}
514-
if (getElementTypeOrSelf(operand).isF64()) {
554+
}
555+
for (Type resultTy : op->getResultTypes()) {
556+
auto elementTy = getElementTypeOrSelf(resultTy);
557+
if (!isValidElementType(elementTy)) {
558+
op->emitOpError() << "is not profile-aligned: element type "
559+
<< elementTy << " is not legal";
515560
return signalPassFailure();
516561
}
517562
}

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
115115

116116
// -----
117117

118+
func.func @test_const_i2(%arg0 : tensor<1xi2>) {
119+
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'i2' is not legal}}
120+
%0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
121+
return
122+
}
123+
124+
// -----
125+
126+
func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
127+
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui32' is not legal}}
128+
%0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
129+
return
130+
}
131+
132+
// -----
133+
118134
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
119135
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
120136
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

0 commit comments

Comments
 (0)