Skip to content

[mlir][tosa] Enhance error_if and verify checks for RESCALE Op #137021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3206,6 +3206,12 @@ LogicalResult RescaleOp::verify() {
// otherwise numChannel is dimension in input shape's last axis
int64_t numChannels = 1;
if (getPerChannel()) {
if (inputType.getRank() < 1) {
emitOpError("requires input to be at least rank 1 when per_channel is "
"true, but got rank ")
<< inputType.getRank();
return failure();
}
numChannels = inputType.getDimSize(inputType.getRank() - 1);
}

Expand Down
82 changes: 81 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,88 @@ bool checkErrorIfTable(Operation *op) {
return true;
}

bool checkErrorIfRescale(Operation *op) {
auto rescale = dyn_cast<tosa::RescaleOp>(op);
if (!rescale)
return true;

auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
!outputType.getElementType().isInteger())
return true;

auto inElemType = inputType.getElementType();
auto outElemType = outputType.getElementType();
auto inWidth = inElemType.getIntOrFloatBitWidth();
auto outWidth = outElemType.getIntOrFloatBitWidth();

bool inputUnsigned = rescale.getInputUnsigned();
bool outputUnsigned = rescale.getOutputUnsigned();

bool scale32 = rescale.getScale32();
auto roundingMode = rescale.getRoundingMode();

// ERROR_IF(scale32 && is_same<in_t,i48_t>())
if (scale32 && inWidth == 48) {
op->emitOpError() << "scale32 is not allowed with 48-bit input.";
return false;
}

// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
if (!scale32 && roundingMode == "DOUBLE_ROUND") {
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
return false;
}

// ERROR_IF(input_unsigned && output_unsigned)
if (inputUnsigned && outputUnsigned) {
op->emitOpError() << "input and output cannot be both unsigned.";
return false;
}

// ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
if (outWidth == 32 && inputUnsigned) {
op->emitOpError() << "i32 output type is not allowed with unsigned input.";
return false;
}

// ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
if (inWidth == 32 && outputUnsigned) {
op->emitOpError() << "i32 input type is not allowed with unsigned output.";
return false;
}

// ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
if (inWidth == 48 && outputUnsigned) {
op->emitOpError() << "i48 input type is not allowed with unsigned output.";
return false;
}

// ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
if (inWidth == 48 && inputUnsigned) {
op->emitOpError() << "i48 input type cannot be unsigned.";
return false;
}

// ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
if (inWidth == 32 && inputUnsigned) {
op->emitOpError() << "i32 input type cannot be unsigned.";
return false;
}

// ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
if (outWidth == 32 && outputUnsigned) {
op->emitOpError() << "i32 output type cannot be unsigned.";
return false;
}

return true;
}

LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op))
return failure();
return success();
}
Expand Down
96 changes: 96 additions & 0 deletions mlir/test/Dialect/Tosa/error_if_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,99 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
return %0 : tensor<2x64xi8>
}

// -----
// CHECK-LABEL: test_error_scale32_with_i48
func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}

// -----
// CHECK-LABEL: test_error_input_output_unsigned
func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op input and output cannot be both unsigned}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
return %0 : tensor<1xi16>
}

// -----
// CHECK-LABEL: test_error_i32_output_unsigned_input
func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}

// -----
// CHECK-LABEL: test_error_i32_input_unsigned_output
func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}

// -----
// CHECK-LABEL: test_error_i48_input_unsigned_output
func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}

// -----
// CHECK-LABEL: test_error_i48_unsigned_input
func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}

// -----
// CHECK-LABEL: test_error_i32_unsigned_input
func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}

// -----
// CHECK-LABEL: test_error_i32_unsigned_output
func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error@+1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
return %0 : tensor<13x21x3xi16>
}

// -----
// CHECK-LABEL: test_error_double_round_without_scale32
func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
return %0 : tensor<1xi16>
}

// -----
// CHECK-LABEL: test_matmul_a_zp_same_element_type
func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Tosa/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,15 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}

// -----

func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
%multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
// expected-error@+1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
return %0 : tensor<i16>
}
Loading