Skip to content

Commit bdc1a33

Browse files
committed
[Executorch][quant] Optimize per channel dequantize
When using quantized kv cache, dequantization routine takes significantly long. This diff just vectorizes dequant per channel for common case. Differential Revision: [D63338858](https://our.internmc.facebook.com/intern/diff/D63338858/) [ghstack-poisoned]
1 parent 5d9d688 commit bdc1a33

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <algorithm>
1212
#include <cinttypes>
1313
#include <cmath>
14+
#if defined(__aarch64__) || defined(__ARM_NEON)
15+
#include <arm_neon.h>
16+
#endif
1417

1518
/**
1619
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
2225
using Tensor = exec_aten::Tensor;
2326
using Scalar = exec_aten::Scalar;
2427
using ScalarType = exec_aten::ScalarType;
28+
using StridesType = exec_aten::StridesType;
29+
using SizesType = exec_aten::SizesType;
2530

2631
namespace {
2732

@@ -61,6 +66,163 @@ void check_dequantize_per_tensor_args(
6166
quant_max);
6267
}
6368

69+
/**
70+
* Useful to reduce a tensor `in` over a given dimension `dim` using the
71+
* reduce function `fn`, which should have the following signature:
72+
* void fn(const size_t size, const size_t stride, const size_t base_ix)
73+
* where `size` and `stride` are the size and stride of the dimension being
74+
* reduced and `base_ix` is the index of the first element of the reduction.
75+
*/
76+
template <typename Fn>
77+
void apply_over_unpacked_dim(
78+
const Fn& fn,
79+
const exec_aten::Tensor& in,
80+
const int64_t& dim) {
81+
if (in.numel() == 0) {
82+
return;
83+
}
84+
85+
ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
86+
ET_CHECK_VALID_DIM(dim, in.dim());
87+
88+
const size_t d = ET_NORMALIZE_IX(dim, in.dim());
89+
const size_t dim_size = in.size(d);
90+
const size_t outer_size = getLeadingDims(in, d);
91+
const size_t inner_size = getTrailingDims(in, d);
92+
// Loop through all outer dimensions
93+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
94+
// Loop through dim
95+
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
96+
++unpacked_dim_idx) {
97+
fn(inner_size, outer_idx, unpacked_dim_idx);
98+
}
99+
}
100+
}
101+
102+
void dequantize_optimized(
103+
const int8_t* in,
104+
const double scale,
105+
const int64_t zero_point,
106+
float* out,
107+
int64_t quant_min,
108+
int64_t quant_max,
109+
size_t numel) {
110+
ET_CHECK_MSG(
111+
zero_point >= quant_min,
112+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
113+
zero_point,
114+
quant_min);
115+
ET_CHECK_MSG(
116+
zero_point <= quant_max,
117+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
118+
zero_point,
119+
quant_max);
120+
size_t i = 0;
121+
#if defined(__aarch64__) || defined(__ARM_NEON)
122+
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
123+
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
124+
constexpr int32_t kVecSize = 16;
125+
const size_t num_vecs = numel / kVecSize;
126+
const size_t rem = numel % kVecSize;
127+
for (; i < numel; i += kVecSize) {
128+
int8x16_t in_vec = vld1q_s8(in);
129+
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
130+
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
131+
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
132+
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
133+
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);
134+
135+
int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
136+
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
137+
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
138+
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
139+
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
140+
in += kVecSize;
141+
}
142+
#endif
143+
for (; i < numel; i++) {
144+
out[i] = (in[i] - zero_point) * scale;
145+
}
146+
}
147+
148+
bool can_use_optimized_dequantize_per_channel(
149+
const Tensor& in,
150+
const ScalarType in_dtype,
151+
exec_aten::optional<ScalarType>& out_dtype) {
152+
if (!executorch::runtime::is_contiguous_dim_order(
153+
in.dim_order().data(), in.dim()) ||
154+
(in_dtype != ScalarType::Char) ||
155+
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
156+
return false;
157+
}
158+
return true;
159+
}
160+
161+
void dequantize_per_channel_optimized(
162+
const Tensor& in,
163+
const Tensor& scales,
164+
const optional<Tensor>& opt_zero_points,
165+
Tensor& out,
166+
int64_t axis,
167+
int64_t quant_min,
168+
int64_t quant_max,
169+
ScalarType in_dtype,
170+
exec_aten::optional<ScalarType>& out_dtype) {
171+
check_dequantize_per_tensor_args(
172+
in, quant_min, quant_max, in_dtype, out_dtype, out);
173+
ET_CHECK_MSG(
174+
executorch::runtime::is_contiguous_dim_order(
175+
in.dim_order().data(), in.dim()),
176+
"in must be in contiguous dim order");
177+
ET_CHECK_MSG(
178+
in_dtype == ScalarType::Char,
179+
"in.scalar_type() %" PRId8 " is not supported:",
180+
static_cast<int8_t>(in.scalar_type()));
181+
if (out_dtype.has_value()) {
182+
ET_CHECK_MSG(
183+
out_dtype.value() == ScalarType::Float,
184+
"Only float output is supported");
185+
}
186+
const int8_t* in_data = in.const_data_ptr<int8_t>();
187+
float* out_data = out.mutable_data_ptr<float>();
188+
const int64_t* zero_points_data = nullptr;
189+
if (opt_zero_points.has_value()) {
190+
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
191+
}
192+
const double* scales_data = scales.const_data_ptr<double>();
193+
const StridesType axis_stride = in.strides()[axis];
194+
const StridesType outer_stride = in.size(axis) * axis_stride;
195+
apply_over_unpacked_dim(
196+
[in_data,
197+
out_data,
198+
scales_data,
199+
zero_points_data,
200+
axis_stride,
201+
outer_stride,
202+
quant_min,
203+
quant_max](
204+
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
205+
const int8_t* in_data_local =
206+
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
207+
const double scale = scales_data[unpacked_dim_idx];
208+
const int64_t zero_point = zero_points_data != nullptr
209+
? zero_points_data[unpacked_dim_idx]
210+
: 0;
211+
float* out_data_local = out_data + outer_idx * outer_stride +
212+
unpacked_dim_idx * axis_stride;
213+
dequantize_optimized(
214+
in_data_local,
215+
scale,
216+
zero_point,
217+
out_data_local,
218+
quant_min,
219+
quant_max,
220+
numel);
221+
},
222+
in,
223+
axis);
224+
}
225+
64226
} // namespace
65227

66228
/**
@@ -225,6 +387,20 @@ Tensor& dequantize_per_channel_out(
225387
check_dequantize_per_tensor_args(
226388
input, quant_min, quant_max, dtype, out_dtype, out);
227389

390+
if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
391+
dequantize_per_channel_optimized(
392+
input,
393+
scale,
394+
opt_zero_points,
395+
out,
396+
axis,
397+
quant_min,
398+
quant_max,
399+
dtype,
400+
out_dtype);
401+
return out;
402+
}
403+
228404
// a list contains all dimensions except axis
229405
int64_t dims[kTensorDimensionLimit];
230406
for (int64_t i = 0; i < input.dim() - 1; i++) {

0 commit comments

Comments
 (0)