Skip to content

Commit f104e0f

Browse files
metascroyYIWENX14
authored andcommitted
Assert quant_min/quant_max in embedding4bit (#7410)
* init * format fix
1 parent 6f939f4 commit f104e0f

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -729,18 +729,18 @@ def __init__(
729729
def forward(self, indices: torch.Tensor) -> torch.Tensor:
730730
if not self.packed: # 8bit
731731
return torch.ops.quantized_decomposed.embedding_byte.dtype(
732-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
732+
self.weight, self.scales, None, -128, 127, indices, dtype=self.dtype
733733
)
734734
else: # packed
735735
if self.bitwidth == 2:
736736
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
737-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
737+
self.weight, self.scales, None, -2, 1, indices, dtype=self.dtype
738738
)
739739

740740
# Remaining case (always return to make pyre happy)
741741
assert self.bitwidth == 4
742742
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
743-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
743+
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
744744
)
745745

746746

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ def embedding_2bit(
202202
weight_quant_max: int,
203203
indices: torch.Tensor,
204204
) -> torch.Tensor:
205+
assert (
206+
weight_quant_min == -2
207+
), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
208+
assert (
209+
weight_quant_max == 1
210+
), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
211+
205212
embedding_weight_checks(weight, weight_scales, weight_zero_points)
206213
group_size = (4 * weight.size(1)) // (
207214
weight_scales.size(1) if weight_scales.dim() == 2 else 1
@@ -257,6 +264,13 @@ def embedding_2bit_dtype(
257264
indices: torch.Tensor,
258265
dtype: Optional[torch.dtype],
259266
) -> torch.Tensor:
267+
assert (
268+
weight_quant_min == -2
269+
), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
270+
assert (
271+
weight_quant_max == 1
272+
), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
273+
260274
embedding_weight_checks(weight, weight_scales, weight_zero_points)
261275
group_size = (4 * weight.size(1)) // (
262276
weight_scales.size(1) if weight_scales.dim() == 2 else 1
@@ -334,6 +348,13 @@ def embedding_4bit(
334348
weight_quant_max: int,
335349
indices: torch.Tensor,
336350
) -> torch.Tensor:
351+
assert (
352+
weight_quant_min == -8
353+
), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
354+
assert (
355+
weight_quant_max == 7
356+
), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
357+
337358
embedding_weight_checks(weight, weight_scales, weight_zero_points)
338359
group_size = (2 * weight.size(1)) // (
339360
weight_scales.size(1) if weight_scales.dim() == 2 else 1
@@ -387,6 +408,13 @@ def embedding_4bit_dtype(
387408
indices: torch.Tensor,
388409
dtype: Optional[torch.dtype],
389410
) -> torch.Tensor:
411+
assert (
412+
weight_quant_min == -8
413+
), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
414+
assert (
415+
weight_quant_max == 7
416+
), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
417+
390418
embedding_weight_checks(weight, weight_scales, weight_zero_points)
391419
group_size = (2 * weight.size(1)) // (
392420
weight_scales.size(1) if weight_scales.dim() == 2 else 1

0 commit comments

Comments
 (0)