@@ -202,6 +202,13 @@ def embedding_2bit(
202
202
weight_quant_max : int ,
203
203
indices : torch .Tensor ,
204
204
) -> torch .Tensor :
205
+ assert (
206
+ weight_quant_min == - 2
207
+ ), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
208
+ assert (
209
+ weight_quant_max == 1
210
+ ), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
211
+
205
212
embedding_weight_checks (weight , weight_scales , weight_zero_points )
206
213
group_size = (4 * weight .size (1 )) // (
207
214
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -257,6 +264,13 @@ def embedding_2bit_dtype(
257
264
indices : torch .Tensor ,
258
265
dtype : Optional [torch .dtype ],
259
266
) -> torch .Tensor :
267
+ assert (
268
+ weight_quant_min == - 2
269
+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
270
+ assert (
271
+ weight_quant_max == 1
272
+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
273
+
260
274
embedding_weight_checks (weight , weight_scales , weight_zero_points )
261
275
group_size = (4 * weight .size (1 )) // (
262
276
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -334,6 +348,13 @@ def embedding_4bit(
334
348
weight_quant_max : int ,
335
349
indices : torch .Tensor ,
336
350
) -> torch .Tensor :
351
+ assert (
352
+ weight_quant_min == - 8
353
+ ), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
354
+ assert (
355
+ weight_quant_max == 7
356
+ ), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
357
+
337
358
embedding_weight_checks (weight , weight_scales , weight_zero_points )
338
359
group_size = (2 * weight .size (1 )) // (
339
360
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -387,6 +408,13 @@ def embedding_4bit_dtype(
387
408
indices : torch .Tensor ,
388
409
dtype : Optional [torch .dtype ],
389
410
) -> torch .Tensor :
411
+ assert (
412
+ weight_quant_min == - 8
413
+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
414
+ assert (
415
+ weight_quant_max == 7
416
+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
417
+
390
418
embedding_weight_checks (weight , weight_scales , weight_zero_points )
391
419
group_size = (2 * weight .size (1 )) // (
392
420
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
0 commit comments