Skip to content

Commit 64d1f4d

Browse files
authored
[ET] enabling half dtype output for dequantization and making logic consistent
Differential Revision: D76289181 Pull Request resolved: #11552
1 parent 9bb255c commit 64d1f4d

File tree

3 files changed

+119
-22
lines changed

3 files changed

+119
-22
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out(
288288
static_cast<float>(scale)); \
289289
} \
290290
} break;
291-
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
292-
case ScalarType::in_dtype: \
293-
switch (out.scalar_type()) { \
294-
ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
295-
default: \
296-
ET_CHECK_MSG( \
297-
false, \
298-
"Unhandled output dtype %" PRId8, \
299-
static_cast<int8_t>(out.scalar_type())); \
300-
} \
291+
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
292+
case ScalarType::in_dtype: \
293+
switch (out.scalar_type()) { \
294+
ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
295+
default: \
296+
ET_CHECK_MSG( \
297+
false, \
298+
"Unhandled output dtype %" PRId8, \
299+
static_cast<int8_t>(out.scalar_type())); \
300+
} \
301301
break;
302302

303303
switch (input.scalar_type()) {
@@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out(
459459
} \
460460
out_data_ptr[current_ix] = \
461461
static_cast<CTYPE_OUT>( \
462-
input_data_ptr[current_ix] - zero_point) * \
462+
input_data_ptr[current_ix] - \
463+
static_cast<int32_t>(zero_point)) * \
463464
_scale; \
464465
} \
465466
}, \
@@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out(
478479
apply_over_dim_list( \
479480
[input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \
480481
out_data_ptr[in_ix] = static_cast<CTYPE_OUT>( \
481-
(input_data_ptr[in_ix] - _zero_point) * _scale); \
482+
(input_data_ptr[in_ix] - static_cast<int32_t>(_zero_point)) * \
483+
_scale); \
482484
}, \
483485
input, \
484486
optional_dim_list, \
485487
channel_ix); \
486488
} \
487489
break;
488-
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
489-
case ScalarType::in_dtype: \
490-
switch (out.scalar_type()) { \
491-
ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
492-
default: \
493-
ET_CHECK_MSG( \
494-
false, \
495-
"Unhandled output dtype %" PRId8, \
496-
static_cast<int8_t>(out.scalar_type())); \
497-
} \
490+
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
491+
case ScalarType::in_dtype: \
492+
switch (out.scalar_type()) { \
493+
ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
494+
default: \
495+
ET_CHECK_MSG( \
496+
false, \
497+
"Unhandled output dtype %" PRId8, \
498+
static_cast<int8_t>(out.scalar_type())); \
499+
} \
498500
break;
499501

500502
switch (input.scalar_type()) {

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) {
6767
test_dtype<ScalarType::Int>();
6868
}
6969

70+
/// Test all supported output dtypes for dequantization
71+
template <ScalarType OUT_DTYPE>
72+
void test_output_dtype() {
73+
TensorFactory<ScalarType::Byte> tf;
74+
75+
Tensor input = tf.full({3, 5}, 100);
76+
double scale = 0.5;
77+
int64_t zero_point = 30;
78+
int64_t quant_min = 0;
79+
int64_t quant_max = 255;
80+
81+
TensorFactory<OUT_DTYPE> tfo;
82+
Tensor out = tfo.zeros({3, 5});
83+
// (100 - 30) * 0.5 = 35
84+
Tensor expected = tfo.full({3, 5}, 35);
85+
dequantize_per_tensor_out(
86+
input,
87+
scale,
88+
zero_point,
89+
quant_min,
90+
quant_max,
91+
ScalarType::Byte,
92+
optional<ScalarType>(OUT_DTYPE),
93+
out);
94+
95+
EXPECT_TENSOR_EQ(out, expected);
96+
}
97+
98+
TEST(OpDequantizeOutTest, AllOutputDtypesSupported) {
99+
et_pal_init();
100+
test_output_dtype<ScalarType::Float>();
101+
test_output_dtype<ScalarType::Double>();
102+
test_output_dtype<ScalarType::Half>();
103+
}
104+
105+
TEST(OpDequantizeOutTest, HalfOutput) {
106+
et_pal_init();
107+
TensorFactory<ScalarType::Byte> tf;
108+
109+
Tensor input = tf.full({3, 5}, 10);
110+
double scale = 0.5;
111+
int64_t zero_point = 100000;
112+
int64_t quant_min = 0;
113+
int64_t quant_max = 255;
114+
115+
TensorFactory<ScalarType::Half> tfo;
116+
Tensor out = tfo.zeros({3, 5});
117+
// (10 - 100000) * 0.5 = -49995
118+
dequantize_per_tensor_out(
119+
input,
120+
scale,
121+
zero_point,
122+
quant_min,
123+
quant_max,
124+
ScalarType::Byte,
125+
optional<ScalarType>(ScalarType::Half),
126+
out);
127+
128+
// The expected result should be (10 - 100000) * 0.5 = -49995
129+
Tensor expected = tfo.full({3, 5}, -49995);
130+
EXPECT_TENSOR_EQ(out, expected);
131+
}
132+
133+
TEST(OpDequantizeOutTest, DoubleOutput) {
134+
et_pal_init();
135+
TensorFactory<ScalarType::Byte> tf;
136+
137+
Tensor input = tf.full({3, 5}, 10);
138+
double scale = 0.5;
139+
int64_t zero_point = 100000;
140+
int64_t quant_min = 0;
141+
int64_t quant_max = 255;
142+
143+
TensorFactory<ScalarType::Double> tfo;
144+
Tensor out = tfo.zeros({3, 5});
145+
dequantize_per_tensor_out(
146+
input,
147+
scale,
148+
zero_point,
149+
quant_min,
150+
quant_max,
151+
ScalarType::Byte,
152+
optional<ScalarType>(ScalarType::Double),
153+
out);
154+
155+
// The expected result should be (10 - 100000) * 0.5 = -49995
156+
Tensor expected = tfo.full({3, 5}, -49995);
157+
EXPECT_TENSOR_EQ(out, expected);
158+
}
159+
70160
TEST(OpDequantizeOutTest, NonWholeNumbers) {
71161
et_pal_init();
72162
TensorFactory<ScalarType::Byte> tf;

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
199199
_(ANOTHER_INPUT, float, Float) \
200200
_(ANOTHER_INPUT, double, Double)
201201

202+
#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \
203+
_(ANOTHER_INPUT, float, Float) \
204+
_(ANOTHER_INPUT, double, Double) \
205+
_(ANOTHER_INPUT, ::executorch::aten::Half, Half)
206+
202207
#define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
203208
_(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
204209
_(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)

0 commit comments

Comments
 (0)