Skip to content

Commit 4a4a90f

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Allow int8 type in quantized_conv and im2row (#6049)
Summary: Pull Request resolved: #6049 As titled. Reviewed By: zonglinpeng Differential Revision: D63842548 fbshipit-source-id: 5d535fb75f9ca3374b49126e6653082fa60b2ac1
1 parent ba8dc28 commit 4a4a90f

File tree

1 file changed

+59
-28
lines changed

1 file changed

+59
-28
lines changed

backends/cadence/reference/operators/quantized_conv_out.cpp

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -190,34 +190,65 @@ void quantized_conv_out(
190190
// per-channel
191191
bool per_tensor_quantized = bias_scale.numel() == 1;
192192

193-
conv2d_nchw_core_generic<uint8_t, uint8_t, int32_t, uint8_t, true>(
194-
input.const_data_ptr<uint8_t>(),
195-
weight.const_data_ptr<uint8_t>(),
196-
bias.const_data_ptr<int32_t>(),
197-
out.mutable_data_ptr<uint8_t>(),
198-
n,
199-
c,
200-
h,
201-
w,
202-
oc,
203-
wc,
204-
wh,
205-
ww,
206-
oh,
207-
ow,
208-
stride[0],
209-
stride[1],
210-
padding[0],
211-
padding[1],
212-
dilation[0],
213-
dilation[1],
214-
groups,
215-
in_zero_point,
216-
weight_zero_point.const_data_ptr<int32_t>(),
217-
bias_scale.const_data_ptr<float>(),
218-
output_scale,
219-
(uint8_t)output_zero_point,
220-
per_tensor_quantized);
193+
if (out.scalar_type() == exec_aten::ScalarType::Byte) {
194+
conv2d_nchw_core_generic<uint8_t, uint8_t, int32_t, uint8_t, true>(
195+
input.const_data_ptr<uint8_t>(),
196+
weight.const_data_ptr<uint8_t>(),
197+
bias.const_data_ptr<int32_t>(),
198+
out.mutable_data_ptr<uint8_t>(),
199+
n,
200+
c,
201+
h,
202+
w,
203+
oc,
204+
wc,
205+
wh,
206+
ww,
207+
oh,
208+
ow,
209+
stride[0],
210+
stride[1],
211+
padding[0],
212+
padding[1],
213+
dilation[0],
214+
dilation[1],
215+
groups,
216+
in_zero_point,
217+
weight_zero_point.const_data_ptr<int32_t>(),
218+
bias_scale.const_data_ptr<float>(),
219+
output_scale,
220+
(uint8_t)output_zero_point,
221+
per_tensor_quantized);
222+
} else if (out.scalar_type() == exec_aten::ScalarType::Char) {
223+
conv2d_nchw_core_generic<int8_t, int8_t, int32_t, int8_t, true>(
224+
input.const_data_ptr<int8_t>(),
225+
weight.const_data_ptr<int8_t>(),
226+
bias.const_data_ptr<int32_t>(),
227+
out.mutable_data_ptr<int8_t>(),
228+
n,
229+
c,
230+
h,
231+
w,
232+
oc,
233+
wc,
234+
wh,
235+
ww,
236+
oh,
237+
ow,
238+
stride[0],
239+
stride[1],
240+
padding[0],
241+
padding[1],
242+
dilation[0],
243+
dilation[1],
244+
groups,
245+
in_zero_point,
246+
weight_zero_point.const_data_ptr<int32_t>(),
247+
bias_scale.const_data_ptr<float>(),
248+
output_scale,
249+
(int8_t)output_zero_point,
250+
per_tensor_quantized);
251+
}
221252
}
222253

223254
}; // namespace native

0 commit comments

Comments
 (0)