Skip to content

Commit d25b57b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Back out "Enable embedding_byte output dtype be different than scales/zp dtype" (#2210)
Summary: Pull Request resolved: #2210 Original commit changeset: f79754770ddc Original Phabricator Diff: D54141337 Reviewed By: kjweng, mikekgfb Differential Revision: D54454388 fbshipit-source-id: 13381d5e14f53edfaa5b57997e4b8b9ac57a27f4
1 parent 2a42737 commit d25b57b

File tree

7 files changed

+23
-99
lines changed

7 files changed

+23
-99
lines changed

examples/models/llama2/ops/quantized.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
1+
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
22
variants: function
33
kernels:
44
- arg_meta: null

examples/models/llama2/ops/quantized_ops.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
) # to not be confused with torch.ops.quantized.* ops.
1515
quantized_lib.define(
1616
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
17-
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
17+
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
1818
)
1919

2020
quantized_lib.define(
2121
"embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
22-
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
22+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
2323
)
2424

2525

@@ -31,8 +31,6 @@ def embedding_byte_meta(
3131
weight_quant_min,
3232
weight_quant_max,
3333
indices,
34-
*,
35-
dtype,
3634
):
3735
assert weight.dtype in [
3836
torch.int8,
@@ -73,7 +71,7 @@ def embedding_byte_meta(
7371
weight_quant_max,
7472
weight.dtype,
7573
)
76-
return torch.ops.aten.embedding.default(weight, indices).to(dtype)
74+
return torch.ops.aten.embedding.default(weight, indices)
7775

7876

7977
@impl_abstract("llama_quantized::embedding_byte.out")
@@ -84,8 +82,6 @@ def embedding_byte_out_meta(
8482
weight_quant_min,
8583
weight_quant_max,
8684
indices,
87-
*,
88-
dtype,
8985
out,
9086
):
9187
return embedding_byte_meta(
@@ -95,5 +91,4 @@ def embedding_byte_out_meta(
9591
weight_quant_min,
9692
weight_quant_max,
9793
indices,
98-
dtype=dtype,
9994
)

examples/models/llama2/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,8 +887,8 @@ def __init__(
887887
@torch.no_grad()
888888
def forward(self, indices: torch.Tensor) -> torch.Tensor:
889889
return torch.ops.llama_quantized.embedding_byte.default(
890-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
891-
)
890+
self.weight, self.scales, None, 0, 0, indices
891+
).to(self.dtype)
892892

893893

894894
# result_weights = self.weight.index_select(0, indices.view(-1))

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
quantized_decomposed_lib.define(
2929
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
30-
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
30+
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
3131
)
3232

3333
quantized_decomposed_lib.define(
@@ -482,48 +482,6 @@ def replacement(
482482
)
483483
return out
484484

485-
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
486-
def pattern_with_dtype(
487-
weight,
488-
weight_scales,
489-
weight_zero_points,
490-
weight_quant_min,
491-
weight_quant_max,
492-
indicies,
493-
dtype,
494-
):
495-
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
496-
weight,
497-
weight_scales,
498-
weight_zero_points,
499-
0,
500-
weight_quant_min,
501-
weight_quant_max,
502-
torch.uint8,
503-
)
504-
out = torch.ops.aten.embedding.default(weight, indicies).to(dtype)
505-
return out
506-
507-
def replacement_with_dtype(
508-
weight,
509-
weight_scales,
510-
weight_zero_points,
511-
weight_quant_min,
512-
weight_quant_max,
513-
indicies,
514-
dtype,
515-
):
516-
out = torch.ops.quantized_decomposed.embedding_byte.default(
517-
weight,
518-
weight_scales,
519-
weight_zero_points,
520-
weight_quant_min,
521-
weight_quant_max,
522-
indicies,
523-
dtype=dtype,
524-
)
525-
return out
526-
527485
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
528486
def pattern_with_padding_idx(
529487
weight,
@@ -571,11 +529,6 @@ def replacement_with_padding_idx(
571529
_trace_and_lower_to_edge_ops(replacement),
572530
[],
573531
),
574-
(
575-
_trace_and_lower_to_edge_ops(pattern_with_dtype),
576-
_trace_and_lower_to_edge_ops(replacement_with_dtype),
577-
[],
578-
),
579532
(
580533
_trace_and_lower_to_edge_ops(pattern_with_padding_idx),
581534
_trace_and_lower_to_edge_ops(replacement_with_padding_idx),

kernels/quantized/cpu/op_embedding.cpp

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ void check_embedding_byte_args(
3131
const int64_t weight_quant_min,
3232
const int64_t weight_quant_max,
3333
const Tensor& indices,
34-
exec_aten::optional<ScalarType> out_dtype,
3534
Tensor& out) {
3635
ET_CHECK_MSG(
3736
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
@@ -76,9 +75,8 @@ void check_embedding_byte_args(
7675
static_cast<int8_t>(out.scalar_type()));
7776

7877
ET_CHECK_MSG(
79-
weight_scales.scalar_type() == ScalarType::Float ||
80-
weight_scales.scalar_type() == ScalarType::Half,
81-
"weight_scales.scalar_type() %" PRId8 " is not supported:",
78+
weight_scales.scalar_type() == out.scalar_type(),
79+
"weight scales scalar type %" PRId8 " does not match out.scalar_type()",
8280
static_cast<int8_t>(weight_scales.scalar_type()));
8381

8482
if (opt_weight_zero_points.has_value()) {
@@ -118,19 +116,13 @@ void check_embedding_byte_args(
118116
" is greater than weight quant max: %" PRId64,
119117
weight_quant_min,
120118
weight_quant_max);
121-
122-
if (out_dtype.has_value()) {
123-
ET_CHECK_MSG(
124-
out.scalar_type() == out_dtype.value(),
125-
"output_dtype must match the dtype of the out tensor");
126-
}
127119
}
128120

129121
/**
130122
* Retrieves the embeddings specified by indices, dequantizes them, and stores
131123
* them in out
132124
*/
133-
template <typename CTYPE_WEIGHT, typename CTYPE_PARAMS, typename CTYPE_OUT>
125+
template <class CTYPE_WEIGHT, class CTYPE_OUT>
134126
void embedding_byte_per_channel(
135127
const Tensor& weight,
136128
const Tensor& weight_scales,
@@ -150,19 +142,19 @@ void embedding_byte_per_channel(
150142
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
151143
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
152144

153-
const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
154-
const CTYPE_PARAMS* zero_points = nullptr;
145+
const CTYPE_OUT* scales = weight_scales.const_data_ptr<CTYPE_OUT>();
146+
const CTYPE_OUT* zero_points = nullptr;
155147
if (opt_weight_zero_points.has_value()) {
156-
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
148+
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_OUT>();
157149
}
158150

159151
for (int i = 0; i < indices.numel(); i++) {
160152
int64_t index = indices_ptr[i];
161153
// If using groupwise embedding
162154
int32_t qparams_index = index * num_groups_per_channel;
163-
CTYPE_PARAMS zp = 0.0;
164-
const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
165-
const CTYPE_PARAMS* zero_points_ptr = nullptr;
155+
CTYPE_OUT zp = 0.0;
156+
const CTYPE_OUT* scale_ptr = scales + qparams_index;
157+
const CTYPE_OUT* zero_points_ptr = nullptr;
166158
if (opt_weight_zero_points.has_value()) {
167159
zero_points_ptr = zero_points + qparams_index;
168160
}
@@ -172,7 +164,7 @@ void embedding_byte_per_channel(
172164

173165
for (int j = 0; j < embedding_dim; ++j) {
174166
int32_t group_id = j / group_size;
175-
const CTYPE_PARAMS scale = scale_ptr[group_id];
167+
const CTYPE_OUT scale = scale_ptr[group_id];
176168
if (opt_weight_zero_points.has_value()) {
177169
zp = zero_points_ptr[group_id];
178170
}
@@ -227,7 +219,6 @@ Tensor& quantized_embedding_byte_out(
227219
const int64_t weight_quant_min,
228220
const int64_t weight_quant_max,
229221
const Tensor& indices,
230-
exec_aten::optional<ScalarType> out_dtype,
231222
Tensor& out) {
232223
// TODO (jakeszwe): improve these to account for the size of out in relation
233224
// to weight and indices accounting for a possible batch dimension
@@ -238,20 +229,16 @@ Tensor& quantized_embedding_byte_out(
238229
weight_quant_min,
239230
weight_quant_max,
240231
indices,
241-
out_dtype,
242232
out);
243233

244-
ScalarType weight_type = weight.scalar_type();
245-
ScalarType params_type = weight_scales.scalar_type();
234+
ScalarType w_type = weight.scalar_type();
246235
ScalarType out_type = out.scalar_type();
247236

248237
constexpr auto name = "quantized_decomposed::embedding_byte.out";
249-
ET_SWITCH_TWO_TYPES(Byte, Char, weight_type, ctx, name, CTYPE_W, [&]() {
250-
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
251-
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
252-
embedding_byte_per_channel<CTYPE_W, CTYPE_P, CTYPE_OUT>(
253-
weight, weight_scales, opt_weight_zero_points, indices, out);
254-
});
238+
ET_SWITCH_TWO_TYPES(Byte, Char, w_type, ctx, name, CTYPE_W, [&]() {
239+
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
240+
embedding_byte_per_channel<CTYPE_W, CTYPE_OUT>(
241+
weight, weight_scales, opt_weight_zero_points, indices, out);
255242
});
256243
});
257244

@@ -266,7 +253,6 @@ Tensor& quantized_embedding_byte_out(
266253
int64_t weight_quant_min,
267254
int64_t weight_quant_max,
268255
const Tensor& indices,
269-
exec_aten::optional<ScalarType> out_dtype,
270256
Tensor& out) {
271257
// TODO(larryliu): Add a context arg to the real op function and remove this
272258
// wrapper
@@ -279,7 +265,6 @@ Tensor& quantized_embedding_byte_out(
279265
weight_quant_min,
280266
weight_quant_max,
281267
indices,
282-
out_dtype,
283268
out);
284269
}
285270

kernels/quantized/quantized.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
- arg_meta: null
3535
kernel_name: torch::executor::dequantize_per_channel_out
3636

37-
- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
37+
- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
3838
variants: function
3939
kernels:
4040
- arg_meta: null

kernels/quantized/test/op_embedding_test.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ void test_dtype() {
7676
quant_min,
7777
quant_max,
7878
indices,
79-
out.scalar_type(),
8079
out);
8180

8281
// (8 - 1) * 0.5 = 3.5
@@ -140,7 +139,6 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
140139
quant_min,
141140
quant_max,
142141
indices,
143-
out.scalar_type(),
144142
out);
145143

146144
// Do Q DQ embedding
@@ -198,7 +196,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
198196
quant_min,
199197
quant_max,
200198
indices,
201-
out.scalar_type(),
202199
out);
203200

204201
EXPECT_TENSOR_EQ(out, expected);
@@ -223,7 +220,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
223220
quant_min,
224221
quant_max,
225222
indices,
226-
out.scalar_type(),
227223
out);
228224

229225
EXPECT_TENSOR_EQ(out, expected);
@@ -255,7 +251,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
255251
quant_min,
256252
quant_max,
257253
indices,
258-
out.scalar_type(),
259254
out),
260255
"");
261256
}
@@ -286,7 +281,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
286281
quant_min,
287282
quant_max,
288283
indices,
289-
out.scalar_type(),
290284
out),
291285
"");
292286
}
@@ -316,7 +310,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
316310
quant_min,
317311
quant_max,
318312
indices,
319-
out.scalar_type(),
320313
out),
321314
"");
322315
}
@@ -346,7 +339,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
346339
quant_min,
347340
quant_max,
348341
indices,
349-
out.scalar_type(),
350342
out),
351343
"");
352344
}
@@ -376,7 +368,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
376368
quant_min,
377369
quant_max,
378370
indices,
379-
out.scalar_type(),
380371
out),
381372
"");
382373
}

0 commit comments

Comments
 (0)