Skip to content

Commit e0d037f

Browse files
committed
[mlir][tosa] Enhance error_if and verify checks for RESCALE Op
* add verifier for rank-0 input with per-channel * add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0 * add LIT tests Change-Id: Ia07e8c2ee66d8ee4113bea5ad9fa859b5986b009 Signed-off-by: Peng Sun <[email protected]>
1 parent 92bba68 commit e0d037f

File tree

5 files changed

+220
-1
lines changed

5 files changed

+220
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3186,6 +3186,12 @@ LogicalResult RescaleOp::verify() {
31863186
// otherwise numChannel is dimension in input shape's last axis
31873187
int64_t numChannels = 1;
31883188
if (getPerChannel()) {
3189+
if (inputType.getRank() < 1) {
3190+
emitOpError("requires input to be at least rank 1 when per_channel is "
3191+
"true, but got rank ")
3192+
<< inputType.getRank();
3193+
return failure();
3194+
}
31893195
numChannels = inputType.getDimSize(inputType.getRank() - 1);
31903196
}
31913197

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,8 +1033,88 @@ bool checkErrorIfTable(Operation *op) {
10331033
return true;
10341034
}
10351035

1036+
bool checkErrorIfRescale(Operation *op) {
1037+
auto rescale = dyn_cast<tosa::RescaleOp>(op);
1038+
if (!rescale)
1039+
return true;
1040+
1041+
auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1042+
auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1043+
if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1044+
!outputType.getElementType().isInteger())
1045+
return true;
1046+
1047+
auto inElemType = inputType.getElementType();
1048+
auto outElemType = outputType.getElementType();
1049+
auto inWidth = inElemType.getIntOrFloatBitWidth();
1050+
auto outWidth = outElemType.getIntOrFloatBitWidth();
1051+
1052+
bool inputUnsigned = rescale.getInputUnsigned();
1053+
bool outputUnsigned = rescale.getOutputUnsigned();
1054+
1055+
bool scale32 = rescale.getScale32();
1056+
auto roundingMode = rescale.getRoundingMode();
1057+
1058+
// ERROR_IF(scale32 && is_same<in_t,i48_t>())
1059+
if (scale32 && inWidth == 48) {
1060+
op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1061+
return false;
1062+
}
1063+
1064+
// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1065+
if (!scale32 && roundingMode == "DOUBLE_ROUND") {
1066+
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
1067+
return false;
1068+
}
1069+
1070+
// ERROR_IF(input_unsigned && output_unsigned)
1071+
if (inputUnsigned && outputUnsigned) {
1072+
op->emitOpError() << "input and output cannot be both unsigned.";
1073+
return false;
1074+
}
1075+
1076+
// ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1077+
if (outWidth == 32 && inputUnsigned) {
1078+
op->emitOpError() << "i32 output type is not allowed with unsigned input.";
1079+
return false;
1080+
}
1081+
1082+
// ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1083+
if (inWidth == 32 && outputUnsigned) {
1084+
op->emitOpError() << "i32 input type is not allowed with unsigned output.";
1085+
return false;
1086+
}
1087+
1088+
// ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1089+
if (inWidth == 48 && outputUnsigned) {
1090+
op->emitOpError() << "i48 input type is not allowed with unsigned output.";
1091+
return false;
1092+
}
1093+
1094+
// ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1095+
if (inWidth == 48 && inputUnsigned) {
1096+
op->emitOpError() << "i48 input type cannot be unsigned.";
1097+
return false;
1098+
}
1099+
1100+
// ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1101+
if (inWidth == 32 && inputUnsigned) {
1102+
op->emitOpError() << "i32 input type cannot be unsigned.";
1103+
return false;
1104+
}
1105+
1106+
// ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1107+
if (outWidth == 32 && outputUnsigned) {
1108+
op->emitOpError() << "i32 output type cannot be unsigned.";
1109+
return false;
1110+
}
1111+
1112+
return true;
1113+
}
1114+
10361115
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1037-
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
1116+
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1117+
!checkErrorIfTable(op) || !checkErrorIfRescale(op))
10381118
return failure();
10391119
return success();
10401120
}

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,111 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
129129
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
130130
return %0 : tensor<2x64xi8>
131131
}
132+
133+
// -----
134+
// CHECK-LABEL: test_error_input_zp_not_allowed
135+
func.func @test_error_input_zp_not_allowed(%arg0: tensor<1xi48>) -> tensor<1xi8> {
136+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
137+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
138+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
139+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
140+
// expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
141+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
142+
return %0 : tensor<1xi8>
143+
}
144+
145+
// -----
146+
// CHECK-LABEL: test_error_scale32_with_i48
147+
func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
148+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
149+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
150+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
151+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
152+
// expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
153+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
154+
return %0 : tensor<1xi8>
155+
}
156+
157+
// -----
158+
// CHECK-LABEL: test_error_input_output_unsigned
159+
func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
160+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
161+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
162+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
163+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
164+
// expected-error@+1 {{'tosa.rescale' op input and output cannot be both unsigned}}
165+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
166+
return %0 : tensor<1xi16>
167+
}
168+
169+
// -----
170+
// CHECK-LABEL: test_error_i32_output_unsigned_input
171+
func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
172+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
173+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
174+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
175+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
176+
// expected-error@+1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
177+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
178+
return %0 : tensor<1xi32>
179+
}
180+
181+
// -----
182+
// CHECK-LABEL: test_error_i32_input_unsigned_output
183+
func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
184+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
185+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
186+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
187+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
188+
// expected-error@+1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
189+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
190+
return %0 : tensor<1xi8>
191+
}
192+
193+
// -----
194+
// CHECK-LABEL: test_error_i48_input_unsigned_output
195+
func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
196+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
197+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
198+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
199+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
200+
// expected-error@+1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
201+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
202+
return %0 : tensor<1xi8>
203+
}
204+
205+
// -----
206+
// CHECK-LABEL: test_error_i48_unsigned_input
207+
func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
208+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
209+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
210+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
211+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
212+
// expected-error@+1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
213+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
214+
return %0 : tensor<1xi8>
215+
}
216+
217+
// -----
218+
// CHECK-LABEL: test_error_i32_unsigned_input
219+
func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
220+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
221+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
222+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
223+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
224+
// expected-error@+1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
225+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
226+
return %0 : tensor<1xi8>
227+
}
228+
229+
// -----
230+
// CHECK-LABEL: test_error_i32_unsigned_output
231+
func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
232+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
233+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
234+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
235+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
236+
// expected-error@+1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
237+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
238+
return %0 : tensor<1xi32>
239+
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
16691669
return %0 : tensor<13x21x3xi16>
16701670
}
16711671

1672+
// -----
1673+
// CHECK-LABEL: test_error_double_round_without_scale32
1674+
func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
1675+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1676+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
1677+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
1678+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
1679+
// expected-error@+1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
1680+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
1681+
return %0 : tensor<1xi16>
1682+
}
1683+
16721684
// -----
16731685
// CHECK-LABEL: test_matmul_a_zp_same_element_type
16741686
func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,16 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
319319
: (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
320320
return %0 : tensor<1x4x8x19x34xf32>
321321
}
322+
323+
// -----
324+
325+
// CHECK-LABEL: test_error_scalar_input_with_per_channel
326+
func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
327+
%multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
328+
%shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
329+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
330+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
331+
// expected-error@+1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
332+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
333+
return %0 : tensor<i16>
334+
}

0 commit comments

Comments
 (0)