Skip to content

Commit 91134d6

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
optional zero points on dequantize per channel (#2364)
Summary: X-link: pytorch/pytorch#121724 Pull Request resolved: #2364 bypass-github-export-checks Reviewed By: mikekgfb Differential Revision: D54709217 fbshipit-source-id: 2d1efe19a65bba8d014fa35bbb06a1b3a2af97d8
1 parent 969638c commit 91134d6

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
166166
Tensor& dequantize_per_channel_out(
167167
const Tensor& input,
168168
const Tensor& scale,
169-
const Tensor& zero_point,
169+
const optional<Tensor>& opt_zero_points,
170170
int64_t axis,
171171
int64_t quant_min,
172172
int64_t quant_max,
@@ -201,16 +201,19 @@ Tensor& dequantize_per_channel_out(
201201
ssize_t(scale.numel()),
202202
ssize_t(input.size(axis)));
203203

204-
ET_CHECK_MSG(
205-
zero_point.scalar_type() == ScalarType::Long,
206-
"zero_point.scalar_type() %" PRId8 " is not integer type",
207-
static_cast<int8_t>(zero_point.scalar_type()));
204+
if (opt_zero_points.has_value()) {
205+
auto zero_point = opt_zero_points.value();
206+
ET_CHECK_MSG(
207+
zero_point.scalar_type() == ScalarType::Long,
208+
"zero_point.scalar_type() %" PRId8 " is not integer type",
209+
static_cast<int8_t>(zero_point.scalar_type()));
208210

209-
ET_CHECK_MSG(
210-
zero_point.numel() == input.size(axis),
211-
"zero_point.numel() %zd != input.size(axis) %zd",
212-
ssize_t(zero_point.numel()),
213-
ssize_t(input.size(axis)));
211+
ET_CHECK_MSG(
212+
zero_point.numel() == input.size(axis),
213+
"zero_point.numel() %zd != input.size(axis) %zd",
214+
ssize_t(zero_point.numel()),
215+
ssize_t(input.size(axis)));
216+
}
214217

215218
check_dequantize_per_tensor_args(
216219
input, quant_min, quant_max, dtype, out_dtype, out);
@@ -225,7 +228,12 @@ Tensor& dequantize_per_channel_out(
225228
}
226229
}
227230
const double* scale_data = scale.const_data_ptr<double>();
228-
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();
231+
const int64_t* zero_point_data;
232+
if (opt_zero_points.has_value()) {
233+
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
234+
} else {
235+
zero_point_data = nullptr;
236+
}
229237

230238
exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
231239
exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
@@ -242,7 +250,10 @@ Tensor& dequantize_per_channel_out(
242250
case ScalarType::out_dtype: \
243251
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
244252
double _scale = scale_data[channel_ix]; \
245-
int64_t _zero_point = zero_point_data[channel_ix]; \
253+
int64_t _zero_point = 0; \
254+
if (zero_point_data != nullptr) { \
255+
_zero_point = zero_point_data[channel_ix]; \
256+
} \
246257
apply_over_dim_list( \
247258
[input, out, _scale, _zero_point](size_t in_ix) { \
248259
out.mutable_data_ptr<CTYPE_OUT>()[in_ix] = static_cast<CTYPE_OUT>( \
@@ -284,7 +295,7 @@ Tensor& dequantize_per_channel_out(
284295
RuntimeContext& context,
285296
const Tensor& input,
286297
const Tensor& scale,
287-
const Tensor& zero_point,
298+
const optional<Tensor>& opt_zero_points,
288299
int64_t axis,
289300
int64_t quant_min,
290301
int64_t quant_max,
@@ -295,7 +306,7 @@ Tensor& dequantize_per_channel_out(
295306
return dequantize_per_channel_out(
296307
input,
297308
scale,
298-
zero_point,
309+
opt_zero_points,
299310
axis,
300311
quant_min,
301312
quant_max,

kernels/quantized/quantized.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
- arg_meta: null
2929
kernel_name: torch::executor::quantize_per_channel_out
3030

31-
- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
31+
- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor? zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!)
3232
variants: function
3333
kernels:
3434
- arg_meta: null

0 commit comments

Comments
 (0)