Skip to content

Commit 3b31eff

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
4b quantized embedding table operator (#3050)
Summary: Pull Request resolved: #3050 4b quantized embedding table operator Reviewed By: mikekgfb Differential Revision: D56123408 fbshipit-source-id: 26293e2b09f93ccb8f14462de7ae0969efc7acc5
1 parent 458d743 commit 3b31eff

File tree

5 files changed

+501
-0
lines changed

5 files changed

+501
-0
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <algorithm>
11+
#include <cinttypes>
12+
#include <cmath>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using Tensor = exec_aten::Tensor;
19+
using Scalar = exec_aten::Scalar;
20+
using ScalarType = exec_aten::ScalarType;
21+
22+
namespace {
23+
24+
/**
25+
* Asserts that the parameters are valid.
26+
*/
27+
void check_embedding_4bit_args(
28+
const Tensor& weight,
29+
const Tensor& weight_scales,
30+
const optional<Tensor>& opt_weight_zero_points,
31+
const int64_t weight_quant_min,
32+
const int64_t weight_quant_max,
33+
const Tensor& indices,
34+
exec_aten::optional<ScalarType> out_dtype,
35+
Tensor& out) {
36+
ET_CHECK_MSG(
37+
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
38+
39+
ET_CHECK_MSG(
40+
weight_scales.dim() == 1 || weight_scales.dim() == 2,
41+
"weight_scales must be 1D or 2D but got() %zd dims",
42+
weight_scales.dim());
43+
44+
ET_CHECK_MSG(
45+
weight_scales.size(0) == weight.size(0),
46+
"Number of scales must be == weight.size(0)=%zd"
47+
", but got %zd",
48+
weight_scales.size(0),
49+
weight.size(0));
50+
51+
if (weight_scales.dim() == 2) {
52+
auto num_groups = weight_scales.size(1);
53+
ET_CHECK_MSG(
54+
// each 8b uint8 column is 2 columns
55+
(2 * weight.size(1)) % num_groups == 0,
56+
"Number of groups must divide weight.size(1)=%zd"
57+
", but got # of groups = %zd",
58+
weight.size(1),
59+
num_groups);
60+
}
61+
62+
ET_CHECK_MSG(
63+
weight.scalar_type() == ScalarType::Byte,
64+
"weight.scalar_type() %" PRId8 " is not supported:",
65+
static_cast<int8_t>(weight.scalar_type()));
66+
67+
ET_CHECK_MSG(
68+
out.scalar_type() == ScalarType::Float ||
69+
out.scalar_type() == ScalarType::Half,
70+
"out.scalar_type() %" PRId8 " is not supported:",
71+
static_cast<int8_t>(out.scalar_type()));
72+
73+
ET_CHECK_MSG(
74+
weight_scales.scalar_type() == ScalarType::Float ||
75+
weight_scales.scalar_type() == ScalarType::Half,
76+
"weight_scales.scalar_type() %" PRId8 " is not supported:",
77+
static_cast<int8_t>(weight_scales.scalar_type()));
78+
79+
if (opt_weight_zero_points.has_value()) {
80+
ET_CHECK_MSG(
81+
opt_weight_zero_points.value().dim() == weight_scales.dim(),
82+
"weight_zero_points's rank match that of weight_scales. "
83+
"weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
84+
static_cast<int8_t>(opt_weight_zero_points.value().dim()),
85+
static_cast<int8_t>(weight_scales.dim()));
86+
87+
ET_CHECK_MSG(
88+
opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
89+
"weight zero points scalar type %" PRId8
90+
" does not match out.scalar_type()",
91+
static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
92+
93+
for (int32_t i = 0; i < weight_scales.dim(); ++i) {
94+
ET_CHECK_MSG(
95+
opt_weight_zero_points.value().size(i) == weight_scales.size(i),
96+
"Dimension size misatch at dim %" PRId8
97+
"Weight_zero_point size = %zd"
98+
", weight_scales size = %zd.",
99+
i,
100+
opt_weight_zero_points.value().size(i),
101+
weight_scales.size(i));
102+
}
103+
}
104+
105+
ET_CHECK_MSG(
106+
indices.scalar_type() == ScalarType::Long,
107+
"indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
108+
static_cast<int8_t>(indices.scalar_type()));
109+
110+
ET_CHECK_MSG(
111+
weight_quant_min <= weight_quant_max,
112+
"weight quant min: %" PRId64
113+
" is greater than weight quant max: %" PRId64,
114+
weight_quant_min,
115+
weight_quant_max);
116+
117+
if (out_dtype.has_value()) {
118+
ET_CHECK_MSG(
119+
out.scalar_type() == out_dtype.value(),
120+
"output_dtype must match the dtype of the out tensor");
121+
}
122+
}
123+
124+
static inline int32_t weight_value(const unsigned char* w_data, int32_t index) {
125+
int32_t odd = index & 1;
126+
index >>= 1;
127+
if (odd) {
128+
return (int32_t)(w_data[index] & 0x0F) - 8;
129+
} else {
130+
return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
131+
}
132+
}
133+
134+
/**
135+
* Retrieves the embeddings specified by indices, dequantizes them, and stores
136+
* them in out. Weight will always be uint8
137+
*/
138+
template <typename CTYPE_PARAMS, typename CTYPE_OUT>
139+
void embedding_4bit_per_channel(
140+
const Tensor& weight,
141+
const Tensor& weight_scales,
142+
const optional<Tensor>& opt_weight_zero_points,
143+
const Tensor& indices,
144+
Tensor& out) {
145+
auto embedding_dim = weight.size(1) * 2;
146+
147+
int32_t num_groups_per_channel = 1;
148+
if (weight_scales.dim() == 2) {
149+
num_groups_per_channel = weight_scales.size(1);
150+
}
151+
int32_t group_size = embedding_dim / num_groups_per_channel;
152+
153+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
154+
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
155+
156+
const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
157+
const CTYPE_PARAMS* zero_points = nullptr;
158+
if (opt_weight_zero_points.has_value()) {
159+
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
160+
}
161+
162+
for (int i = 0; i < indices.numel(); i++) {
163+
int64_t index = indices_ptr[i];
164+
// If using groupwise embedding
165+
int32_t qparams_index = index * num_groups_per_channel;
166+
CTYPE_PARAMS zp = 0.0;
167+
const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
168+
const CTYPE_PARAMS* zero_points_ptr = nullptr;
169+
if (opt_weight_zero_points.has_value()) {
170+
zero_points_ptr = zero_points + qparams_index;
171+
}
172+
173+
const uint8_t* w_data = weight.data_ptr<uint8_t>() + weight.size(1) * index;
174+
175+
for (int j = 0; j < embedding_dim; ++j) {
176+
int32_t group_id = j / group_size;
177+
const CTYPE_PARAMS scale = scale_ptr[group_id];
178+
if (opt_weight_zero_points.has_value()) {
179+
zp = zero_points_ptr[group_id];
180+
}
181+
out_data[j] = static_cast<CTYPE_OUT>(
182+
(static_cast<float>(weight_value(w_data, j)) -
183+
static_cast<float>(zp)) *
184+
static_cast<float>(scale));
185+
}
186+
out_data += embedding_dim;
187+
}
188+
}
189+
190+
void resize_out_tensor(
191+
const Tensor& weight,
192+
const Tensor& indices,
193+
Tensor& out) {
194+
exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
195+
for (size_t i = 0; i < indices.dim(); i++) {
196+
expected_output_size[i] = indices.size(i);
197+
}
198+
const size_t embedding_dim = weight.size(1);
199+
expected_output_size[out.dim() - 1] = embedding_dim;
200+
201+
exec_aten::ArrayRef<exec_aten::SizesType> output_size{
202+
expected_output_size, static_cast<size_t>(out.dim())};
203+
204+
torch::executor::Error err = resize_tensor(out, output_size);
205+
ET_CHECK_MSG(
206+
err == torch::executor::Error::Ok,
207+
"Failed to resize out Tensor in quantized_embedding_4bit_out");
208+
}
209+
210+
} // namespace
211+
212+
/**
213+
* Retrieves the embeddings specified by indices, dequantizes them, and stores
214+
* them in out. The weight is quantized per channel, with a scale and zero_point
215+
* for each embedding.
216+
*
217+
* Corresponds as the out variant to torch.ops.quantized.embedding_4bit
218+
*
219+
* NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
220+
* metadata that is passed around which can be useful for pattern matching. See
221+
* https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
222+
* info.
223+
*/
224+
Tensor& quantized_embedding_4bit_out(
225+
// TODO Evaluate whether this name is appropriate for an operator that takes
226+
// non quant input and returns fp output
227+
const Tensor& weight,
228+
const Tensor& weight_scales,
229+
const optional<Tensor>& opt_weight_zero_points,
230+
const int64_t weight_quant_min,
231+
const int64_t weight_quant_max,
232+
const Tensor& indices,
233+
Tensor& out) {
234+
ScalarType out_type = out.scalar_type();
235+
236+
// TODO (jakeszwe): improve these to account for the size of out in relation
237+
// to weight and indices accounting for a possible batch dimension
238+
check_embedding_4bit_args(
239+
weight,
240+
weight_scales,
241+
opt_weight_zero_points,
242+
weight_quant_min,
243+
weight_quant_max,
244+
indices,
245+
out_type,
246+
out);
247+
248+
constexpr auto name = "quantized_decomposed::embedding_4bit.out";
249+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
250+
embedding_4bit_per_channel<CTYPE_OUT, CTYPE_OUT>(
251+
weight, weight_scales, opt_weight_zero_points, indices, out);
252+
});
253+
254+
return out;
255+
}
256+
257+
Tensor& quantized_embedding_4bit_out(
258+
RuntimeContext& context,
259+
const Tensor& weight,
260+
const Tensor& weight_scales,
261+
const optional<Tensor>& opt_weight_zero_points,
262+
int64_t weight_quant_min,
263+
int64_t weight_quant_max,
264+
const Tensor& indices,
265+
Tensor& out) {
266+
// TODO(larryliu): Add a context arg to the real op function and remove this
267+
// wrapper
268+
(void)context;
269+
resize_out_tensor(weight, indices, out);
270+
return quantized_embedding_4bit_out(
271+
weight,
272+
weight_scales,
273+
opt_weight_zero_points,
274+
weight_quant_min,
275+
weight_quant_max,
276+
indices,
277+
out);
278+
}
279+
280+
Tensor& quantized_embedding_4bit_dtype_out(
281+
// TODO Evaluate whether this name is appropriate for an operator that takes
282+
// non quant input and returns fp output
283+
const Tensor& weight,
284+
const Tensor& weight_scales,
285+
const optional<Tensor>& opt_weight_zero_points,
286+
const int64_t weight_quant_min,
287+
const int64_t weight_quant_max,
288+
const Tensor& indices,
289+
exec_aten::optional<ScalarType> out_dtype,
290+
Tensor& out) {
291+
// TODO (jakeszwe): improve these to account for the size of out in relation
292+
// to weight and indices accounting for a possible batch dimension
293+
check_embedding_4bit_args(
294+
weight,
295+
weight_scales,
296+
opt_weight_zero_points,
297+
weight_quant_min,
298+
weight_quant_max,
299+
indices,
300+
out_dtype,
301+
out);
302+
303+
ScalarType params_type = weight_scales.scalar_type();
304+
ScalarType out_type = out.scalar_type();
305+
306+
constexpr auto name = "quantized_decomposed::embedding_4bit.dtype_out";
307+
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
308+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
309+
embedding_4bit_per_channel<CTYPE_P, CTYPE_OUT>(
310+
weight, weight_scales, opt_weight_zero_points, indices, out);
311+
});
312+
});
313+
314+
return out;
315+
}
316+
317+
Tensor& quantized_embedding_4bit_dtype_out(
318+
RuntimeContext& context,
319+
const Tensor& weight,
320+
const Tensor& weight_scales,
321+
const optional<Tensor>& opt_weight_zero_points,
322+
int64_t weight_quant_min,
323+
int64_t weight_quant_max,
324+
const Tensor& indices,
325+
exec_aten::optional<ScalarType> out_dtype,
326+
Tensor& out) {
327+
// TODO(larryliu): Add a context arg to the real op function and remove this
328+
// wrapper
329+
(void)context;
330+
resize_out_tensor(weight, indices, out);
331+
return quantized_embedding_4bit_dtype_out(
332+
weight,
333+
weight_scales,
334+
opt_weight_zero_points,
335+
weight_quant_min,
336+
weight_quant_max,
337+
indices,
338+
out_dtype,
339+
out);
340+
}
341+
342+
} // namespace native
343+
} // namespace executor
344+
} // namespace torch

kernels/quantized/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ _QUANT_OPS = (
2323
op_target(
2424
name = "op_embedding",
2525
),
26+
op_target(
27+
name = "op_embedding4b",
28+
),
2629
op_target(
2730
name = "op_mixed_mm",
2831
deps = [

kernels/quantized/quantized.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@
4646
- arg_meta: null
4747
kernel_name: torch::executor::quantized_embedding_byte_dtype_out
4848

49+
- func: quantized_decomposed::embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
50+
variants: function
51+
kernels:
52+
- arg_meta: null
53+
kernel_name: torch::executor::quantized_embedding_4bit_out
54+
55+
- func: quantized_decomposed::embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
56+
variants: function
57+
kernels:
58+
- arg_meta: null
59+
kernel_name: torch::executor::quantized_embedding_4bit_dtype_out
60+
4961
- func: quantized_decomposed::mixed_mm.out(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, *, Tensor(a!) out) -> Tensor(a!)
5062
variants: function
5163
kernels:

0 commit comments

Comments
 (0)