@@ -202,8 +202,12 @@ def embedding_2bit(
202
202
weight_quant_max : int ,
203
203
indices : torch .Tensor ,
204
204
) -> torch .Tensor :
205
- assert weight_quant_min == - 2 , "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
206
- assert weight_quant_max == 1 , "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
205
+ assert (
206
+ weight_quant_min == - 2
207
+ ), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
208
+ assert (
209
+ weight_quant_max == 1
210
+ ), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
207
211
208
212
embedding_weight_checks (weight , weight_scales , weight_zero_points )
209
213
group_size = (4 * weight .size (1 )) // (
@@ -260,8 +264,12 @@ def embedding_2bit_dtype(
260
264
indices : torch .Tensor ,
261
265
dtype : Optional [torch .dtype ],
262
266
) -> torch .Tensor :
263
- assert weight_quant_min == - 2 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
264
- assert weight_quant_max == 1 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
267
+ assert (
268
+ weight_quant_min == - 2
269
+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
270
+ assert (
271
+ weight_quant_max == 1
272
+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
265
273
266
274
embedding_weight_checks (weight , weight_scales , weight_zero_points )
267
275
group_size = (4 * weight .size (1 )) // (
@@ -340,8 +348,12 @@ def embedding_4bit(
340
348
weight_quant_max : int ,
341
349
indices : torch .Tensor ,
342
350
) -> torch .Tensor :
343
- assert weight_quant_min == - 8 , "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
344
- assert weight_quant_max == 7 , "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
351
+ assert (
352
+ weight_quant_min == - 8
353
+ ), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
354
+ assert (
355
+ weight_quant_max == 7
356
+ ), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
345
357
346
358
embedding_weight_checks (weight , weight_scales , weight_zero_points )
347
359
group_size = (2 * weight .size (1 )) // (
@@ -396,8 +408,12 @@ def embedding_4bit_dtype(
396
408
indices : torch .Tensor ,
397
409
dtype : Optional [torch .dtype ],
398
410
) -> torch .Tensor :
399
- assert weight_quant_min == - 8 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
400
- assert weight_quant_max == 7 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
411
+ assert (
412
+ weight_quant_min == - 8
413
+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
414
+ assert (
415
+ weight_quant_max == 7
416
+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
401
417
402
418
embedding_weight_checks (weight , weight_scales , weight_zero_points )
403
419
group_size = (2 * weight .size (1 )) // (
0 commit comments