Skip to content

Commit 71b1bc8

Browse files
committed
[TOSA] Allow all integer types in most ops (llvm#86509)
As discussed in one of the previous TOSA community meetings, we would like to allow for more integer types in the TOSA dialect to enable more use cases. For strict standards conformance, the TosaValidation pass can be used. Follow up PRs will extend conversions from TOSA where needed.
1 parent 55b8b59 commit 71b1bc8

File tree

4 files changed

+71
-26
lines changed

4 files changed

+71
-26
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1960,7 +1960,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
19601960
);
19611961

19621962
let results = (outs
1963-
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
1963+
TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
19641964
);
19651965

19661966
let hasFolder = 1;

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
3838
// Used to express accumulator results or compare results.
3939
//===----------------------------------------------------------------------===//
4040

41-
def Tosa_UInt8 : UI<8>;
42-
def Tosa_UInt16 : UI<16>;
43-
4441
def Tosa_Int4 : I<4>;
4542
def Tosa_Int8 : I<8>;
46-
def Tosa_Int16 : I<16>;
4743
def Tosa_Int32 : I<32>;
48-
def Tosa_Int48 : I<48>;
4944
def Tosa_Int64 : I<64>;
5045

51-
def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
52-
Tosa_Int16,
53-
Tosa_Int32,
54-
Tosa_Int48,
55-
Tosa_Int64]>;
56-
57-
def Tosa_Bool : I<1>;
58-
59-
def Tosa_Int : AnyTypeOf<[Tosa_Bool,
60-
AnyUnsignedInteger,
61-
AnySignlessInteger,
62-
// TODO: For backwards compatibility, keep Tosa_SignedInt, which is actually
63-
// a set of signless types.
64-
Tosa_SignedInt]>;
46+
// The TOSA dialect allows more types than the TOSA standard to allow for
47+
// experimentation. For historical reasons, signless is used in the place of
48+
// signed.
49+
// The TosaValidation pass can be used to check for standard conformance.
50+
def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
51+
AnySignlessInteger]>;
6552

6653
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
6754
Tosa_Int64]>;
@@ -173,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<
173160

174161
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
175162
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
176-
def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
177-
def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
178-
def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
179163

180164
//===----------------------------------------------------------------------===//
181165
// Attribute predicates and classes.

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
410410
bool CheckVariable(Operation *op);
411411
bool CheckVariableReadOrWrite(Operation *op);
412412

413+
bool isValidElementType(Type type);
414+
413415
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
414416
TosaLevel tosaLevel;
415417
DenseMap<StringAttr, mlir::Type> variablesMap;
@@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
503505
return success();
504506
}
505507

508+
bool TosaValidation::isValidElementType(Type type) {
509+
if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
510+
return false;
511+
}
512+
if (type.isF64()) {
513+
return false;
514+
}
515+
if (auto intTy = dyn_cast<IntegerType>(type)) {
516+
if (intTy.isUnsigned()) {
517+
switch (intTy.getWidth()) {
518+
case 8:
519+
case 16:
520+
return true;
521+
default:
522+
return false;
523+
}
524+
} else {
525+
// Signless - treated as signed.
526+
switch (intTy.getWidth()) {
527+
case 1:
528+
case 4:
529+
case 8:
530+
case 16:
531+
case 32:
532+
case 48:
533+
case 64:
534+
return true;
535+
default:
536+
return false;
537+
}
538+
}
539+
return false;
540+
}
541+
return true;
542+
}
543+
506544
void TosaValidation::runOnOperation() {
507545
configLevelAndProfile();
508546
getOperation().walk([&](Operation *op) {
509547
for (Value operand : op->getOperands()) {
510-
if ((profile == TosaProfileEnum::BaseInference) &&
511-
isa<FloatType>(getElementTypeOrSelf(operand))) {
548+
auto elementTy = getElementTypeOrSelf(operand);
549+
if (!isValidElementType(elementTy)) {
550+
op->emitOpError() << "is not profile-aligned: element type "
551+
<< elementTy << " is not legal";
512552
return signalPassFailure();
513553
}
514-
if (getElementTypeOrSelf(operand).isF64()) {
554+
}
555+
for (Type resultTy : op->getResultTypes()) {
556+
auto elementTy = getElementTypeOrSelf(resultTy);
557+
if (!isValidElementType(elementTy)) {
558+
op->emitOpError() << "is not profile-aligned: element type "
559+
<< elementTy << " is not legal";
515560
return signalPassFailure();
516561
}
517562
}

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
115115

116116
// -----
117117

118+
func.func @test_const_i2(%arg0 : tensor<1xi2>) {
119+
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'i2' is not legal}}
120+
%0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
121+
return
122+
}
123+
124+
// -----
125+
126+
func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
127+
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui32' is not legal}}
128+
%0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
129+
return
130+
}
131+
132+
// -----
133+
118134
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
119135
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
120136
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :

0 commit comments

Comments
 (0)