Skip to content

Commit d0ca255

Browse files
committed
init
1 parent f341da8 commit d0ca255

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -729,18 +729,18 @@ def __init__(
729729
def forward(self, indices: torch.Tensor) -> torch.Tensor:
730730
if not self.packed: # 8bit
731731
return torch.ops.quantized_decomposed.embedding_byte.dtype(
732-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
732+
self.weight, self.scales, None, -128, 127, indices, dtype=self.dtype
733733
)
734734
else: # packed
735735
if self.bitwidth == 2:
736736
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
737-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
737+
self.weight, self.scales, None, -2, 1, indices, dtype=self.dtype
738738
)
739739

740740
# Remaining case (always return to make pyre happy)
741741
assert self.bitwidth == 4
742742
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
743-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
743+
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
744744
)
745745

746746

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ def embedding_2bit(
202202
weight_quant_max: int,
203203
indices: torch.Tensor,
204204
) -> torch.Tensor:
205+
assert weight_quant_min == -2, "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
206+
assert weight_quant_max == 1, "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
207+
205208
embedding_weight_checks(weight, weight_scales, weight_zero_points)
206209
group_size = (4 * weight.size(1)) // (
207210
weight_scales.size(1) if weight_scales.dim() == 2 else 1
@@ -257,6 +260,9 @@ def embedding_2bit_dtype(
257260
indices: torch.Tensor,
258261
dtype: Optional[torch.dtype],
259262
) -> torch.Tensor:
263+
assert weight_quant_min == -2, "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
264+
assert weight_quant_max == 1, "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
265+
260266
embedding_weight_checks(weight, weight_scales, weight_zero_points)
261267
group_size = (4 * weight.size(1)) // (
262268
weight_scales.size(1) if weight_scales.dim() == 2 else 1
@@ -334,6 +340,9 @@ def embedding_4bit(
334340
weight_quant_max: int,
335341
indices: torch.Tensor,
336342
) -> torch.Tensor:
343+
assert weight_quant_min == -8, "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
344+
assert weight_quant_max == 7, "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
345+
337346
embedding_weight_checks(weight, weight_scales, weight_zero_points)
338347
group_size = (2 * weight.size(1)) // (
339348
weight_scales.size(1) if weight_scales.dim() == 2 else 1
@@ -387,6 +396,9 @@ def embedding_4bit_dtype(
387396
indices: torch.Tensor,
388397
dtype: Optional[torch.dtype],
389398
) -> torch.Tensor:
399+
assert weight_quant_min == -8, "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
400+
assert weight_quant_max == 7, "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
401+
390402
embedding_weight_checks(weight, weight_scales, weight_zero_points)
391403
group_size = (2 * weight.size(1)) // (
392404
weight_scales.size(1) if weight_scales.dim() == 2 else 1

0 commit comments

Comments
 (0)