Skip to content

Commit 0877926

Browse files
Add UInt16 support to Q and DQ
Differential Revision: D65899080 Pull Request resolved: #6891
1 parent 92ee522 commit 0877926

File tree

5 files changed

+15
-1
lines changed

5 files changed

+15
-1
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void check_dequantize_per_tensor_args(
3939
input.scalar_type() == ScalarType::Byte ||
4040
input.scalar_type() == ScalarType::Char ||
4141
input.scalar_type() == ScalarType::Bits16 ||
42+
input.scalar_type() == ScalarType::UInt16 ||
4243
input.scalar_type() == ScalarType::Short ||
4344
input.scalar_type() == ScalarType::Int,
4445
"input.scalar_type() %" PRId8 " is not supported:",
@@ -120,6 +121,7 @@ Tensor& dequantize_per_tensor_out(
120121
switch (input.scalar_type()) {
121122
ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE);
122123
CALCULATE_INT_TYPE(uint16_t, Bits16);
124+
CALCULATE_INT_TYPE(uint16_t, UInt16);
123125
default:
124126
ET_CHECK_MSG(
125127
false,
@@ -315,6 +317,7 @@ Tensor& dequantize_per_channel_out(
315317
switch (input.scalar_type()) {
316318
ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE);
317319
CALCULATE_INT_TYPE(uint16_t, Bits16);
320+
CALCULATE_INT_TYPE(uint16_t, UInt16);
318321
default:
319322
ET_CHECK_MSG(
320323
false,

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void check_quantize_per_tensor_args(
5757
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
5858
quant_max_upper_bound =
5959
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
60-
} else if (dtype == ScalarType::Bits16) {
60+
} else if (dtype == ScalarType::Bits16 || dtype == ScalarType::UInt16) {
6161
quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
6262
quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
6363
} else if (dtype == ScalarType::Short) {
@@ -139,6 +139,7 @@ Tensor& quantize_per_tensor_out(
139139
switch (out.scalar_type()) { \
140140
ET_FORALL_INT_TYPES_WITH(IN_CTYPE, QUANTIZE_IMPL); \
141141
QUANTIZE_IMPL(IN_CTYPE, uint16_t, Bits16) \
142+
QUANTIZE_IMPL(IN_CTYPE, uint16_t, UInt16) \
142143
default: \
143144
ET_CHECK_MSG( \
144145
false, \
@@ -334,6 +335,7 @@ Tensor& quantize_per_channel_out(
334335
switch (out.scalar_type()) { \
335336
ET_FORALL_INT_TYPES_WITH(CTYPE_IN, QUANTIZE_IMPL); \
336337
QUANTIZE_IMPL(CTYPE_IN, uint16_t, Bits16) \
338+
QUANTIZE_IMPL(CTYPE_IN, uint16_t, UInt16) \
337339
default: \
338340
ET_CHECK_MSG( \
339341
false, \

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) {
6363
test_dtype<ScalarType::Char>();
6464
test_dtype<ScalarType::Short>();
6565
test_dtype<ScalarType::Bits16>();
66+
test_dtype<ScalarType::UInt16>();
6667
test_dtype<ScalarType::Int>();
6768
}
6869

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
5454
test_dtype<ScalarType::Char>();
5555
test_dtype<ScalarType::Short>();
5656
test_dtype<ScalarType::Bits16>();
57+
test_dtype<ScalarType::UInt16>();
5758
test_dtype<ScalarType::Int>();
5859
}
5960

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,13 @@ struct ScalarTypeToCppTypeWrapper<torch::executor::ScalarType::Bits16> {
650650
using ctype = uint16_t;
651651
};
652652

653+
// Use a C type of `uint16_t` instead of `UInt16` to simplify code reuse when
654+
// testing multiple integer types.
655+
template <>
656+
struct ScalarTypeToCppTypeWrapper<torch::executor::ScalarType::UInt16> {
657+
using ctype = uint16_t;
658+
};
659+
653660
// To allow implicit conversion between simple types to `ctype`
654661
#define SPECIALIZE_ScalarTypeToCppTypeWrapper(CTYPE, DTYPE) \
655662
template <> \

0 commit comments

Comments
 (0)