@@ -1132,11 +1132,21 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
1132
1132
// CHECK-DAG: linalg.yield [[TRUNC]]
1133
1133
%0 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <2 xi8 >) -> tensor <2 xi8 >
1134
1134
1135
+ // CHECK: return
1136
+ return
1137
+ }
1138
+
1139
+ // -----
1140
+ // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1141
+
1142
+ // CHECK-LABEL: @rescale_i8_unsigned_output
1143
+ // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1144
+ func.func @rescale_i8_unsigned_output (%arg0 : tensor <2 xi8 >) -> () {
1135
1145
// CHECK: [[C0:%.+]] = arith.constant 19689
1136
1146
// CHECK: [[C1:%.+]] = arith.constant 15
1137
1147
// CHECK: [[INIT:%.+]] = tensor.empty()
1138
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8 >)
1139
- // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8 ):
1148
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8 >)
1149
+ // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8 ):
1140
1150
// CHECK: [[C17:%.+]] = arith.constant 17
1141
1151
// CHECK: [[C22:%.+]] = arith.constant 22
1142
1152
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
@@ -1148,9 +1158,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
1148
1158
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1149
1159
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1150
1160
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1151
- // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1152
- // CHECK: linalg.yield [[CAST]]
1153
- %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <2 xi8 >) -> tensor <2 xui8 >
1161
+ // CHECK: linalg.yield [[TRUNC]]
1162
+ %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false , output_unsigned = true } : (tensor <2 xi8 >) -> tensor <2 xi8 >
1154
1163
1155
1164
// CHECK: return
1156
1165
return
@@ -1171,9 +1180,9 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
1171
1180
1172
1181
// CHECK: %[[C0:.+]] = arith.constant 0
1173
1182
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
1174
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8 >
1175
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8 >)
1176
- %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <?x2 xi8 >) -> tensor <?x 2 x ui8 >
1183
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8 >
1184
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8 >)
1185
+ %1 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false , output_unsigned = true } : (tensor <?x2 xi8 >) -> tensor <?x 2 x i8 >
1177
1186
1178
1187
return
1179
1188
}
@@ -1199,18 +1208,17 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
1199
1208
1200
1209
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1201
1210
1202
- // CHECK-LABEL: @rescale_ui8
1211
+ // CHECK-LABEL: @rescale_i8_unsigned_input
1203
1212
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1204
- func.func @rescale_ui8 (%arg0 : tensor <2 x ui8 >) -> () {
1213
+ func.func @rescale_i8_unsigned_input (%arg0 : tensor <2 x i8 >) -> () {
1205
1214
// CHECK: [[C0:%.+]] = arith.constant 19689
1206
1215
// CHECK: [[C1:%.+]] = arith.constant 15
1207
1216
// CHECK: [[INIT:%.+]] = tensor.empty()
1208
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8 >) outs([[INIT]] : tensor<2xi8>)
1209
- // CHECK: ^bb0([[IN:%.+]]: ui8 , [[UNUSED:%.+]]: i8):
1217
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8 >) outs([[INIT]] : tensor<2xi8>)
1218
+ // CHECK: ^bb0([[IN:%.+]]: i8 , [[UNUSED:%.+]]: i8):
1210
1219
// CHECK: [[C17:%.+]] = arith.constant 17
1211
1220
// CHECK: [[C22:%.+]] = arith.constant 22
1212
- // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1213
- // CHECK-DAG: [[IN32:%.+]] = arith.extui [[CAST]]
1221
+ // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
1214
1222
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1215
1223
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
1216
1224
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
@@ -1220,7 +1228,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
1220
1228
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1221
1229
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1222
1230
// CHECK: linalg.yield [[TRUNC]]
1223
- %0 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false } : (tensor <2 x ui8 >) -> tensor <2 xi8 >
1231
+ %0 = tosa.rescale %arg0 {input_zp = 17 : i32 , output_zp = 22 : i32 , multiplier = array<i32 : 19689 >, shift = array<i8 : 15 >, scale32 = false , double_round = false , per_channel = false , input_unsigned = true } : (tensor <2 x i8 >) -> tensor <2 xi8 >
1224
1232
1225
1233
return
1226
1234
}
0 commit comments