Skip to content

Commit 6cb1fac

Browse files
author
morelos
committed
[ET] enabling half dtype output for dequantization and making logic consistent
Pull Request resolved: #11552 # Context Currently the cpu implementation for the dequantization operator (which includes `dequantize_per_token`, `dequantize_per_tensor`, and `dequantize_per_channel`), does not inherently support half (fp16) input scalar types. In order to align with the PyTorch implementation that accepts fp16 and bfp16 inputs, this diff aims to enable half input dtype support for the quantization operators. We will be comparing this implementation against the vulkan operators. Furthermore, there is a casting bug when applying the zero_point, as only in the `dequantize_per_tensor` implementation does it cast the zero_point to int32, while comparitively for `dequantize_per_channel` and `dequantize_per_token` they do not cast the zero_point. In an environment that only supports 32bit integers, understandbly there will be some inconsistencies in dequantization logic as per_tensor will contain different overflow logic compared to its respective per_token and per_channel partner since the latter eliminates the overflow by utilizing a 64bit value. # Changes As defined in ExecuTorch [scalar_type_util.h](https://github.com/pytorch/executorch/blob/053686242c1687f0d51b3bb8befd14b047d7b025/runtime/core/exec_aten/util/scalar_type_util.h), the changes in this diff include adding a new macro `ET_FORALL_FLOATH_TYPES_WITH` to `util/scalar_type_util.h`, updating the `CALCULATE_INT_TYPE` macro to handle the new dtype. This enables support for Half (fp16), Float (fp32), and Double (fp64). I have also included more comprehensive testing against the input dtypes, including adding double testing since it didn't already exist before. Instead of just confirming that all the output dtypes are supported, we also have a check that all input dtypes are supported now as well. In order to align both dequantization implementations, we cast the zero_point to 32bit for both to maintain the overflow logic carried over from `dequantize_per_tensor`. ghstack-source-id: 290376483 @exported-using-ghexport Differential Revision: [D76289181](https://our.internmc.facebook.com/intern/diff/D76289181/)
1 parent 92d9b14 commit 6cb1fac

File tree

3 files changed

+119
-22
lines changed

3 files changed

+119
-22
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out(
288288
static_cast<float>(scale)); \
289289
} \
290290
} break;
291-
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
292-
case ScalarType::in_dtype: \
293-
switch (out.scalar_type()) { \
294-
ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
295-
default: \
296-
ET_CHECK_MSG( \
297-
false, \
298-
"Unhandled output dtype %" PRId8, \
299-
static_cast<int8_t>(out.scalar_type())); \
300-
} \
291+
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
292+
case ScalarType::in_dtype: \
293+
switch (out.scalar_type()) { \
294+
ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
295+
default: \
296+
ET_CHECK_MSG( \
297+
false, \
298+
"Unhandled output dtype %" PRId8, \
299+
static_cast<int8_t>(out.scalar_type())); \
300+
} \
301301
break;
302302

303303
switch (input.scalar_type()) {
@@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out(
459459
} \
460460
out_data_ptr[current_ix] = \
461461
static_cast<CTYPE_OUT>( \
462-
input_data_ptr[current_ix] - zero_point) * \
462+
input_data_ptr[current_ix] - \
463+
static_cast<int32_t>(zero_point)) * \
463464
_scale; \
464465
} \
465466
}, \
@@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out(
478479
apply_over_dim_list( \
479480
[input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \
480481
out_data_ptr[in_ix] = static_cast<CTYPE_OUT>( \
481-
(input_data_ptr[in_ix] - _zero_point) * _scale); \
482+
(input_data_ptr[in_ix] - static_cast<int32_t>(_zero_point)) * \
483+
_scale); \
482484
}, \
483485
input, \
484486
optional_dim_list, \
485487
channel_ix); \
486488
} \
487489
break;
488-
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
489-
case ScalarType::in_dtype: \
490-
switch (out.scalar_type()) { \
491-
ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
492-
default: \
493-
ET_CHECK_MSG( \
494-
false, \
495-
"Unhandled output dtype %" PRId8, \
496-
static_cast<int8_t>(out.scalar_type())); \
497-
} \
490+
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
491+
case ScalarType::in_dtype: \
492+
switch (out.scalar_type()) { \
493+
ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
494+
default: \
495+
ET_CHECK_MSG( \
496+
false, \
497+
"Unhandled output dtype %" PRId8, \
498+
static_cast<int8_t>(out.scalar_type())); \
499+
} \
498500
break;
499501

500502
switch (input.scalar_type()) {

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) {
6767
test_dtype<ScalarType::Int>();
6868
}
6969

70+
/// Test all supported output dtypes for dequantization
71+
template <ScalarType OUT_DTYPE>
72+
void test_output_dtype() {
73+
TensorFactory<ScalarType::Byte> tf;
74+
75+
Tensor input = tf.full({3, 5}, 100);
76+
double scale = 0.5;
77+
int64_t zero_point = 30;
78+
int64_t quant_min = 0;
79+
int64_t quant_max = 255;
80+
81+
TensorFactory<OUT_DTYPE> tfo;
82+
Tensor out = tfo.zeros({3, 5});
83+
// (100 - 30) * 0.5 = 35
84+
Tensor expected = tfo.full({3, 5}, 35);
85+
dequantize_per_tensor_out(
86+
input,
87+
scale,
88+
zero_point,
89+
quant_min,
90+
quant_max,
91+
ScalarType::Byte,
92+
optional<ScalarType>(OUT_DTYPE),
93+
out);
94+
95+
EXPECT_TENSOR_EQ(out, expected);
96+
}
97+
98+
TEST(OpDequantizeOutTest, AllOutputDtypesSupported) {
99+
et_pal_init();
100+
test_output_dtype<ScalarType::Float>();
101+
test_output_dtype<ScalarType::Double>();
102+
test_output_dtype<ScalarType::Half>();
103+
}
104+
105+
TEST(OpDequantizeOutTest, HalfOutput) {
106+
et_pal_init();
107+
TensorFactory<ScalarType::Byte> tf;
108+
109+
Tensor input = tf.full({3, 5}, 10);
110+
double scale = 0.5;
111+
int64_t zero_point = 100000;
112+
int64_t quant_min = 0;
113+
int64_t quant_max = 255;
114+
115+
TensorFactory<ScalarType::Half> tfo;
116+
Tensor out = tfo.zeros({3, 5});
117+
// (10 - 100000) * 0.5 = -49995
118+
dequantize_per_tensor_out(
119+
input,
120+
scale,
121+
zero_point,
122+
quant_min,
123+
quant_max,
124+
ScalarType::Byte,
125+
optional<ScalarType>(ScalarType::Half),
126+
out);
127+
128+
// The expected result should be (10 - 100000) * 0.5 = -49995
129+
Tensor expected = tfo.full({3, 5}, -49995);
130+
EXPECT_TENSOR_EQ(out, expected);
131+
}
132+
133+
TEST(OpDequantizeOutTest, DoubleOutput) {
134+
et_pal_init();
135+
TensorFactory<ScalarType::Byte> tf;
136+
137+
Tensor input = tf.full({3, 5}, 10);
138+
double scale = 0.5;
139+
int64_t zero_point = 100000;
140+
int64_t quant_min = 0;
141+
int64_t quant_max = 255;
142+
143+
TensorFactory<ScalarType::Double> tfo;
144+
Tensor out = tfo.zeros({3, 5});
145+
dequantize_per_tensor_out(
146+
input,
147+
scale,
148+
zero_point,
149+
quant_min,
150+
quant_max,
151+
ScalarType::Byte,
152+
optional<ScalarType>(ScalarType::Double),
153+
out);
154+
155+
// The expected result should be (10 - 100000) * 0.5 = -49995
156+
Tensor expected = tfo.full({3, 5}, -49995);
157+
EXPECT_TENSOR_EQ(out, expected);
158+
}
159+
70160
TEST(OpDequantizeOutTest, NonWholeNumbers) {
71161
et_pal_init();
72162
TensorFactory<ScalarType::Byte> tf;

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
199199
_(ANOTHER_INPUT, float, Float) \
200200
_(ANOTHER_INPUT, double, Double)
201201

202+
#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \
203+
_(ANOTHER_INPUT, float, Float) \
204+
_(ANOTHER_INPUT, double, Double) \
205+
_(ANOTHER_INPUT, ::executorch::aten::Half, Half)
206+
202207
#define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
203208
_(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
204209
_(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)

0 commit comments

Comments
 (0)