@@ -166,7 +166,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
166
166
Tensor& dequantize_per_channel_out (
167
167
const Tensor& input,
168
168
const Tensor& scale,
169
- const Tensor& zero_point ,
169
+ const optional< Tensor>& opt_zero_points ,
170
170
int64_t axis,
171
171
int64_t quant_min,
172
172
int64_t quant_max,
@@ -201,16 +201,19 @@ Tensor& dequantize_per_channel_out(
201
201
ssize_t (scale.numel ()),
202
202
ssize_t (input.size (axis)));
203
203
204
- ET_CHECK_MSG (
205
- zero_point.scalar_type () == ScalarType::Long,
206
- " zero_point.scalar_type() %" PRId8 " is not integer type" ,
207
- static_cast <int8_t >(zero_point.scalar_type ()));
204
+ if (opt_zero_points.has_value ()) {
205
+ auto zero_point = opt_zero_points.value ();
206
+ ET_CHECK_MSG (
207
+ zero_point.scalar_type () == ScalarType::Long,
208
+ " zero_point.scalar_type() %" PRId8 " is not integer type" ,
209
+ static_cast <int8_t >(zero_point.scalar_type ()));
208
210
209
- ET_CHECK_MSG (
210
- zero_point.numel () == input.size (axis),
211
- " zero_point.numel() %zd != input.size(axis) %zd" ,
212
- ssize_t (zero_point.numel ()),
213
- ssize_t (input.size (axis)));
211
+ ET_CHECK_MSG (
212
+ zero_point.numel () == input.size (axis),
213
+ " zero_point.numel() %zd != input.size(axis) %zd" ,
214
+ ssize_t (zero_point.numel ()),
215
+ ssize_t (input.size (axis)));
216
+ }
214
217
215
218
check_dequantize_per_tensor_args (
216
219
input, quant_min, quant_max, dtype, out_dtype, out);
@@ -225,7 +228,12 @@ Tensor& dequantize_per_channel_out(
225
228
}
226
229
}
227
230
const double * scale_data = scale.const_data_ptr <double >();
228
- const int64_t * zero_point_data = zero_point.const_data_ptr <int64_t >();
231
+ const int64_t * zero_point_data;
232
+ if (opt_zero_points.has_value ()) {
233
+ zero_point_data = opt_zero_points.value ().const_data_ptr <int64_t >();
234
+ } else {
235
+ zero_point_data = nullptr ;
236
+ }
229
237
230
238
exec_aten::optional<exec_aten::ArrayRef<int64_t >> optional_dim_list{
231
239
exec_aten::ArrayRef<int64_t >{dims, size_t (input.dim () - 1 )}};
@@ -242,7 +250,10 @@ Tensor& dequantize_per_channel_out(
242
250
case ScalarType::out_dtype: \
243
251
for (size_t channel_ix = 0 ; channel_ix < input.size (axis); ++channel_ix) { \
244
252
double _scale = scale_data[channel_ix]; \
245
- int64_t _zero_point = zero_point_data[channel_ix]; \
253
+ int64_t _zero_point = 0 ; \
254
+ if (zero_point_data != nullptr ) { \
255
+ _zero_point = zero_point_data[channel_ix]; \
256
+ } \
246
257
apply_over_dim_list ( \
247
258
[input, out, _scale, _zero_point](size_t in_ix) { \
248
259
out.mutable_data_ptr <CTYPE_OUT>()[in_ix] = static_cast <CTYPE_OUT>( \
@@ -284,7 +295,7 @@ Tensor& dequantize_per_channel_out(
284
295
RuntimeContext& context,
285
296
const Tensor& input,
286
297
const Tensor& scale,
287
- const Tensor& zero_point ,
298
+ const optional< Tensor>& opt_zero_points ,
288
299
int64_t axis,
289
300
int64_t quant_min,
290
301
int64_t quant_max,
@@ -295,7 +306,7 @@ Tensor& dequantize_per_channel_out(
295
306
return dequantize_per_channel_out (
296
307
input,
297
308
scale,
298
- zero_point ,
309
+ opt_zero_points ,
299
310
axis,
300
311
quant_min,
301
312
quant_max,
0 commit comments