Skip to content

Back out "Enable embedding_byte output dtype be different than scales/zp dtype" #2210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama2/ops/quantized.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
- func: llama_quantized::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
Expand Down
11 changes: 3 additions & 8 deletions examples/models/llama2/ops/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
) # to not be confused with torch.ops.quantized.* ops.
quantized_lib.define(
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
)

quantized_lib.define(
"embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
)


Expand All @@ -31,8 +31,6 @@ def embedding_byte_meta(
weight_quant_min,
weight_quant_max,
indices,
*,
dtype,
):
assert weight.dtype in [
torch.int8,
Expand Down Expand Up @@ -73,7 +71,7 @@ def embedding_byte_meta(
weight_quant_max,
weight.dtype,
)
return torch.ops.aten.embedding.default(weight, indices).to(dtype)
return torch.ops.aten.embedding.default(weight, indices)


@impl_abstract("llama_quantized::embedding_byte.out")
Expand All @@ -84,8 +82,6 @@ def embedding_byte_out_meta(
weight_quant_min,
weight_quant_max,
indices,
*,
dtype,
out,
):
return embedding_byte_meta(
Expand All @@ -95,5 +91,4 @@ def embedding_byte_out_meta(
weight_quant_min,
weight_quant_max,
indices,
dtype=dtype,
)
4 changes: 2 additions & 2 deletions examples/models/llama2/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,8 +818,8 @@ def __init__(
@torch.no_grad()
def forward(self, indices: torch.Tensor) -> torch.Tensor:
return torch.ops.llama_quantized.embedding_byte.default(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
self.weight, self.scales, None, 0, 0, indices
).to(self.dtype)


# result_weights = self.weight.index_select(0, indices.view(-1))
Expand Down
49 changes: 1 addition & 48 deletions exir/passes/_quant_patterns_and_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

quantized_decomposed_lib.define(
"embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
)

quantized_decomposed_lib.define(
Expand Down Expand Up @@ -482,48 +482,6 @@ def replacement(
)
return out

@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
def pattern_with_dtype(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indicies,
dtype,
):
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
weight,
weight_scales,
weight_zero_points,
0,
weight_quant_min,
weight_quant_max,
torch.uint8,
)
out = torch.ops.aten.embedding.default(weight, indicies).to(dtype)
return out

def replacement_with_dtype(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indicies,
dtype,
):
out = torch.ops.quantized_decomposed.embedding_byte.default(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indicies,
dtype=dtype,
)
return out

@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
def pattern_with_padding_idx(
weight,
Expand Down Expand Up @@ -571,11 +529,6 @@ def replacement_with_padding_idx(
_trace_and_lower_to_edge_ops(replacement),
[],
),
(
_trace_and_lower_to_edge_ops(pattern_with_dtype),
_trace_and_lower_to_edge_ops(replacement_with_dtype),
[],
),
(
_trace_and_lower_to_edge_ops(pattern_with_padding_idx),
_trace_and_lower_to_edge_ops(replacement_with_padding_idx),
Expand Down
45 changes: 15 additions & 30 deletions kernels/quantized/cpu/op_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ void check_embedding_byte_args(
const int64_t weight_quant_min,
const int64_t weight_quant_max,
const Tensor& indices,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
ET_CHECK_MSG(
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
Expand Down Expand Up @@ -76,9 +75,8 @@ void check_embedding_byte_args(
static_cast<int8_t>(out.scalar_type()));

ET_CHECK_MSG(
weight_scales.scalar_type() == ScalarType::Float ||
weight_scales.scalar_type() == ScalarType::Half,
"weight_scales.scalar_type() %" PRId8 " is not supported:",
weight_scales.scalar_type() == out.scalar_type(),
"weight scales scalar type %" PRId8 " does not match out.scalar_type()",
static_cast<int8_t>(weight_scales.scalar_type()));

if (opt_weight_zero_points.has_value()) {
Expand Down Expand Up @@ -118,19 +116,13 @@ void check_embedding_byte_args(
" is greater than weight quant max: %" PRId64,
weight_quant_min,
weight_quant_max);

if (out_dtype.has_value()) {
ET_CHECK_MSG(
out.scalar_type() == out_dtype.value(),
"output_dtype must match the dtype of the out tensor");
}
}

/**
* Retrieves the embeddings specified by indices, dequantizes them, and stores
* them in out
*/
template <typename CTYPE_WEIGHT, typename CTYPE_PARAMS, typename CTYPE_OUT>
template <class CTYPE_WEIGHT, class CTYPE_OUT>
void embedding_byte_per_channel(
const Tensor& weight,
const Tensor& weight_scales,
Expand All @@ -150,19 +142,19 @@ void embedding_byte_per_channel(
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();

const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
const CTYPE_PARAMS* zero_points = nullptr;
const CTYPE_OUT* scales = weight_scales.const_data_ptr<CTYPE_OUT>();
const CTYPE_OUT* zero_points = nullptr;
if (opt_weight_zero_points.has_value()) {
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_OUT>();
}

for (int i = 0; i < indices.numel(); i++) {
int64_t index = indices_ptr[i];
// If using groupwise embedding
int32_t qparams_index = index * num_groups_per_channel;
CTYPE_PARAMS zp = 0.0;
const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
const CTYPE_PARAMS* zero_points_ptr = nullptr;
CTYPE_OUT zp = 0.0;
const CTYPE_OUT* scale_ptr = scales + qparams_index;
const CTYPE_OUT* zero_points_ptr = nullptr;
if (opt_weight_zero_points.has_value()) {
zero_points_ptr = zero_points + qparams_index;
}
Expand All @@ -172,7 +164,7 @@ void embedding_byte_per_channel(

for (int j = 0; j < embedding_dim; ++j) {
int32_t group_id = j / group_size;
const CTYPE_PARAMS scale = scale_ptr[group_id];
const CTYPE_OUT scale = scale_ptr[group_id];
if (opt_weight_zero_points.has_value()) {
zp = zero_points_ptr[group_id];
}
Expand Down Expand Up @@ -227,7 +219,6 @@ Tensor& quantized_embedding_byte_out(
const int64_t weight_quant_min,
const int64_t weight_quant_max,
const Tensor& indices,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
// TODO (jakeszwe): improve these to account for the size of out in relation
// to weight and indices accounting for a possible batch dimension
Expand All @@ -238,20 +229,16 @@ Tensor& quantized_embedding_byte_out(
weight_quant_min,
weight_quant_max,
indices,
out_dtype,
out);

ScalarType weight_type = weight.scalar_type();
ScalarType params_type = weight_scales.scalar_type();
ScalarType w_type = weight.scalar_type();
ScalarType out_type = out.scalar_type();

constexpr auto name = "quantized_decomposed::embedding_byte.out";
ET_SWITCH_TWO_TYPES(Byte, Char, weight_type, ctx, name, CTYPE_W, [&]() {
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
embedding_byte_per_channel<CTYPE_W, CTYPE_P, CTYPE_OUT>(
weight, weight_scales, opt_weight_zero_points, indices, out);
});
ET_SWITCH_TWO_TYPES(Byte, Char, w_type, ctx, name, CTYPE_W, [&]() {
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
embedding_byte_per_channel<CTYPE_W, CTYPE_OUT>(
weight, weight_scales, opt_weight_zero_points, indices, out);
});
});

Expand All @@ -266,7 +253,6 @@ Tensor& quantized_embedding_byte_out(
int64_t weight_quant_min,
int64_t weight_quant_max,
const Tensor& indices,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
Expand All @@ -279,7 +265,6 @@ Tensor& quantized_embedding_byte_out(
weight_quant_min,
weight_quant_max,
indices,
out_dtype,
out);
}

Expand Down
2 changes: 1 addition & 1 deletion kernels/quantized/quantized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
- arg_meta: null
kernel_name: torch::executor::dequantize_per_channel_out

- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
Expand Down
9 changes: 0 additions & 9 deletions kernels/quantized/test/op_embedding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ void test_dtype() {
quant_min,
quant_max,
indices,
out.scalar_type(),
out);

// (8 - 1) * 0.5 = 3.5
Expand Down Expand Up @@ -140,7 +139,6 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out);

// Do Q DQ embedding
Expand Down Expand Up @@ -198,7 +196,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out);

EXPECT_TENSOR_EQ(out, expected);
Expand All @@ -223,7 +220,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out);

EXPECT_TENSOR_EQ(out, expected);
Expand Down Expand Up @@ -255,7 +251,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out),
"");
}
Expand Down Expand Up @@ -286,7 +281,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out),
"");
}
Expand Down Expand Up @@ -316,7 +310,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out),
"");
}
Expand Down Expand Up @@ -346,7 +339,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out),
"");
}
Expand Down Expand Up @@ -376,7 +368,6 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
quant_min,
quant_max,
indices,
out.scalar_type(),
out),
"");
}