Skip to content

Commit a007d75

Browse files
committed
format fix
1 parent d0ca255 commit a007d75

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,12 @@ def embedding_2bit(
202202
weight_quant_max: int,
203203
indices: torch.Tensor,
204204
) -> torch.Tensor:
205-
assert weight_quant_min == -2, "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
206-
assert weight_quant_max == 1, "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
205+
assert (
206+
weight_quant_min == -2
207+
), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
208+
assert (
209+
weight_quant_max == 1
210+
), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
207211

208212
embedding_weight_checks(weight, weight_scales, weight_zero_points)
209213
group_size = (4 * weight.size(1)) // (
@@ -260,8 +264,12 @@ def embedding_2bit_dtype(
260264
indices: torch.Tensor,
261265
dtype: Optional[torch.dtype],
262266
) -> torch.Tensor:
263-
assert weight_quant_min == -2, "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
264-
assert weight_quant_max == 1, "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
267+
assert (
268+
weight_quant_min == -2
269+
), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
270+
assert (
271+
weight_quant_max == 1
272+
), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
265273

266274
embedding_weight_checks(weight, weight_scales, weight_zero_points)
267275
group_size = (4 * weight.size(1)) // (
@@ -340,8 +348,12 @@ def embedding_4bit(
340348
weight_quant_max: int,
341349
indices: torch.Tensor,
342350
) -> torch.Tensor:
343-
assert weight_quant_min == -8, "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
344-
assert weight_quant_max == 7, "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
351+
assert (
352+
weight_quant_min == -8
353+
), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
354+
assert (
355+
weight_quant_max == 7
356+
), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
345357

346358
embedding_weight_checks(weight, weight_scales, weight_zero_points)
347359
group_size = (2 * weight.size(1)) // (
@@ -396,8 +408,12 @@ def embedding_4bit_dtype(
396408
indices: torch.Tensor,
397409
dtype: Optional[torch.dtype],
398410
) -> torch.Tensor:
399-
assert weight_quant_min == -8, "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
400-
assert weight_quant_max == 7, "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
411+
assert (
412+
weight_quant_min == -8
413+
), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
414+
assert (
415+
weight_quant_max == 7
416+
), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
401417

402418
embedding_weight_checks(weight, weight_scales, weight_zero_points)
403419
group_size = (2 * weight.size(1)) // (

0 commit comments

Comments
 (0)