Skip to content

Commit b0f9a61

Browse files
authored
Add support for uint16 in quant and dequant kernels
Differential Revision: D65370235 Pull Request resolved: #6724
1 parent 485a5df commit b0f9a61

File tree

6 files changed

+31
-0
lines changed

6 files changed

+31
-0
lines changed

backends/cadence/hifi/kernels/kernels.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ void requantize(
165165
typed_quantize_val(int8_t);
166166
typed_quantize_val(uint8_t);
167167
typed_quantize_val(int16_t);
168+
typed_quantize_val(uint16_t);
168169
#undef typed_quantize_val
169170

170171
#define typed_quantize_vec(dtype) \
@@ -177,6 +178,7 @@ typed_quantize_val(int16_t);
177178
typed_quantize_vec(int8_t);
178179
typed_quantize_vec(uint8_t);
179180
typed_quantize_vec(int16_t);
181+
typed_quantize_vec(uint16_t);
180182
typed_quantize_vec(int32_t);
181183
#undef typed_quantize_vec
182184

@@ -186,6 +188,7 @@ typed_quantize_vec(int32_t);
186188
typed_dequantize_val(int8_t);
187189
typed_dequantize_val(uint8_t);
188190
typed_dequantize_val(int16_t);
191+
typed_dequantize_val(uint16_t);
189192
#undef typed_dequantize_val
190193

191194
#define typed_dequantize_vec(dtype) \
@@ -198,6 +201,7 @@ typed_dequantize_val(int16_t);
198201
typed_dequantize_vec(int8_t);
199202
typed_dequantize_vec(uint8_t);
200203
typed_dequantize_vec(int16_t);
204+
typed_dequantize_vec(uint16_t);
201205
typed_dequantize_vec(int32_t);
202206
#undef typed_dequantize_vec
203207

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ void dequantize_per_tensor_out(
4141
} else if (input.scalar_type() == ScalarType::Short) {
4242
const int16_t* input_data = input.const_data_ptr<int16_t>();
4343
dequantize<int16_t>(out_data, input_data, scale, zero_point, numel);
44+
} else if (input.scalar_type() == ScalarType::Bits16) {
45+
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
46+
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
4447
} else if (input.scalar_type() == ScalarType::Int) {
4548
const int32_t* input_data = input.const_data_ptr<int32_t>();
4649
dequantize<int32_t>(out_data, input_data, scale, zero_point, numel);

backends/cadence/hifi/operators/quantize_per_tensor.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ void quantize_per_tensor_out(
4444
int16_t* out_data = out.mutable_data_ptr<int16_t>();
4545
cadence::impl::HiFi::kernels::quantize<int16_t>(
4646
out_data, input_data, 1. / scale, zero_point, numel);
47+
} else if (out.scalar_type() == ScalarType::Bits16) {
48+
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
49+
cadence::impl::HiFi::kernels::quantize<uint16_t>(
50+
out_data, input_data, 1. / scale, zero_point, numel);
4751
} else if (out.scalar_type() == ScalarType::Int) {
4852
int32_t* out_data = out.mutable_data_ptr<int32_t>();
4953
cadence::impl::HiFi::kernels::quantize<int32_t>(

backends/cadence/reference/kernels/kernels.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ void dequantize(
6565
typed_quantize_val(int8_t);
6666
typed_quantize_val(uint8_t);
6767
typed_quantize_val(int16_t);
68+
typed_quantize_val(uint16_t);
6869
typed_quantize_val(int32_t);
6970
#undef typed_quantize_val
7071

@@ -78,6 +79,7 @@ typed_quantize_val(int32_t);
7879
typed_quantize_vec(int8_t);
7980
typed_quantize_vec(uint8_t);
8081
typed_quantize_vec(int16_t);
82+
typed_quantize_vec(uint16_t);
8183
typed_quantize_vec(int32_t);
8284
#undef typed_quantize_vec
8385

@@ -86,6 +88,7 @@ typed_quantize_vec(int32_t);
8688
typed_dequantize_val(int8_t);
8789
typed_dequantize_val(uint8_t);
8890
typed_dequantize_val(int16_t);
91+
typed_dequantize_val(uint16_t);
8992
typed_dequantize_val(int32_t);
9093
#undef typed_dequantize_val
9194

@@ -99,6 +102,7 @@ typed_dequantize_val(int32_t);
99102
typed_dequantize_vec(int8_t);
100103
typed_dequantize_vec(uint8_t);
101104
typed_dequantize_vec(int16_t);
105+
typed_dequantize_vec(uint16_t);
102106
typed_dequantize_vec(int32_t);
103107
#undef typed_dequantize_vec
104108

backends/cadence/reference/operators/dequantize_per_tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ void dequantize_per_tensor_out(
3737
const int8_t* input_data = input.const_data_ptr<int8_t>();
3838
impl::reference::kernels::dequantize<int8_t>(
3939
out_data, input_data, scale, zero_point, numel);
40+
} else if (input.scalar_type() == ScalarType::Bits16) {
41+
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
42+
impl::reference::kernels::dequantize<uint16_t>(
43+
out_data, input_data, scale, zero_point, numel);
44+
} else if (input.scalar_type() == ScalarType::Short) {
45+
const int16_t* input_data = input.const_data_ptr<int16_t>();
46+
impl::reference::kernels::dequantize<int16_t>(
47+
out_data, input_data, scale, zero_point, numel);
4048
} else if (input.scalar_type() == ScalarType::Int) {
4149
const int32_t* input_data = input.const_data_ptr<int32_t>();
4250
impl::reference::kernels::dequantize<int32_t>(

backends/cadence/reference/operators/quantize_per_tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ void quantize_per_tensor_out(
3939
int8_t* out_data = out.mutable_data_ptr<int8_t>();
4040
impl::reference::kernels::quantize<int8_t>(
4141
out_data, input_data, 1. / scale, zero_point, numel);
42+
} else if (out.scalar_type() == ScalarType::Bits16) {
43+
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
44+
impl::reference::kernels::quantize<uint16_t>(
45+
out_data, input_data, 1. / scale, zero_point, numel);
46+
} else if (out.scalar_type() == ScalarType::Short) {
47+
int16_t* out_data = out.mutable_data_ptr<int16_t>();
48+
impl::reference::kernels::quantize<int16_t>(
49+
out_data, input_data, 1. / scale, zero_point, numel);
4250
} else if (out.scalar_type() == ScalarType::Int) {
4351
int32_t* out_data = out.mutable_data_ptr<int32_t>();
4452
impl::reference::kernels::quantize<int32_t>(

0 commit comments

Comments
 (0)