Skip to content

Commit 5ebcf7f

Browse files
author
Zonglin Peng
committed
create quantized_linear_per_tensor_out in cpu
1 parent ab455df commit 5ebcf7f

File tree

4 files changed

+336
-0
lines changed

4 files changed

+336
-0
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,8 @@
183183
kernels:
184184
- arg_meta: null
185185
kernel_name: impl::reference::quantized_matmul_out
186+
187+
- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
188+
kernels:
189+
- arg_meta: null
190+
kernel_name: impl::reference::quantized_linear_per_tensor_out
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include <executorch/runtime/core/array_ref.h>
6+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
7+
#include <executorch/runtime/kernel/kernel_includes.h>
8+
#include <optional>
9+
10+
namespace cadence {
11+
namespace impl {
12+
namespace cpu {
13+
namespace native {
14+
namespace {
15+
using ::executorch::runtime::getLeadingDims;
16+
17+
18+
#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \
19+
_(uint8_t, Byte) \
20+
_(int8_t, Char)
21+
22+
inline __attribute__((always_inline)) void linear_(
23+
const ::executorch::aten::Tensor& input,
24+
const ::executorch::aten::Tensor& weight,
25+
const ::executorch::aten::optional<::executorch::aten::Tensor>& bias,
26+
::executorch::aten::Tensor& output) {
27+
const float* __restrict__ input_data = input.const_data_ptr<float>();
28+
const float* __restrict__ weight_data = weight.const_data_ptr<float>();
29+
const float* __restrict__ bias_data = bias.value().const_data_ptr<float>();
30+
float* __restrict__ output_data = output.mutable_data_ptr<float>();
31+
32+
// input comes in shape [batch_size, in_dim]
33+
// weight comes in shape [out_dim, in_dim]
34+
// output comes in empty with shape [batch_size, out_dim]
35+
// Perform matrix multiply (M x N) x (N x P) => M x P
36+
int64_t M = weight.size(0); // = out_dim
37+
int64_t N = weight.size(1); // = in_dim
38+
39+
// Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the
40+
// leading dimensions is d0 * d1 * ... * d_{N-2}
41+
int64_t leading_dims =
42+
getLeadingDims(input, input.dim() - 1);
43+
44+
for (int i = 0; i < leading_dims; ++i) {
45+
for (int j = 0; j < M; ++j) {
46+
float sum = bias_data[j];
47+
for (int k = 0; k < N; ++k) {
48+
sum += input_data[i * N + k] * weight_data[j * N + k];
49+
}
50+
output_data[i * M + j] = sum;
51+
}
52+
}
53+
}
54+
55+
} // namespace
56+
} // namespace native
57+
} // namespace cpu
58+
} // namespace impl
59+
} // namespace cadence

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88

99
#include <executorch/backends/cadence/reference/kernels/kernels.h>
10+
#include <executorch/backends/cadence/reference/operators/operators.h>
11+
#include <executorch/backends/cadence/reference/operators/quantized_ops.h>
1012
#include <executorch/runtime/kernel/kernel_includes.h>
1113

1214
namespace impl {
@@ -85,6 +87,7 @@ void quantized_linear_out(
8587
int64_t out_zero_point,
8688
__ET_UNUSED const executorch::aten::optional<Tensor>& offset,
8789
Tensor& out) {
90+
// TODO: refactor to use switch case as quantized_linear_per_tensor_out
8891
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
8992
_typed_quantized_linear<uint8_t>(
9093
src,
@@ -115,6 +118,42 @@ void quantized_linear_out(
115118
}
116119
}
117120

121+
void quantized_linear_per_tensor_out(
122+
__ET_UNUSED KernelRuntimeContext& ctx,
123+
const Tensor& src,
124+
const Tensor& weight,
125+
const Tensor& bias,
126+
const int64_t src_zero_point,
127+
const int64_t weight_zero_point,
128+
const int64_t out_multiplier,
129+
const int64_t out_shift,
130+
const int64_t out_zero_point,
131+
__ET_UNUSED const executorch::aten::optional<Tensor>& offset,
132+
Tensor& out) {
133+
#define typed_quantized_linear_per_tensor(ctype, dtype) \
134+
case executorch::aten::ScalarType::dtype: { \
135+
quantized_linear_per_tensor_<ctype>( \
136+
src, \
137+
weight, \
138+
bias, \
139+
src_zero_point, \
140+
weight_zero_point, \
141+
out_multiplier, \
142+
out_shift, \
143+
out_zero_point, \
144+
out); \
145+
break; \
146+
}
147+
148+
executorch::aten::ScalarType dtype = out.scalar_type();
149+
switch (dtype) {
150+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor);
151+
default:
152+
ET_DCHECK_MSG(false, "Unhandled dtype %s", toString(dtype));
153+
}
154+
#undef typed_quantized_linear_per_tensor
155+
}
156+
118157
}; // namespace native
119158
}; // namespace reference
120159
}; // namespace impl
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include <executorch/backends/cadence/reference/kernels/kernels.h>
6+
#include <executorch/backends/cadence/reference/operators/operators.h>
7+
8+
using executorch::runtime::getLeadingDims;
9+
10+
// Generate kernels that perform elementwise arithmetic on two quantized
11+
// tensors. The tensors are either the same size, or the second tensor is a
12+
// scalar.
13+
#define DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \
14+
template <typename T> \
15+
void BINARY_FUNC_NAME( \
16+
const ::executorch::aten::Tensor& X, \
17+
float X_scale, \
18+
int32_t X_zero_point, \
19+
const ::executorch::aten::Tensor& Y, \
20+
float Y_scale, \
21+
int32_t Y_zero_point, \
22+
float out_scale, \
23+
int32_t out_zero_point, \
24+
::executorch::aten::Tensor& out) { \
25+
const T* __restrict__ X_data = X.const_data_ptr<T>(); \
26+
const T* __restrict__ Y_data = Y.const_data_ptr<T>(); \
27+
T* __restrict__ out_data = out.mutable_data_ptr<T>(); \
28+
size_t Y_numel = Y.numel(); \
29+
size_t X_numel = X.numel(); \
30+
float inv_out_scale = 1.0f / out_scale; \
31+
/* Tensor that has the same element of X */ \
32+
if (Y_numel == X_numel) { \
33+
for (size_t i = 0; i < X_numel; ++i) { \
34+
float x = kernels::dequantize<T>(X_data[i], X_scale, X_zero_point); \
35+
float y = kernels::dequantize<T>(Y_data[i], Y_scale, Y_zero_point); \
36+
float z = x OP y; \
37+
out_data[i] = kernels::quantize<T>(z, inv_out_scale, out_zero_point); \
38+
} \
39+
} /* if Y is a scalar Tensor */ \
40+
else if (Y_numel == 1) { \
41+
float y = kernels::dequantize<T>(Y_data[0], Y_scale, Y_zero_point); \
42+
for (size_t i = 0; i < X_numel; ++i) { \
43+
float x = kernels::dequantize<T>(X_data[i], X_scale, X_zero_point); \
44+
float z = x OP y; \
45+
out_data[i] = kernels::quantize<T>(z, inv_out_scale, out_zero_point); \
46+
} \
47+
} /* other broadcasting cases */ \
48+
else { \
49+
ET_DCHECK_MSG(false, "Unsupported broadcasting"); \
50+
} \
51+
}
52+
53+
template <typename T>
54+
inline __attribute__((always_inline)) void quantized_linear_per_tensor_(
55+
const ::executorch::aten::Tensor& src,
56+
const ::executorch::aten::Tensor& weight,
57+
const ::executorch::aten::Tensor& bias,
58+
const int64_t src_zero_point,
59+
const int64_t weight_zero_point,
60+
const int64_t out_multiplier,
61+
const int64_t out_shift,
62+
const int64_t out_zero_point,
63+
::executorch::aten::Tensor& out) {
64+
// input comes in shape [leading_dims, in_dim]
65+
// weight comes in shape [out_dim, in_dim]
66+
// output comes in empty with shape [leading_dims, out_dim]
67+
// Perform matrix multiply (M x N) x (N x P)' => M x P
68+
const int64_t leading_dims = getLeadingDims(src, src.dim() - 1);
69+
const int64_t out_dim = weight.size(0); // = out_dim
70+
const int64_t in_dim = weight.size(1); // = in_dim
71+
72+
const T* __restrict__ in_data = src.const_data_ptr<T>();
73+
const T* __restrict__ weight_data = weight.const_data_ptr<T>();
74+
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
75+
T* __restrict__ out_data = out.mutable_data_ptr<T>();
76+
77+
// Compute the requant_scale from out_multiplier and out_shift
78+
const float requant_scale =
79+
-out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift);
80+
81+
for (size_t i = 0; i < leading_dims; ++i) {
82+
for (size_t j = 0; j < out_dim; ++j) {
83+
int32_t sum = bias_data[j];
84+
for (size_t k = 0; k < in_dim; ++k) {
85+
int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point;
86+
int32_t w =
87+
(int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point;
88+
sum += x * w;
89+
}
90+
out_data[i * out_dim + j] =
91+
::impl::reference::kernels::quantize<T>(sum, requant_scale, out_zero_point);
92+
}
93+
}
94+
}
95+
96+
template <typename T>
97+
inline __attribute__((always_inline)) void quantized_linear_per_tensor_(
98+
const ::executorch::aten::Tensor& src,
99+
const ::executorch::aten::Tensor& weight,
100+
const ::executorch::aten::Tensor& bias,
101+
int64_t src_zero_point,
102+
const ::executorch::aten::Tensor& weight_zero_point_t,
103+
int64_t out_multiplier,
104+
int64_t out_shift,
105+
int64_t out_zero_point,
106+
::executorch::aten::Tensor& out) {
107+
// Get the zero_point of weight.
108+
int32_t weight_zero_point = weight_zero_point_t.const_data_ptr<int32_t>()[0];
109+
quantized_linear_per_tensor_<T>(
110+
src,
111+
weight,
112+
bias,
113+
src_zero_point,
114+
weight_zero_point,
115+
out_multiplier,
116+
out_shift,
117+
out_zero_point,
118+
out);
119+
}
120+
121+
template <typename T>
122+
inline __attribute__((always_inline)) void quantized_linear_per_channel_(
123+
const ::executorch::aten::Tensor& src,
124+
const ::executorch::aten::Tensor& weight,
125+
const ::executorch::aten::Tensor& bias,
126+
int64_t src_zero_point,
127+
int64_t weight_zero_point,
128+
const ::executorch::aten::Tensor& out_multiplier,
129+
const ::executorch::aten::Tensor& out_shift,
130+
int64_t out_zero_point,
131+
::executorch::aten::Tensor& out) {
132+
// input comes in shape [leading_dims, in_dim]
133+
// weight comes in shape [out_dim, in_dim]
134+
// output comes in empty with shape [leading_dims, out_dim]
135+
// Perform matrix multiply (M x N) x (N x P)' => M x P
136+
int64_t leading_dims = getLeadingDims(src, src.dim() - 1);
137+
const int64_t out_dim = weight.size(0); // = out_dim
138+
const int64_t in_dim = weight.size(1); // = in_dim
139+
140+
const T* __restrict__ in_data = src.const_data_ptr<T>();
141+
const T* __restrict__ weight_data = weight.const_data_ptr<T>();
142+
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
143+
T* __restrict__ out_data = out.mutable_data_ptr<T>();
144+
const int32_t* __restrict__ out_multiplier_data =
145+
out_multiplier.const_data_ptr<int32_t>();
146+
const int32_t* __restrict__ out_shift_data =
147+
out_shift.const_data_ptr<int32_t>();
148+
149+
for (size_t i = 0; i < leading_dims; ++i) {
150+
for (size_t j = 0; j < out_dim; ++j) {
151+
int32_t sum = bias_data[j];
152+
for (size_t k = 0; k < in_dim; ++k) {
153+
int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point;
154+
int32_t w =
155+
(int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point;
156+
sum += x * w;
157+
}
158+
// Compute the out_scale from out_multiplier and out_shift
159+
const float out_scale =
160+
-out_multiplier_data[j] * 1.0 / (1 << 31) * pow(2, out_shift_data[j]);
161+
out_data[i * out_dim + j] =
162+
::impl::reference::kernels::quantize<T>(sum, out_scale, out_zero_point);
163+
}
164+
}
165+
}
166+
167+
template <typename T>
168+
inline __attribute__((always_inline)) void quantized_linear_(
169+
const ::executorch::aten::Tensor& src,
170+
const ::executorch::aten::Tensor& weight,
171+
const ::executorch::aten::Tensor& bias,
172+
int64_t src_zero_point,
173+
int64_t weight_zero_point,
174+
const ::executorch::aten::Tensor& out_multiplier,
175+
const ::executorch::aten::Tensor& out_shift,
176+
int64_t out_zero_point,
177+
::executorch::aten::Tensor& out) {
178+
if (out_multiplier.numel() == 1) {
179+
// Use per-tensor quantization kernel.
180+
const int32_t* __restrict__ out_multiplier_data =
181+
out_multiplier.const_data_ptr<int32_t>();
182+
const int32_t* __restrict__ out_shift_data =
183+
out_shift.const_data_ptr<int32_t>();
184+
quantized_linear_per_tensor_<T>(
185+
src,
186+
weight,
187+
bias,
188+
src_zero_point,
189+
weight_zero_point,
190+
out_multiplier_data[0],
191+
out_shift_data[0],
192+
out_zero_point,
193+
out);
194+
return;
195+
}
196+
197+
// Use per-channel quantization kernel.
198+
quantized_linear_per_channel_<T>(
199+
src,
200+
weight,
201+
bias,
202+
src_zero_point,
203+
weight_zero_point,
204+
out_multiplier,
205+
out_shift,
206+
out_zero_point,
207+
out);
208+
}
209+
210+
template <typename T>
211+
inline __attribute__((always_inline)) void quantized_linear_(
212+
const ::executorch::aten::Tensor& src,
213+
const ::executorch::aten::Tensor& weight,
214+
const ::executorch::aten::Tensor& bias,
215+
int64_t src_zero_point,
216+
const ::executorch::aten::Tensor& weight_zero_point_t,
217+
const ::executorch::aten::Tensor& out_multiplier,
218+
const ::executorch::aten::Tensor& out_shift,
219+
int64_t out_zero_point,
220+
::executorch::aten::Tensor& out) {
221+
// Get the zero_point of weight.
222+
int32_t weight_zero_point = weight_zero_point_t.const_data_ptr<int32_t>()[0];
223+
quantized_linear_<T>(
224+
src,
225+
weight,
226+
bias,
227+
src_zero_point,
228+
weight_zero_point,
229+
out_multiplier,
230+
out_shift,
231+
out_zero_point,
232+
out);
233+
}

0 commit comments

Comments
 (0)