Skip to content

Commit 6b6b1c7

Browse files
authored
open source quantized_fully_connected
Differential Revision: D69085419 Pull Request resolved: #8164
1 parent 56baff7 commit 6b6b1c7

File tree

3 files changed

+281
-3
lines changed

3 files changed

+281
-3
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::HiFi::full_out
74-
74+
7575
- op: gt.Scalar_out
7676
kernels:
7777
- arg_meta: null
78-
kernel_name: torch::executor::gt_scalar_out
78+
kernel_name: torch::executor::gt_scalar_out
7979

8080
- op: gelu.out
8181
kernels:
@@ -100,7 +100,7 @@
100100
- op: mean.out
101101
kernels:
102102
- arg_meta: null
103-
kernel_name: cadence::impl::HiFi::mean_dim_out
103+
kernel_name: cadence::impl::HiFi::mean_dim_out
104104

105105
- op: minimum.out
106106
kernels:
@@ -213,3 +213,13 @@
213213
kernels:
214214
- arg_meta: null
215215
kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out
216+
217+
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
218+
kernels:
219+
- arg_meta: null
220+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
221+
222+
- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
223+
kernels:
224+
- arg_meta: null
225+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
3+
#include <executorch/runtime/kernel/kernel_includes.h>
4+
5+
#include <algorithm>
6+
#include <cmath>
7+
8+
namespace cadence {
9+
namespace impl {
10+
namespace HiFi {
11+
namespace native {
12+
13+
using ::executorch::aten::ArrayRef;
14+
using ::executorch::aten::IntArrayRef;
15+
using ::executorch::aten::optional;
16+
using ::executorch::aten::Scalar;
17+
using ::executorch::aten::ScalarType;
18+
using ::executorch::aten::SizesType;
19+
using ::executorch::aten::Tensor;
20+
using ::executorch::runtime::KernelRuntimeContext;
21+
22+
void inline _quantized_fully_connected_asym8u(
23+
const Tensor& in,
24+
const Tensor& weight,
25+
const Tensor& bias,
26+
int64_t in_zero_point,
27+
const Tensor& weight_zero_point,
28+
const Tensor& out_multiplier,
29+
const Tensor& out_shift,
30+
int64_t out_zero_point,
31+
__ET_UNUSED const optional<Tensor>& offset,
32+
Tensor& out) {
33+
// input comes in shape [leading_dims, in_dim]
34+
// weight comes in shape [out_dim, in_dim]
35+
// output comes in empty with shape [leading_dims, out_dim]
36+
// Perform matrix multiply (M x N) x (N x P)' => M x P
37+
int64_t leading_dims = 1;
38+
int64_t out_dim = weight.size(0); // = out_dim
39+
int64_t in_dim = weight.size(1); // = in_dim
40+
41+
const uint8_t* __restrict__ in_data = in.const_data_ptr<uint8_t>();
42+
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
43+
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
44+
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
45+
46+
int32_t ret = xa_nn_fully_connected_asym8uxasym8u_asym8u(
47+
out_data,
48+
weight_data,
49+
in_data,
50+
bias_data,
51+
in_dim, // weight_depth, number of columns in weight
52+
out_dim, // out_depth, number of rows in weight
53+
-in_zero_point,
54+
-weight_zero_point.const_data_ptr<int32_t>()[0],
55+
out_multiplier.const_data_ptr<int32_t>()[0],
56+
out_shift.const_data_ptr<int32_t>()[0],
57+
out_zero_point);
58+
ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed");
59+
}
60+
61+
void inline _quantized_fully_connected_asym8s(
62+
const Tensor& in,
63+
const Tensor& weight,
64+
const Tensor& bias,
65+
int64_t in_zero_point,
66+
const Tensor& weight_zero_point,
67+
const Tensor& out_multiplier,
68+
const Tensor& out_shift,
69+
int64_t out_zero_point,
70+
__ET_UNUSED const optional<Tensor>& offset,
71+
Tensor& out) {
72+
// input comes in shape [leading_dims, in_dim]
73+
// weight comes in shape [out_dim, in_dim]
74+
// output comes in empty with shape [leading_dims, out_dim]
75+
// Perform matrix multiply (M x N) x (N x P)' => M x P
76+
int64_t leading_dims = 1;
77+
int64_t out_dim = weight.size(0); // = out_dim
78+
int64_t in_dim = weight.size(1); // = in_dim
79+
80+
const int8_t* __restrict__ in_data = in.const_data_ptr<int8_t>();
81+
const int8_t* __restrict__ weight_data = weight.const_data_ptr<int8_t>();
82+
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
83+
int8_t* __restrict__ out_data = out.mutable_data_ptr<int8_t>();
84+
85+
int32_t ret = xa_nn_fully_connected_asym8sxasym8s_asym8s(
86+
out_data,
87+
weight_data,
88+
in_data,
89+
bias_data,
90+
in_dim, // weight_depth, number of columns in weight
91+
out_dim, // out_depth, number of rows in weight
92+
-in_zero_point,
93+
-weight_zero_point.const_data_ptr<int32_t>()[0],
94+
out_multiplier.const_data_ptr<int32_t>()[0],
95+
out_shift.const_data_ptr<int32_t>()[0],
96+
out_zero_point);
97+
ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed");
98+
}
99+
100+
void quantized_fully_connected_out(
101+
__ET_UNUSED KernelRuntimeContext& ctx,
102+
const Tensor& in,
103+
const Tensor& weight,
104+
const Tensor& bias,
105+
int64_t in_zero_point,
106+
const Tensor& weight_zero_point,
107+
const Tensor& out_multiplier,
108+
const Tensor& out_shift,
109+
int64_t out_zero_point,
110+
__ET_UNUSED const optional<Tensor>& offset,
111+
Tensor& out) {
112+
if (out.scalar_type() == ScalarType::Byte) {
113+
_quantized_fully_connected_asym8u(
114+
in,
115+
weight,
116+
bias,
117+
in_zero_point,
118+
weight_zero_point,
119+
out_multiplier,
120+
out_shift,
121+
out_zero_point,
122+
offset,
123+
out);
124+
} else if (out.scalar_type() == ScalarType::Char) {
125+
_quantized_fully_connected_asym8s(
126+
in,
127+
weight,
128+
bias,
129+
in_zero_point,
130+
weight_zero_point,
131+
out_multiplier,
132+
out_shift,
133+
out_zero_point,
134+
offset,
135+
out);
136+
} else {
137+
ET_CHECK_MSG(
138+
false,
139+
"quantized fully connected only supported for uint8 and int8 dtypes");
140+
}
141+
}
142+
143+
void inline _quantized_fully_connected_per_tensor_asym8u(
144+
const Tensor& in,
145+
const Tensor& weight,
146+
const Tensor& bias,
147+
int64_t in_zero_point,
148+
int64_t weight_zero_point,
149+
int64_t out_multiplier,
150+
int64_t out_shift,
151+
int64_t out_zero_point,
152+
__ET_UNUSED const optional<Tensor>& offset,
153+
Tensor& out) {
154+
// input comes in shape [leading_dims, in_dim]
155+
// weight comes in shape [out_dim, in_dim]
156+
// output comes in empty with shape [leading_dims, out_dim]
157+
// Perform matrix multiply (M x N) x (N x P)' => M x P
158+
int64_t leading_dims = 1;
159+
int64_t out_dim = weight.size(0); // = out_dim
160+
int64_t in_dim = weight.size(1); // = in_dim
161+
162+
const uint8_t* __restrict__ in_data = in.const_data_ptr<uint8_t>();
163+
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
164+
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
165+
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
166+
167+
int32_t ret = xa_nn_fully_connected_asym8uxasym8u_asym8u(
168+
out_data,
169+
weight_data,
170+
in_data,
171+
bias_data,
172+
in_dim, // weight_depth, number of columns in weight
173+
out_dim, // out_depth, number of rows in weight
174+
-in_zero_point,
175+
-static_cast<int32_t>(weight_zero_point),
176+
static_cast<int32_t>(out_multiplier),
177+
static_cast<int32_t>(out_shift),
178+
out_zero_point);
179+
ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed");
180+
}
181+
182+
void inline _quantized_fully_connected_per_tensor_asym8s(
183+
const Tensor& in,
184+
const Tensor& weight,
185+
const Tensor& bias,
186+
int64_t in_zero_point,
187+
int64_t weight_zero_point,
188+
int64_t out_multiplier,
189+
int64_t out_shift,
190+
int64_t out_zero_point,
191+
__ET_UNUSED const optional<Tensor>& offset,
192+
Tensor& out) {
193+
// input comes in shape [leading_dims, in_dim]
194+
// weight comes in shape [out_dim, in_dim]
195+
// output comes in empty with shape [leading_dims, out_dim]
196+
// Perform matrix multiply (M x N) x (N x P)' => M x P
197+
int64_t leading_dims = 1;
198+
int64_t out_dim = weight.size(0); // = out_dim
199+
int64_t in_dim = weight.size(1); // = in_dim
200+
201+
const int8_t* __restrict__ in_data = in.const_data_ptr<int8_t>();
202+
const int8_t* __restrict__ weight_data = weight.const_data_ptr<int8_t>();
203+
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
204+
int8_t* __restrict__ out_data = out.mutable_data_ptr<int8_t>();
205+
206+
int32_t ret = xa_nn_fully_connected_asym8sxasym8s_asym8s(
207+
out_data,
208+
weight_data,
209+
in_data,
210+
bias_data,
211+
in_dim, // weight_depth, number of columns in weight
212+
out_dim, // out_depth, number of rows in weight
213+
-in_zero_point,
214+
-static_cast<int32_t>(weight_zero_point),
215+
static_cast<int32_t>(out_multiplier),
216+
static_cast<int32_t>(out_shift),
217+
out_zero_point);
218+
ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed");
219+
}
220+
221+
void quantized_fully_connected_per_tensor_out(
222+
__ET_UNUSED KernelRuntimeContext& ctx,
223+
const Tensor& in,
224+
const Tensor& weight,
225+
const Tensor& bias,
226+
int64_t in_zero_point,
227+
int64_t weight_zero_point,
228+
int64_t out_multiplier,
229+
int64_t out_shift,
230+
int64_t out_zero_point,
231+
__ET_UNUSED const optional<Tensor>& offset,
232+
Tensor& out) {
233+
if (out.scalar_type() == ScalarType::Byte) {
234+
_quantized_fully_connected_per_tensor_asym8u(
235+
in,
236+
weight,
237+
bias,
238+
in_zero_point,
239+
weight_zero_point,
240+
out_multiplier,
241+
out_shift,
242+
out_zero_point,
243+
offset,
244+
out);
245+
} else if (out.scalar_type() == ScalarType::Char) {
246+
_quantized_fully_connected_per_tensor_asym8s(
247+
in,
248+
weight,
249+
bias,
250+
in_zero_point,
251+
weight_zero_point,
252+
out_multiplier,
253+
out_shift,
254+
out_zero_point,
255+
offset,
256+
out);
257+
} else {
258+
ET_CHECK_MSG(
259+
false,
260+
"quantized fully connected only supported for uint8 and int8 dtypes");
261+
}
262+
}
263+
264+
} // namespace native
265+
} // namespace HiFi
266+
} // namespace impl
267+
} // namespace cadence

backends/cadence/hifi/operators/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ OPERATORS = [
4545
"mul",
4646
"permute_copy",
4747
"pow",
48+
"quantized_fully_connected_out",
4849
"quantize_per_tensor",
4950
"quantized_layer_norm",
5051
"quantized_linear_out",

0 commit comments

Comments
 (0)