Skip to content

Commit 53f5870

Browse files
sxufacebook-github-bot
authored andcommitted
Support quantizing to and dequantizing from uint16_t (Bits16) (#5839)
Summary: Pull Request resolved: #5839 Reviewed By: kimishpatel Differential Revision: D63730600 fbshipit-source-id: 7d4fb7f73ca81275ac276ec8dc186c8a1685ca75
1 parent 2027a14 commit 53f5870

File tree

5 files changed

+26
-3
lines changed

5 files changed

+26
-3
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void check_dequantize_per_tensor_args(
3838
ET_CHECK_MSG(
3939
input.scalar_type() == ScalarType::Byte ||
4040
input.scalar_type() == ScalarType::Char ||
41+
input.scalar_type() == ScalarType::Bits16 ||
4142
input.scalar_type() == ScalarType::Short ||
4243
input.scalar_type() == ScalarType::Int,
4344
"input.scalar_type() %" PRId8 " is not supported:",
@@ -118,6 +119,7 @@ Tensor& dequantize_per_tensor_out(
118119

119120
switch (input.scalar_type()) {
120121
ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE);
122+
CALCULATE_INT_TYPE(uint16_t, Bits16);
121123
default:
122124
ET_CHECK_MSG(
123125
false,
@@ -312,6 +314,7 @@ Tensor& dequantize_per_channel_out(
312314

313315
switch (input.scalar_type()) {
314316
ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE);
317+
CALCULATE_INT_TYPE(uint16_t, Bits16);
315318
default:
316319
ET_CHECK_MSG(
317320
false,

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ void check_quantize_per_tensor_args(
5757
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
5858
quant_max_upper_bound =
5959
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
60+
} else if (dtype == ScalarType::Bits16) {
61+
quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
62+
quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
6063
} else if (dtype == ScalarType::Short) {
6164
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
6265
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
@@ -135,6 +138,7 @@ Tensor& quantize_per_tensor_out(
135138
case ScalarType::in_dtype: \
136139
switch (out.scalar_type()) { \
137140
ET_FORALL_INT_TYPES_WITH(IN_CTYPE, QUANTIZE_IMPL); \
141+
QUANTIZE_IMPL(IN_CTYPE, uint16_t, Bits16) \
138142
default: \
139143
ET_CHECK_MSG( \
140144
false, \
@@ -329,6 +333,7 @@ Tensor& quantize_per_channel_out(
329333
case ScalarType::in_dtype: \
330334
switch (out.scalar_type()) { \
331335
ET_FORALL_INT_TYPES_WITH(CTYPE_IN, QUANTIZE_IMPL); \
336+
QUANTIZE_IMPL(CTYPE_IN, uint16_t, Bits16) \
332337
default: \
333338
ET_CHECK_MSG( \
334339
false, \

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ void test_dtype() {
6060
TEST(OpDequantizeOutTest, AllDtypesSupported) {
6161
et_pal_init();
6262
test_dtype<ScalarType::Byte>();
63+
test_dtype<ScalarType::Char>();
64+
test_dtype<ScalarType::Short>();
65+
test_dtype<ScalarType::Bits16>();
66+
test_dtype<ScalarType::Int>();
6367
}
6468

6569
TEST(OpDequantizeOutTest, NonWholeNumbers) {

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ void test_dtype() {
3535
Tensor input = tf.full({3, 5}, 4);
3636
double scale = 0.5;
3737

38-
int64_t zero_point = 127;
38+
int64_t zero_point = 108;
3939
int64_t quant_min = 0;
40-
int64_t quant_max = 255;
40+
int64_t quant_max = 127;
4141

4242
TensorFactory<DTYPE> tfo;
4343
Tensor out = tfo.zeros({3, 5});
4444
// 4 / 0.5 + 127
45-
Tensor expected = tfo.full({3, 5}, 135);
45+
Tensor expected = tfo.full({3, 5}, 116);
4646
quantize_per_tensor_out(
4747
input, scale, zero_point, quant_min, quant_max, DTYPE, out);
4848

@@ -51,6 +51,10 @@ void test_dtype() {
5151

5252
TEST(OpQuantizeOutTest, AllDtypesSupported) {
5353
test_dtype<ScalarType::Byte>();
54+
test_dtype<ScalarType::Char>();
55+
test_dtype<ScalarType::Short>();
56+
test_dtype<ScalarType::Bits16>();
57+
test_dtype<ScalarType::Int>();
5458
}
5559

5660
TEST(OpQuantizeOutTest, TensorArgOverload) {

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,13 @@ struct ScalarTypeToCppTypeWrapper<torch::executor::ScalarType::Bool> {
645645
using ctype = uint8_t;
646646
};
647647

648+
// Use a C type of `uint16_t` instead of `Bits16` to simplify code reuse when
649+
// testing multiple integer types.
650+
template <>
651+
struct ScalarTypeToCppTypeWrapper<torch::executor::ScalarType::Bits16> {
652+
using ctype = uint16_t;
653+
};
654+
648655
// To allow implicit conversion between simple types to `ctype`
649656
#define SPECIALIZE_ScalarTypeToCppTypeWrapper(CTYPE, DTYPE) \
650657
template <> \

0 commit comments

Comments
 (0)