Skip to content

Commit 2966e38

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable embedding_byte output dtype be different than scales/zp dtype (#2091)
Summary: Pull Request resolved: #2091 Reviewed By: mikekgfb, cbilgin Differential Revision: D54141337 fbshipit-source-id: f79754770ddca459e0e23680b42f84d6ff5ce21a
1 parent 5a18cc6 commit 2966e38

File tree

7 files changed

+99
-23
lines changed

7 files changed

+99
-23
lines changed

examples/models/llama2/ops/quantized.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
1+
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
22
variants: function
33
kernels:
44
- arg_meta: null

examples/models/llama2/ops/quantized_ops.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
) # to not be confused with torch.ops.quantized.* ops.
1515
quantized_lib.define(
1616
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
17-
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
17+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
1818
)
1919

2020
quantized_lib.define(
2121
"embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
22-
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
22+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
2323
)
2424

2525

@@ -31,6 +31,8 @@ def embedding_byte_meta(
3131
weight_quant_min,
3232
weight_quant_max,
3333
indices,
34+
*,
35+
dtype,
3436
):
3537
assert weight.dtype in [
3638
torch.int8,
@@ -71,7 +73,7 @@ def embedding_byte_meta(
7173
weight_quant_max,
7274
weight.dtype,
7375
)
74-
return torch.ops.aten.embedding.default(weight, indices)
76+
return torch.ops.aten.embedding.default(weight, indices).to(dtype)
7577

7678

7779
@impl_abstract("llama_quantized::embedding_byte.out")
@@ -82,6 +84,8 @@ def embedding_byte_out_meta(
8284
weight_quant_min,
8385
weight_quant_max,
8486
indices,
87+
*,
88+
dtype,
8589
out,
8690
):
8791
return embedding_byte_meta(
@@ -91,4 +95,5 @@ def embedding_byte_out_meta(
9195
weight_quant_min,
9296
weight_quant_max,
9397
indices,
98+
dtype=dtype,
9499
)

examples/models/llama2/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,8 +818,8 @@ def __init__(
818818
@torch.no_grad()
819819
def forward(self, indices: torch.Tensor) -> torch.Tensor:
820820
return torch.ops.llama_quantized.embedding_byte.default(
821-
self.weight, self.scales, None, 0, 0, indices
822-
).to(self.dtype)
821+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
822+
)
823823

824824

825825
# result_weights = self.weight.index_select(0, indices.view(-1))

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
quantized_decomposed_lib.define(
2929
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
30-
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
30+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
3131
)
3232

3333
quantized_decomposed_lib.define(
@@ -482,6 +482,48 @@ def replacement(
482482
)
483483
return out
484484

485+
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
486+
def pattern_with_dtype(
487+
weight,
488+
weight_scales,
489+
weight_zero_points,
490+
weight_quant_min,
491+
weight_quant_max,
492+
indicies,
493+
dtype,
494+
):
495+
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
496+
weight,
497+
weight_scales,
498+
weight_zero_points,
499+
0,
500+
weight_quant_min,
501+
weight_quant_max,
502+
torch.uint8,
503+
)
504+
out = torch.ops.aten.embedding.default(weight, indicies).to(dtype)
505+
return out
506+
507+
def replacement_with_dtype(
508+
weight,
509+
weight_scales,
510+
weight_zero_points,
511+
weight_quant_min,
512+
weight_quant_max,
513+
indicies,
514+
dtype,
515+
):
516+
out = torch.ops.quantized_decomposed.embedding_byte.default(
517+
weight,
518+
weight_scales,
519+
weight_zero_points,
520+
weight_quant_min,
521+
weight_quant_max,
522+
indicies,
523+
dtype=dtype,
524+
)
525+
return out
526+
485527
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
486528
def pattern_with_padding_idx(
487529
weight,
@@ -529,6 +571,11 @@ def replacement_with_padding_idx(
529571
_trace_and_lower_to_edge_ops(replacement),
530572
[],
531573
),
574+
(
575+
_trace_and_lower_to_edge_ops(pattern_with_dtype),
576+
_trace_and_lower_to_edge_ops(replacement_with_dtype),
577+
[],
578+
),
532579
(
533580
_trace_and_lower_to_edge_ops(pattern_with_padding_idx),
534581
_trace_and_lower_to_edge_ops(replacement_with_padding_idx),

kernels/quantized/cpu/op_embedding.cpp

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ void check_embedding_byte_args(
3131
const int64_t weight_quant_min,
3232
const int64_t weight_quant_max,
3333
const Tensor& indices,
34+
exec_aten::optional<ScalarType> out_dtype,
3435
Tensor& out) {
3536
ET_CHECK_MSG(
3637
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
@@ -75,8 +76,9 @@ void check_embedding_byte_args(
7576
static_cast<int8_t>(out.scalar_type()));
7677

7778
ET_CHECK_MSG(
78-
weight_scales.scalar_type() == out.scalar_type(),
79-
"weight scales scalar type %" PRId8 " does not match out.scalar_type()",
79+
weight_scales.scalar_type() == ScalarType::Float ||
80+
weight_scales.scalar_type() == ScalarType::Half,
81+
"weight_scales.scalar_type() %" PRId8 " is not supported:",
8082
static_cast<int8_t>(weight_scales.scalar_type()));
8183

8284
if (opt_weight_zero_points.has_value()) {
@@ -116,13 +118,19 @@ void check_embedding_byte_args(
116118
" is greater than weight quant max: %" PRId64,
117119
weight_quant_min,
118120
weight_quant_max);
121+
122+
if (out_dtype.has_value()) {
123+
ET_CHECK_MSG(
124+
out.scalar_type() == out_dtype.value(),
125+
"output_dtype must match the dtype of the out tensor");
126+
}
119127
}
120128

121129
/**
122130
* Retrieves the embeddings specified by indices, dequantizes them, and stores
123131
* them in out
124132
*/
125-
template <class CTYPE_WEIGHT, class CTYPE_OUT>
133+
template <typename CTYPE_WEIGHT, typename CTYPE_PARAMS, typename CTYPE_OUT>
126134
void embedding_byte_per_channel(
127135
const Tensor& weight,
128136
const Tensor& weight_scales,
@@ -142,19 +150,19 @@ void embedding_byte_per_channel(
142150
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
143151
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
144152

145-
const CTYPE_OUT* scales = weight_scales.const_data_ptr<CTYPE_OUT>();
146-
const CTYPE_OUT* zero_points = nullptr;
153+
const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
154+
const CTYPE_PARAMS* zero_points = nullptr;
147155
if (opt_weight_zero_points.has_value()) {
148-
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_OUT>();
156+
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
149157
}
150158

151159
for (int i = 0; i < indices.numel(); i++) {
152160
int64_t index = indices_ptr[i];
153161
// If using groupwise embedding
154162
int32_t qparams_index = index * num_groups_per_channel;
155-
CTYPE_OUT zp = 0.0;
156-
const CTYPE_OUT* scale_ptr = scales + qparams_index;
157-
const CTYPE_OUT* zero_points_ptr = nullptr;
163+
CTYPE_PARAMS zp = 0.0;
164+
const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
165+
const CTYPE_PARAMS* zero_points_ptr = nullptr;
158166
if (opt_weight_zero_points.has_value()) {
159167
zero_points_ptr = zero_points + qparams_index;
160168
}
@@ -164,7 +172,7 @@ void embedding_byte_per_channel(
164172

165173
for (int j = 0; j < embedding_dim; ++j) {
166174
int32_t group_id = j / group_size;
167-
const CTYPE_OUT scale = scale_ptr[group_id];
175+
const CTYPE_PARAMS scale = scale_ptr[group_id];
168176
if (opt_weight_zero_points.has_value()) {
169177
zp = zero_points_ptr[group_id];
170178
}
@@ -219,6 +227,7 @@ Tensor& quantized_embedding_byte_out(
219227
const int64_t weight_quant_min,
220228
const int64_t weight_quant_max,
221229
const Tensor& indices,
230+
exec_aten::optional<ScalarType> out_dtype,
222231
Tensor& out) {
223232
// TODO (jakeszwe): improve these to account for the size of out in relation
224233
// to weight and indices accounting for a possible batch dimension
@@ -229,16 +238,20 @@ Tensor& quantized_embedding_byte_out(
229238
weight_quant_min,
230239
weight_quant_max,
231240
indices,
241+
out_dtype,
232242
out);
233243

234-
ScalarType w_type = weight.scalar_type();
244+
ScalarType weight_type = weight.scalar_type();
245+
ScalarType params_type = weight_scales.scalar_type();
235246
ScalarType out_type = out.scalar_type();
236247

237248
constexpr auto name = "quantized_decomposed::embedding_byte.out";
238-
ET_SWITCH_TWO_TYPES(Byte, Char, w_type, ctx, name, CTYPE_W, [&]() {
239-
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
240-
embedding_byte_per_channel<CTYPE_W, CTYPE_OUT>(
241-
weight, weight_scales, opt_weight_zero_points, indices, out);
249+
ET_SWITCH_TWO_TYPES(Byte, Char, weight_type, ctx, name, CTYPE_W, [&]() {
250+
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
251+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
252+
embedding_byte_per_channel<CTYPE_W, CTYPE_P, CTYPE_OUT>(
253+
weight, weight_scales, opt_weight_zero_points, indices, out);
254+
});
242255
});
243256
});
244257

@@ -253,6 +266,7 @@ Tensor& quantized_embedding_byte_out(
253266
int64_t weight_quant_min,
254267
int64_t weight_quant_max,
255268
const Tensor& indices,
269+
exec_aten::optional<ScalarType> out_dtype,
256270
Tensor& out) {
257271
// TODO(larryliu): Add a context arg to the real op function and remove this
258272
// wrapper
@@ -265,6 +279,7 @@ Tensor& quantized_embedding_byte_out(
265279
weight_quant_min,
266280
weight_quant_max,
267281
indices,
282+
out_dtype,
268283
out);
269284
}
270285

kernels/quantized/quantized.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
- arg_meta: null
3535
kernel_name: torch::executor::dequantize_per_channel_out
3636

37-
- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
37+
- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
3838
variants: function
3939
kernels:
4040
- arg_meta: null

kernels/quantized/test/op_embedding_test.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ void test_dtype() {
7676
quant_min,
7777
quant_max,
7878
indices,
79+
out.scalar_type(),
7980
out);
8081

8182
// (8 - 1) * 0.5 = 3.5
@@ -139,6 +140,7 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
139140
quant_min,
140141
quant_max,
141142
indices,
143+
out.scalar_type(),
142144
out);
143145

144146
// Do Q DQ embedding
@@ -196,6 +198,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
196198
quant_min,
197199
quant_max,
198200
indices,
201+
out.scalar_type(),
199202
out);
200203

201204
EXPECT_TENSOR_EQ(out, expected);
@@ -220,6 +223,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
220223
quant_min,
221224
quant_max,
222225
indices,
226+
out.scalar_type(),
223227
out);
224228

225229
EXPECT_TENSOR_EQ(out, expected);
@@ -251,6 +255,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
251255
quant_min,
252256
quant_max,
253257
indices,
258+
out.scalar_type(),
254259
out),
255260
"");
256261
}
@@ -281,6 +286,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
281286
quant_min,
282287
quant_max,
283288
indices,
289+
out.scalar_type(),
284290
out),
285291
"");
286292
}
@@ -310,6 +316,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
310316
quant_min,
311317
quant_max,
312318
indices,
319+
out.scalar_type(),
313320
out),
314321
"");
315322
}
@@ -339,6 +346,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
339346
quant_min,
340347
quant_max,
341348
indices,
349+
out.scalar_type(),
342350
out),
343351
"");
344352
}
@@ -368,6 +376,7 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
368376
quant_min,
369377
quant_max,
370378
indices,
379+
out.scalar_type(),
371380
out),
372381
"");
373382
}

0 commit comments

Comments
 (0)