@@ -202,6 +202,9 @@ def embedding_2bit(
202
202
weight_quant_max : int ,
203
203
indices : torch .Tensor ,
204
204
) -> torch .Tensor :
205
+ assert weight_quant_min == - 2 , "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
206
+ assert weight_quant_max == 1 , "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
207
+
205
208
embedding_weight_checks (weight , weight_scales , weight_zero_points )
206
209
group_size = (4 * weight .size (1 )) // (
207
210
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -257,6 +260,9 @@ def embedding_2bit_dtype(
257
260
indices : torch .Tensor ,
258
261
dtype : Optional [torch .dtype ],
259
262
) -> torch .Tensor :
263
+ assert weight_quant_min == - 2 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
264
+ assert weight_quant_max == 1 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
265
+
260
266
embedding_weight_checks (weight , weight_scales , weight_zero_points )
261
267
group_size = (4 * weight .size (1 )) // (
262
268
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -334,6 +340,9 @@ def embedding_4bit(
334
340
weight_quant_max : int ,
335
341
indices : torch .Tensor ,
336
342
) -> torch .Tensor :
343
+ assert weight_quant_min == - 8 , "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
344
+ assert weight_quant_max == 7 , "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
345
+
337
346
embedding_weight_checks (weight , weight_scales , weight_zero_points )
338
347
group_size = (2 * weight .size (1 )) // (
339
348
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -387,6 +396,9 @@ def embedding_4bit_dtype(
387
396
indices : torch .Tensor ,
388
397
dtype : Optional [torch .dtype ],
389
398
) -> torch .Tensor :
399
+ assert weight_quant_min == - 8 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
400
+ assert weight_quant_max == 7 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
401
+
390
402
embedding_weight_checks (weight , weight_scales , weight_zero_points )
391
403
group_size = (2 * weight .size (1 )) // (
392
404
weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
0 commit comments