Skip to content

Commit a3a525a

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Quantization support for groupwise embedding, various fp16 support
Summary: Quantization support for groupwise embedding, various fp16 support Differential Revision: D54549727
1 parent 1ad35a8 commit a3a525a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

examples/models/llama2/ops/quantized_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points):
5757
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
5858
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
5959
assert (
60-
weight_zero_points is None or weight_zero_points.dim() == 1
61-
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
60+
weight_zero_points is None or weight_zero_points.dim() == weight_scales.dim()
61+
), f"Expecting weight_zero_points tensor to be None or have dim() same as weight scales, but found {weight_zero_points.dim()}"
6262
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
6363
0
6464
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"

examples/models/llama2/quantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def dynamically_quantize_per_channel(
7676

7777
if group_size is None or group_size == 0:
7878
items = x_shape_1
79-
elif not enable_non_multiple_groups:
79+
elif (x_shape_1 % group_size == 0) or not enable_non_multiple_groups:
8080
assert group_size > 0, "group size must be positive"
8181
assert (
8282
x_shape_1 % group_size
@@ -128,6 +128,7 @@ def dynamically_quantize_per_channel(
128128
scales = scales.to(dtype=scales_dtype)
129129
quant = quant[:, :x_shape_1]
130130

131+
print(f"quant shape {quant.shape} scales shape {scales.shape}")
131132
return quant, scales, zero_points
132133

133134

0 commit comments

Comments
 (0)