@@ -190,34 +190,65 @@ void quantized_conv_out(
190
190
// per-channel
191
191
bool per_tensor_quantized = bias_scale.numel () == 1 ;
192
192
193
- conv2d_nchw_core_generic<uint8_t , uint8_t , int32_t , uint8_t , true >(
194
- input.const_data_ptr <uint8_t >(),
195
- weight.const_data_ptr <uint8_t >(),
196
- bias.const_data_ptr <int32_t >(),
197
- out.mutable_data_ptr <uint8_t >(),
198
- n,
199
- c,
200
- h,
201
- w,
202
- oc,
203
- wc,
204
- wh,
205
- ww,
206
- oh,
207
- ow,
208
- stride[0 ],
209
- stride[1 ],
210
- padding[0 ],
211
- padding[1 ],
212
- dilation[0 ],
213
- dilation[1 ],
214
- groups,
215
- in_zero_point,
216
- weight_zero_point.const_data_ptr <int32_t >(),
217
- bias_scale.const_data_ptr <float >(),
218
- output_scale,
219
- (uint8_t )output_zero_point,
220
- per_tensor_quantized);
193
+ if (out.scalar_type () == exec_aten::ScalarType::Byte) {
194
+ conv2d_nchw_core_generic<uint8_t , uint8_t , int32_t , uint8_t , true >(
195
+ input.const_data_ptr <uint8_t >(),
196
+ weight.const_data_ptr <uint8_t >(),
197
+ bias.const_data_ptr <int32_t >(),
198
+ out.mutable_data_ptr <uint8_t >(),
199
+ n,
200
+ c,
201
+ h,
202
+ w,
203
+ oc,
204
+ wc,
205
+ wh,
206
+ ww,
207
+ oh,
208
+ ow,
209
+ stride[0 ],
210
+ stride[1 ],
211
+ padding[0 ],
212
+ padding[1 ],
213
+ dilation[0 ],
214
+ dilation[1 ],
215
+ groups,
216
+ in_zero_point,
217
+ weight_zero_point.const_data_ptr <int32_t >(),
218
+ bias_scale.const_data_ptr <float >(),
219
+ output_scale,
220
+ (uint8_t )output_zero_point,
221
+ per_tensor_quantized);
222
+ } else if (out.scalar_type () == exec_aten::ScalarType::Char) {
223
+ conv2d_nchw_core_generic<int8_t , int8_t , int32_t , int8_t , true >(
224
+ input.const_data_ptr <int8_t >(),
225
+ weight.const_data_ptr <int8_t >(),
226
+ bias.const_data_ptr <int32_t >(),
227
+ out.mutable_data_ptr <int8_t >(),
228
+ n,
229
+ c,
230
+ h,
231
+ w,
232
+ oc,
233
+ wc,
234
+ wh,
235
+ ww,
236
+ oh,
237
+ ow,
238
+ stride[0 ],
239
+ stride[1 ],
240
+ padding[0 ],
241
+ padding[1 ],
242
+ dilation[0 ],
243
+ dilation[1 ],
244
+ groups,
245
+ in_zero_point,
246
+ weight_zero_point.const_data_ptr <int32_t >(),
247
+ bias_scale.const_data_ptr <float >(),
248
+ output_scale,
249
+ (int8_t )output_zero_point,
250
+ per_tensor_quantized);
251
+ }
221
252
}
222
253
223
254
}; // namespace native
0 commit comments