Skip to content

Commit fc4444f

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
4b quantized embedding table operator
Summary: 4b quantized embedding table operator Differential Revision: D56123408
1 parent d0208d0 commit fc4444f

File tree

5 files changed

+518
-0
lines changed

5 files changed

+518
-0
lines changed
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <algorithm>
11+
#include <cinttypes>
12+
#include <cmath>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using Tensor = exec_aten::Tensor;
19+
using Scalar = exec_aten::Scalar;
20+
using ScalarType = exec_aten::ScalarType;
21+
22+
namespace {
23+
24+
/**
25+
* Asserts that the parameters are valid.
26+
*/
27+
void check_embedding_4bit_args(
28+
const Tensor& weight,
29+
const Tensor& weight_scales,
30+
const optional<Tensor>& opt_weight_zero_points,
31+
const int64_t weight_quant_min,
32+
const int64_t weight_quant_max,
33+
const Tensor& indices,
34+
exec_aten::optional<ScalarType> out_dtype,
35+
Tensor& out) {
36+
ET_CHECK_MSG(
37+
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
38+
39+
ET_CHECK_MSG(
40+
weight_scales.dim() == 1 || weight_scales.dim() == 2,
41+
"weight_scales must be 1D or 2D but got() %zd dims",
42+
weight_scales.dim());
43+
44+
ET_CHECK_MSG(
45+
weight_scales.size(0) == weight.size(0),
46+
"Number of scales must be == weight.size(0)=%zd"
47+
", but got %zd",
48+
weight_scales.size(0),
49+
weight.size(0));
50+
51+
if (weight_scales.dim() == 2) {
52+
auto num_groups = weight_scales.size(1);
53+
ET_CHECK_MSG(
54+
// each 8b uint8 column is 2 columns
55+
(2 * weight.size(1)) % num_groups == 0,
56+
"Number of groups must divide weight.size(1)=%zd"
57+
", but got # of groups = %zd",
58+
weight.size(1),
59+
num_groups);
60+
}
61+
62+
ET_CHECK_MSG(
63+
weight.scalar_type() == ScalarType::Byte ||
64+
weight.scalar_type() == ScalarType::Char,
65+
"weight.scalar_type() %" PRId8 " is not supported:",
66+
static_cast<int8_t>(weight.scalar_type()));
67+
68+
ET_CHECK_MSG(
69+
out.scalar_type() == ScalarType::Float ||
70+
out.scalar_type() == ScalarType::Half,
71+
"out.scalar_type() %" PRId8 " is not supported:",
72+
static_cast<int8_t>(out.scalar_type()));
73+
74+
ET_CHECK_MSG(
75+
weight_scales.scalar_type() == ScalarType::Float ||
76+
weight_scales.scalar_type() == ScalarType::Half,
77+
"weight_scales.scalar_type() %" PRId8 " is not supported:",
78+
static_cast<int8_t>(weight_scales.scalar_type()));
79+
80+
if (opt_weight_zero_points.has_value()) {
81+
ET_CHECK_MSG(
82+
opt_weight_zero_points.value().dim() == weight_scales.dim(),
83+
"weight_zero_points's rank match that of weight_scales. "
84+
"weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
85+
static_cast<int8_t>(opt_weight_zero_points.value().dim()),
86+
static_cast<int8_t>(weight_scales.dim()));
87+
88+
ET_CHECK_MSG(
89+
opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
90+
"weight zero points scalar type %" PRId8
91+
" does not match out.scalar_type()",
92+
static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
93+
94+
for (int32_t i = 0; i < weight_scales.dim(); ++i) {
95+
ET_CHECK_MSG(
96+
opt_weight_zero_points.value().size(i) == weight_scales.size(i),
97+
"Dimension size misatch at dim %" PRId8
98+
"Weight_zero_point size = %zd"
99+
", weight_scales size = %zd.",
100+
i,
101+
opt_weight_zero_points.value().size(i),
102+
weight_scales.size(i));
103+
}
104+
}
105+
106+
ET_CHECK_MSG(
107+
indices.scalar_type() == ScalarType::Long,
108+
"indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
109+
static_cast<int8_t>(indices.scalar_type()));
110+
111+
ET_CHECK_MSG(
112+
weight_quant_min <= weight_quant_max,
113+
"weight quant min: %" PRId64
114+
" is greater than weight quant max: %" PRId64,
115+
weight_quant_min,
116+
weight_quant_max);
117+
118+
if (out_dtype.has_value()) {
119+
ET_CHECK_MSG(
120+
out.scalar_type() == out_dtype.value(),
121+
"output_dtype must match the dtype of the out tensor");
122+
}
123+
}
124+
125+
static inline int32_t weight_value(const unsigned char* w_data, int32_t index) {
126+
int32_t odd = index & 1;
127+
index >>= 1;
128+
if (odd) {
129+
return (int32_t)(w_data[index] & 0x0F) - 8;
130+
} else {
131+
return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
132+
}
133+
}
134+
135+
/**
136+
* Retrieves the embeddings specified by indices, dequantizes them, and stores
137+
* them in out. Weight will always be uint8
138+
*/
139+
template <typename CTYPE_PARAMS, typename CTYPE_OUT>
140+
void embedding_4bit_per_channel(
141+
const Tensor& weight,
142+
const Tensor& weight_scales,
143+
const optional<Tensor>& opt_weight_zero_points,
144+
const Tensor& indices,
145+
Tensor& out) {
146+
// An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a
147+
// weight of shape (num_embeddings, embedding_dim / 2).
148+
auto embedding_dim = weight.size(1) * 2;
149+
150+
int32_t num_groups_per_channel = 1;
151+
if (weight_scales.dim() == 2) {
152+
num_groups_per_channel = weight_scales.size(1);
153+
}
154+
int32_t group_size = embedding_dim / num_groups_per_channel;
155+
156+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
157+
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
158+
159+
const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
160+
const CTYPE_PARAMS* zero_points = nullptr;
161+
if (opt_weight_zero_points.has_value()) {
162+
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
163+
}
164+
165+
for (int i = 0; i < indices.numel(); i++) {
166+
int64_t index = indices_ptr[i];
167+
// If using groupwise embedding
168+
int32_t qparams_index = index * num_groups_per_channel;
169+
CTYPE_PARAMS zp = 0.0;
170+
const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
171+
const CTYPE_PARAMS* zero_points_ptr = nullptr;
172+
if (opt_weight_zero_points.has_value()) {
173+
zero_points_ptr = zero_points + qparams_index;
174+
}
175+
176+
const uint8_t* w_data = weight.data_ptr<uint8_t>() + weight.size(1) * index;
177+
178+
for (int j = 0; j < embedding_dim; ++j) {
179+
int32_t group_id = j / group_size;
180+
const CTYPE_PARAMS scale = scale_ptr[group_id];
181+
if (opt_weight_zero_points.has_value()) {
182+
zp = zero_points_ptr[group_id];
183+
}
184+
out_data[j] = static_cast<CTYPE_OUT>(
185+
(static_cast<float>(weight_value(w_data, j)) -
186+
static_cast<float>(zp)) *
187+
static_cast<float>(scale));
188+
}
189+
out_data += embedding_dim;
190+
}
191+
}
192+
193+
void resize_out_tensor(
194+
const Tensor& weight,
195+
const Tensor& indices,
196+
Tensor& out) {
197+
exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
198+
for (size_t i = 0; i < indices.dim(); i++) {
199+
expected_output_size[i] = indices.size(i);
200+
}
201+
const size_t embedding_dim = weight.size(1);
202+
expected_output_size[out.dim() - 1] = embedding_dim;
203+
204+
exec_aten::ArrayRef<exec_aten::SizesType> output_size{
205+
expected_output_size, static_cast<size_t>(out.dim())};
206+
207+
torch::executor::Error err = resize_tensor(out, output_size);
208+
ET_CHECK_MSG(
209+
err == torch::executor::Error::Ok,
210+
"Failed to resize out Tensor in quantized_embedding_4bit_out");
211+
}
212+
213+
} // namespace
214+
215+
/**
216+
* Retrieves the embeddings specified by indices, dequantizes them, and stores
217+
* them in out. The weight is quantized per channel, with a scale and zero_point
218+
* for each embedding.
219+
*
220+
* Corresponds as the out variant to torch.ops.quantized.embedding_4bit
221+
*
222+
* NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
223+
* metadata that is passed around which can be useful for pattern matching. See
224+
* https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
225+
* info.
226+
*/
227+
Tensor& quantized_embedding_4bit_out(
228+
// TODO Evaluate whether this name is appropriate for an operator that takes
229+
// non quant input and returns fp output
230+
const Tensor& weight,
231+
const Tensor& weight_scales,
232+
const optional<Tensor>& opt_weight_zero_points,
233+
const int64_t weight_quant_min,
234+
const int64_t weight_quant_max,
235+
const Tensor& indices,
236+
Tensor& out) {
237+
ScalarType w_type = weight.scalar_type();
238+
ScalarType out_type = out.scalar_type();
239+
240+
// TODO (jakeszwe): improve these to account for the size of out in relation
241+
// to weight and indices accounting for a possible batch dimension
242+
check_embedding_4bit_args(
243+
weight,
244+
weight_scales,
245+
opt_weight_zero_points,
246+
weight_quant_min,
247+
weight_quant_max,
248+
indices,
249+
out_type,
250+
out);
251+
252+
constexpr auto name = "quantized_decomposed::embedding_4bit.out";
253+
ET_SWITCH_TWO_TYPES(Byte, Char, w_type, ctx, name, CTYPE_W, [&]() {
254+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
255+
embedding_4bit_per_channel<CTYPE_OUT, CTYPE_OUT>(
256+
weight, weight_scales, opt_weight_zero_points, indices, out);
257+
});
258+
});
259+
260+
return out;
261+
}
262+
263+
Tensor& quantized_embedding_4bit_out(
264+
RuntimeContext& context,
265+
const Tensor& weight,
266+
const Tensor& weight_scales,
267+
const optional<Tensor>& opt_weight_zero_points,
268+
int64_t weight_quant_min,
269+
int64_t weight_quant_max,
270+
const Tensor& indices,
271+
Tensor& out) {
272+
// TODO(larryliu): Add a context arg to the real op function and remove this
273+
// wrapper
274+
(void)context;
275+
resize_out_tensor(weight, indices, out);
276+
return quantized_embedding_4bit_out(
277+
weight,
278+
weight_scales,
279+
opt_weight_zero_points,
280+
weight_quant_min,
281+
weight_quant_max,
282+
indices,
283+
out);
284+
}
285+
286+
Tensor& quantized_embedding_4bit_dtype_out(
287+
// TODO Evaluate whether this name is appropriate for an operator that takes
288+
// non quant input and returns fp output
289+
const Tensor& weight,
290+
const Tensor& weight_scales,
291+
const optional<Tensor>& opt_weight_zero_points,
292+
const int64_t weight_quant_min,
293+
const int64_t weight_quant_max,
294+
const Tensor& indices,
295+
exec_aten::optional<ScalarType> out_dtype,
296+
Tensor& out) {
297+
// TODO (jakeszwe): improve these to account for the size of out in relation
298+
// to weight and indices accounting for a possible batch dimension
299+
check_embedding_4bit_args(
300+
weight,
301+
weight_scales,
302+
opt_weight_zero_points,
303+
weight_quant_min,
304+
weight_quant_max,
305+
indices,
306+
out_dtype,
307+
out);
308+
309+
ScalarType weight_type = weight.scalar_type();
310+
ScalarType params_type = weight_scales.scalar_type();
311+
ScalarType out_type = out.scalar_type();
312+
313+
constexpr auto name = "quantized_decomposed::embedding_4bit.dtype_out";
314+
ET_SWITCH_TWO_TYPES(Byte, Char, weight_type, ctx, name, CTYPE_W, [&]() {
315+
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
316+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
317+
embedding_4bit_per_channel<CTYPE_P, CTYPE_OUT>(
318+
weight, weight_scales, opt_weight_zero_points, indices, out);
319+
});
320+
});
321+
});
322+
323+
return out;
324+
}
325+
326+
Tensor& quantized_embedding_4bit_dtype_out(
327+
RuntimeContext& context,
328+
const Tensor& weight,
329+
const Tensor& weight_scales,
330+
const optional<Tensor>& opt_weight_zero_points,
331+
int64_t weight_quant_min,
332+
int64_t weight_quant_max,
333+
const Tensor& indices,
334+
exec_aten::optional<ScalarType> out_dtype,
335+
Tensor& out) {
336+
// TODO(larryliu): Add a context arg to the real op function and remove this
337+
// wrapper
338+
(void)context;
339+
resize_out_tensor(weight, indices, out);
340+
return quantized_embedding_4bit_dtype_out(
341+
weight,
342+
weight_scales,
343+
opt_weight_zero_points,
344+
weight_quant_min,
345+
weight_quant_max,
346+
indices,
347+
out_dtype,
348+
out);
349+
}
350+
351+
} // namespace native
352+
} // namespace executor
353+
} // namespace torch

kernels/quantized/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ _QUANT_OPS = (
2323
op_target(
2424
name = "op_embedding",
2525
),
26+
op_target(
27+
name = "op_embedding4b",
28+
),
2629
op_target(
2730
name = "op_mixed_mm",
2831
deps = [

kernels/quantized/quantized.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@
4646
- arg_meta: null
4747
kernel_name: torch::executor::quantized_embedding_byte_dtype_out
4848

49+
- func: quantized_decomposed::embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
50+
variants: function
51+
kernels:
52+
- arg_meta: null
53+
kernel_name: torch::executor::quantized_embedding_4bit_out
54+
55+
- func: quantized_decomposed::embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
56+
variants: function
57+
kernels:
58+
- arg_meta: null
59+
kernel_name: torch::executor::quantized_embedding_4bit_dtype_out
60+
4961
- func: quantized_decomposed::mixed_mm.out(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, *, Tensor(a!) out) -> Tensor(a!)
5062
variants: function
5163
kernels:

0 commit comments

Comments
 (0)