Skip to content

Commit 356bd2c

Browse files
authored
[mlir][tosa] Allow unsigned types for rescale ops during validation (#138253)
This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators. Signed-off-by: Luke Hutton <[email protected]>
1 parent b3ef15a commit 356bd2c

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
562562

563563
bool CheckVariable(Operation *op);
564564
bool CheckVariableReadOrWrite(Operation *op);
565-
bool isValidElementType(Type type);
565+
bool isValidElementType(Type type, const bool allowUnsigned = false);
566566

567567
SmallVector<
568568
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
11761176
return success();
11771177
}
11781178

1179-
bool TosaValidation::isValidElementType(Type type) {
1179+
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
11801180
if (isa<FloatType>(type)) {
11811181
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
11821182
Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
11911191
case 48:
11921192
return true;
11931193
}
1194+
} else if (allowUnsigned && intTy.isUnsigned()) {
1195+
switch (intTy.getWidth()) {
1196+
case 8:
1197+
case 16:
1198+
case 32:
1199+
return true;
1200+
}
11941201
}
11951202
} else if (mlir::isa<tosa::shapeType>(type)) {
11961203
return true;
@@ -1209,19 +1216,23 @@ void TosaValidation::runOnOperation() {
12091216
if (op->getDialect() != tosaDialect)
12101217
return;
12111218

1212-
// perform valid element type check at the beginning to
1213-
// protect rest of code against quantized element types
1219+
// validate operator element types:
1220+
// - rescale operator is allowed to have ui8/ui16/ui32
1221+
// operands/results
1222+
// - perform valid element type check at the beginning to
1223+
// protect rest of code against quantized element types
1224+
const bool opIsRescale = isa<tosa::RescaleOp>(op);
12141225
for (Value operand : op->getOperands()) {
12151226
auto elementTy = getElementTypeOrSelf(operand);
1216-
if (!isValidElementType(elementTy)) {
1227+
if (!isValidElementType(elementTy, opIsRescale)) {
12171228
op->emitOpError() << "is not profile-aligned: element type "
12181229
<< elementTy << " is not legal";
12191230
return signalPassFailure();
12201231
}
12211232
}
12221233
for (Type resultTy : op->getResultTypes()) {
12231234
auto elementTy = getElementTypeOrSelf(resultTy);
1224-
if (!isValidElementType(elementTy)) {
1235+
if (!isValidElementType(elementTy, opIsRescale)) {
12251236
op->emitOpError() << "is not profile-aligned: element type "
12261237
<< elementTy << " is not legal";
12271238
return signalPassFailure();

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te
19371937
%0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
19381938
return %0 : tensor<13x21x3xf32>
19391939
}
1940+
1941+
// -----
1942+
1943+
// CHECK-LABEL: test_rescale_input_unsigned
1944+
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
1945+
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
1946+
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
1947+
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
1948+
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
1949+
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
1950+
return %r : tensor<1x1xi8>
1951+
}
1952+
1953+
// -----
1954+
1955+
// CHECK-LABEL: test_rescale_output_unsigned
1956+
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
1957+
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
1958+
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
1959+
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
1960+
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
1961+
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
1962+
return %r : tensor<1x1xui8>
1963+
}

0 commit comments

Comments
 (0)