Skip to content

Commit 93d4791

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
fix embedding_4bit resize (#3118)
Summary: Pull Request resolved: #3118 Reviewed By: larryliu0820 Differential Revision: D56282683
1 parent b19d586 commit 93d4791

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

kernels/quantized/cpu/op_embedding4b.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ void resize_out_tensor(
195195
for (size_t i = 0; i < indices.dim(); i++) {
196196
expected_output_size[i] = indices.size(i);
197197
}
198-
const size_t embedding_dim = weight.size(1);
198+
const size_t embedding_dim = weight.size(1) * 2;
199199
expected_output_size[out.dim() - 1] = embedding_dim;
200200

201201
exec_aten::ArrayRef<exec_aten::SizesType> output_size{

kernels/quantized/test/op_embedding4b_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using namespace ::testing;
2020
using exec_aten::ArrayRef;
2121
using exec_aten::optional;
22+
using exec_aten::RuntimeContext;
2223
using exec_aten::ScalarType;
2324
using exec_aten::Tensor;
2425
using torch::executor::native::quantized_embedding_4bit_out;
@@ -60,6 +61,20 @@ TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) {
6061

6162
EXPECT_TENSOR_EQ(out, expected);
6263

64+
out = tf.zeros({3, 4});
65+
auto context = RuntimeContext();
66+
torch::executor::native::quantized_embedding_4bit_out(
67+
context,
68+
qweight,
69+
weight_scales,
70+
weight_zero_points,
71+
quant_min,
72+
quant_max,
73+
indices,
74+
out);
75+
76+
EXPECT_TENSOR_EQ(out, expected);
77+
6378
// Groupwise quantization. groupsize = 2
6479
weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
6580
weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1});

0 commit comments

Comments
 (0)