11
11
#include < algorithm>
12
12
#include < cinttypes>
13
13
#include < cmath>
14
+ #if defined(__aarch64__) || defined(__ARM_NEON)
15
+ #include < arm_neon.h>
16
+ #endif
14
17
15
18
/* *
16
19
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
22
25
using Tensor = exec_aten::Tensor;
23
26
using Scalar = exec_aten::Scalar;
24
27
using ScalarType = exec_aten::ScalarType;
28
+ using StridesType = exec_aten::StridesType;
29
+ using SizesType = exec_aten::SizesType;
25
30
26
31
namespace {
27
32
@@ -61,6 +66,163 @@ void check_dequantize_per_tensor_args(
61
66
quant_max);
62
67
}
63
68
69
+ /* *
70
+ * Useful to reduce a tensor `in` over a given dimension `dim` using the
71
+ * reduce function `fn`, which should have the following signature:
72
+ * void fn(const size_t size, const size_t stride, const size_t base_ix)
73
+ * where `size` and `stride` are the size and stride of the dimension being
74
+ * reduced and `base_ix` is the index of the first element of the reduction.
75
+ */
76
+ template <typename Fn>
77
+ void apply_over_unpacked_dim (
78
+ const Fn& fn,
79
+ const exec_aten::Tensor& in,
80
+ const int64_t & dim) {
81
+ if (in.numel () == 0 ) {
82
+ return ;
83
+ }
84
+
85
+ ET_CHECK_MSG (in.dim () > 0 , " Input tensor must have at least one dimension" );
86
+ ET_CHECK_VALID_DIM (dim, in.dim ());
87
+
88
+ const size_t d = ET_NORMALIZE_IX (dim, in.dim ());
89
+ const size_t dim_size = in.size (d);
90
+ const size_t outer_size = getLeadingDims (in, d);
91
+ const size_t inner_size = getTrailingDims (in, d);
92
+ // Loop through all outer dimensions
93
+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
94
+ // Loop through dim
95
+ for (size_t unpacked_dim_idx = 0 ; unpacked_dim_idx < dim_size;
96
+ ++unpacked_dim_idx) {
97
+ fn (inner_size, outer_idx, unpacked_dim_idx);
98
+ }
99
+ }
100
+ }
101
+
102
+ void dequantize_optimized (
103
+ const int8_t * in,
104
+ const double scale,
105
+ const int64_t zero_point,
106
+ float * out,
107
+ int64_t quant_min,
108
+ int64_t quant_max,
109
+ size_t numel) {
110
+ ET_CHECK_MSG (
111
+ zero_point >= quant_min,
112
+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
113
+ zero_point,
114
+ quant_min);
115
+ ET_CHECK_MSG (
116
+ zero_point <= quant_max,
117
+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
118
+ zero_point,
119
+ quant_max);
120
+ size_t i = 0 ;
121
+ #if defined(__aarch64__) || defined(__ARM_NEON)
122
+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
123
+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
124
+ constexpr int32_t kVecSize = 16 ;
125
+ const size_t num_vecs = numel / kVecSize ;
126
+ const size_t rem = numel % kVecSize ;
127
+ for (; i < numel; i += kVecSize ) {
128
+ int8x16_t in_vec = vld1q_s8 (in);
129
+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
130
+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
131
+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
132
+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
133
+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
134
+
135
+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
136
+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
137
+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
138
+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
139
+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
140
+ in += kVecSize ;
141
+ }
142
+ #endif
143
+ for (; i < numel; i++) {
144
+ out[i] = (in[i] - zero_point) * scale;
145
+ }
146
+ }
147
+
148
+ bool can_use_optimized_dequantize_per_channel (
149
+ const Tensor& in,
150
+ const ScalarType in_dtype,
151
+ exec_aten::optional<ScalarType>& out_dtype) {
152
+ if (!executorch::runtime::is_contiguous_dim_order (
153
+ in.dim_order ().data (), in.dim ()) ||
154
+ (in_dtype != ScalarType::Char) ||
155
+ (out_dtype.has_value () && out_dtype.value () != ScalarType::Float)) {
156
+ return false ;
157
+ }
158
+ return true ;
159
+ }
160
+
161
+ void dequantize_per_channel_optimized (
162
+ const Tensor& in,
163
+ const Tensor& scales,
164
+ const optional<Tensor>& opt_zero_points,
165
+ Tensor& out,
166
+ int64_t axis,
167
+ int64_t quant_min,
168
+ int64_t quant_max,
169
+ ScalarType in_dtype,
170
+ exec_aten::optional<ScalarType>& out_dtype) {
171
+ check_dequantize_per_tensor_args (
172
+ in, quant_min, quant_max, in_dtype, out_dtype, out);
173
+ ET_CHECK_MSG (
174
+ executorch::runtime::is_contiguous_dim_order (
175
+ in.dim_order ().data (), in.dim ()),
176
+ " in must be in contiguous dim order" );
177
+ ET_CHECK_MSG (
178
+ in_dtype == ScalarType::Char,
179
+ " in.scalar_type() %" PRId8 " is not supported:" ,
180
+ static_cast <int8_t >(in.scalar_type ()));
181
+ if (out_dtype.has_value ()) {
182
+ ET_CHECK_MSG (
183
+ out_dtype.value () == ScalarType::Float,
184
+ " Only float output is supported" );
185
+ }
186
+ const int8_t * in_data = in.const_data_ptr <int8_t >();
187
+ float * out_data = out.mutable_data_ptr <float >();
188
+ const int64_t * zero_points_data = nullptr ;
189
+ if (opt_zero_points.has_value ()) {
190
+ zero_points_data = opt_zero_points.value ().const_data_ptr <int64_t >();
191
+ }
192
+ const double * scales_data = scales.const_data_ptr <double >();
193
+ const StridesType axis_stride = in.strides ()[axis];
194
+ const StridesType outer_stride = in.size (axis) * axis_stride;
195
+ apply_over_unpacked_dim (
196
+ [in_data,
197
+ out_data,
198
+ scales_data,
199
+ zero_points_data,
200
+ axis_stride,
201
+ outer_stride,
202
+ quant_min,
203
+ quant_max](
204
+ SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
205
+ const int8_t * in_data_local =
206
+ in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
207
+ const double scale = scales_data[unpacked_dim_idx];
208
+ const int64_t zero_point = zero_points_data != nullptr
209
+ ? zero_points_data[unpacked_dim_idx]
210
+ : 0 ;
211
+ float * out_data_local = out_data + outer_idx * outer_stride +
212
+ unpacked_dim_idx * axis_stride;
213
+ dequantize_optimized (
214
+ in_data_local,
215
+ scale,
216
+ zero_point,
217
+ out_data_local,
218
+ quant_min,
219
+ quant_max,
220
+ numel);
221
+ },
222
+ in,
223
+ axis);
224
+ }
225
+
64
226
} // namespace
65
227
66
228
/* *
@@ -225,6 +387,20 @@ Tensor& dequantize_per_channel_out(
225
387
check_dequantize_per_tensor_args (
226
388
input, quant_min, quant_max, dtype, out_dtype, out);
227
389
390
+ if (can_use_optimized_dequantize_per_channel (input, dtype, out_dtype)) {
391
+ dequantize_per_channel_optimized (
392
+ input,
393
+ scale,
394
+ opt_zero_points,
395
+ out,
396
+ axis,
397
+ quant_min,
398
+ quant_max,
399
+ dtype,
400
+ out_dtype);
401
+ return out;
402
+ }
403
+
228
404
// a list contains all dimensions except axis
229
405
int64_t dims[kTensorDimensionLimit ];
230
406
for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
0 commit comments