Skip to content

Commit d1d2e7a

Browse files
swolchokfacebook-github-bot
authored andcommitted
Manual LICM for mutable/const_data_ptr calls in quantize ops (#3784)
Summary: Pull Request resolved: #3784 Profiling showed that these calls were not getting inlined in ATen mode. Since function calls can have side effects, lack of inlining prevented the compiler from doing this transform itself. ghstack-source-id: 228354914 exported-using-ghexport Reviewed By: larryliu0820 Differential Revision: D57987182 fbshipit-source-id: e915fe73c7f8adbb0397729d17d9b368c996d1ca
1 parent ccce5fa commit d1d2e7a

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,18 @@ Tensor& dequantize_per_tensor_out(
9191

9292
// calculate the dequantized output, cast scale to float to match fbgemm
9393
// behavior
94-
#define DEQUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
95-
case ScalarType::out_dtype: \
96-
for (size_t i = 0; i < input.numel(); i++) { \
97-
out.mutable_data_ptr<OUT_CTYPE>()[i] = static_cast<OUT_CTYPE>( \
98-
(input.const_data_ptr<IN_CTYPE>()[i] - \
99-
static_cast<int32_t>(zero_point)) * \
100-
static_cast<float>(scale)); \
101-
} \
102-
break;
94+
#define DEQUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
95+
case ScalarType::out_dtype: { \
96+
/* Hoist these function calls out of our inner loop because they might not \
97+
* get inlined without LTO, particularly in ATen mode. */ \
98+
auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
99+
const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
100+
for (size_t i = 0; i < input.numel(); i++) { \
101+
out_data_ptr[i] = static_cast<OUT_CTYPE>( \
102+
(input_data_ptr[i] - static_cast<int32_t>(zero_point)) * \
103+
static_cast<float>(scale)); \
104+
} \
105+
} break;
103106
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
104107
case ScalarType::in_dtype: \
105108
switch (out.scalar_type()) { \
@@ -255,11 +258,12 @@ Tensor& dequantize_per_channel_out(
255258
if (zero_point_data != nullptr) { \
256259
_zero_point = zero_point_data[channel_ix]; \
257260
} \
261+
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
262+
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
258263
apply_over_dim_list( \
259-
[input, out, _scale, _zero_point](size_t in_ix) { \
260-
out.mutable_data_ptr<CTYPE_OUT>()[in_ix] = static_cast<CTYPE_OUT>( \
261-
(input.const_data_ptr<CTYPE_IN>()[in_ix] - _zero_point) * \
262-
_scale); \
264+
[input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \
265+
out_data_ptr[in_ix] = static_cast<CTYPE_OUT>( \
266+
(input_data_ptr[in_ix] - _zero_point) * _scale); \
263267
}, \
264268
input, \
265269
optional_dim_list, \

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,18 @@ Tensor& quantize_per_tensor_out(
118118
check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
119119

120120
// calculate the quantized input
121-
#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
122-
case ScalarType::out_dtype: \
123-
for (size_t i = 0; i < input.numel(); i++) { \
124-
IN_CTYPE value = input.const_data_ptr<IN_CTYPE>()[i]; \
125-
out.mutable_data_ptr<OUT_CTYPE>()[i] = \
126-
quantize_val<OUT_CTYPE, IN_CTYPE>( \
127-
scale, zero_point, value, quant_min, quant_max); \
128-
} \
129-
break;
121+
#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
122+
case ScalarType::out_dtype: { \
123+
/* Hoist these function calls out of our inner loop because they might not \
124+
* get inlined without LTO, particularly in ATen mode. */ \
125+
auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
126+
const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
127+
for (size_t i = 0; i < input.numel(); i++) { \
128+
IN_CTYPE value = input_data_ptr[i]; \
129+
out_data_ptr[i] = quantize_val<OUT_CTYPE, IN_CTYPE>( \
130+
scale, zero_point, value, quant_min, quant_max); \
131+
} \
132+
} break;
130133
#define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype) \
131134
case ScalarType::in_dtype: \
132135
switch (out.scalar_type()) { \
@@ -306,16 +309,21 @@ Tensor& quantize_per_channel_out(
306309
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
307310
double _scale = scale_data[channel_ix]; \
308311
int64_t _zero_point = zero_point_data[channel_ix]; \
312+
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
313+
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
309314
apply_over_dim_list( \
310-
[input, out, _scale, _zero_point, quant_min, quant_max]( \
311-
size_t in_ix) { \
312-
out.mutable_data_ptr<CTYPE_OUT>()[in_ix] = \
313-
quantize_val<CTYPE_OUT, CTYPE_IN>( \
314-
_scale, \
315-
_zero_point, \
316-
input.const_data_ptr<CTYPE_IN>()[in_ix], \
317-
quant_min, \
318-
quant_max); \
315+
[input_data_ptr, \
316+
out_data_ptr, \
317+
_scale, \
318+
_zero_point, \
319+
quant_min, \
320+
quant_max](size_t in_ix) { \
321+
out_data_ptr[in_ix] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
322+
_scale, \
323+
_zero_point, \
324+
input_data_ptr[in_ix], \
325+
quant_min, \
326+
quant_max); \
319327
}, \
320328
input, \
321329
optional_dim_list, \

0 commit comments

Comments
 (0)