Skip to content

Commit f96be5b

Browse files
authored
Add UInt16 support to Cadence kernels
Differential Revision: D66016288 Pull Request resolved: #6893
1 parent 0877926 commit f96be5b

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ void dequantize_per_tensor_out(
4141
} else if (input.scalar_type() == ScalarType::Short) {
4242
const int16_t* input_data = input.const_data_ptr<int16_t>();
4343
dequantize<int16_t>(out_data, input_data, scale, zero_point, numel);
44-
} else if (input.scalar_type() == ScalarType::Bits16) {
44+
} else if (
45+
input.scalar_type() == ScalarType::Bits16 ||
46+
input.scalar_type() == ScalarType::UInt16) {
4547
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
4648
dequantize<uint16_t>(out_data, input_data, scale, zero_point, numel);
4749
} else if (input.scalar_type() == ScalarType::Int) {

backends/cadence/hifi/operators/quantize_per_tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ void quantize_per_tensor_out(
4444
int16_t* out_data = out.mutable_data_ptr<int16_t>();
4545
cadence::impl::HiFi::kernels::quantize<int16_t>(
4646
out_data, input_data, 1. / scale, zero_point, numel);
47-
} else if (out.scalar_type() == ScalarType::Bits16) {
47+
} else if (
48+
out.scalar_type() == ScalarType::Bits16 ||
49+
out.scalar_type() == ScalarType::UInt16) {
4850
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
4951
cadence::impl::HiFi::kernels::quantize<uint16_t>(
5052
out_data, input_data, 1. / scale, zero_point, numel);

backends/cadence/reference/operators/dequantize_per_tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ void dequantize_per_tensor_out(
3737
const int8_t* input_data = input.const_data_ptr<int8_t>();
3838
impl::reference::kernels::dequantize<int8_t>(
3939
out_data, input_data, scale, zero_point, numel);
40-
} else if (input.scalar_type() == ScalarType::Bits16) {
40+
} else if (
41+
input.scalar_type() == ScalarType::Bits16 ||
42+
input.scalar_type() == ScalarType::UInt16) {
4143
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
4244
impl::reference::kernels::dequantize<uint16_t>(
4345
out_data, input_data, scale, zero_point, numel);

backends/cadence/reference/operators/quantize_per_tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ void quantize_per_tensor_out(
3939
int8_t* out_data = out.mutable_data_ptr<int8_t>();
4040
impl::reference::kernels::quantize<int8_t>(
4141
out_data, input_data, 1. / scale, zero_point, numel);
42-
} else if (out.scalar_type() == ScalarType::Bits16) {
42+
} else if (
43+
out.scalar_type() == ScalarType::Bits16 ||
44+
out.scalar_type() == ScalarType::UInt16) {
4345
uint16_t* out_data = out.mutable_data_ptr<uint16_t>();
4446
impl::reference::kernels::quantize<uint16_t>(
4547
out_data, input_data, 1. / scale, zero_point, numel);

0 commit comments

Comments
 (0)