Skip to content

Commit 3eca367

Browse files
Michael Gschwindmalfet
authored andcommitted
create architecture neutral forward for channel/group-wise embedding quantization operator
1 parent 815508b commit 3eca367

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

.github/workflows/compile.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ jobs:
6565
# echo "******************************************"
6666
# echo "******** Emb: group-wise quantized *******"
6767
# echo "******************************************"
68-
# python generate.py --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
69-
# cat ./output_eager
70-
# python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
71-
# cat ./output_compiled
72-
# python export.py --quant "embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
73-
# python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
68+
python generate.py --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
69+
cat ./output_eager
70+
python generate.py --compile --quant '{"embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
71+
cat ./output_compiled
72+
python export.py --quant "embedding" : {"bitwidth": 8, "group_size": 8}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
73+
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
7474
cat ./output_aoti
7575
7676
# echo "******************************************"

quantize.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,11 +461,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
461461
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
462462
)
463463

464-
result_weights = self.weight.index_select(0, indices.view(-1))
465-
result_scales = self.scales.index_select(0, indices.view(-1))
464+
465+
# result_weights = self.weight.index_select(0, indices.view(-1))
466+
# result_scales = self.scales.index_select(0, indices.view(-1))
467+
468+
weight = self.weight
469+
scales = self.scales.view(weight.shape[0], -1)
470+
471+
result_weights = F.embedding(indices, weight)
472+
result_scales = F.embedding(indices, scales)
473+
474+
rw_view = result_weights.to(dtype=result_scales.dtype).view(tuple(result_weights.shape[:-1] + (scales.shape[1], -1, )))
475+
rs_view = result_scales.view(tuple(result_scales.shape[:-1]) + (scales.shape[1], 1, ))
476+
# print(f"rw_view {rw_view.shape}")
477+
# print(f"rs_view {rs_view.shape}")
466478

467-
r = result_weights.to(dtype=result_scales.dtype) * result_scales
479+
r = rw_view * rs_view
468480
return r.view(indices.size() + (-1,))
481+
482+
# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
469483

470484
##################################################################
471485
##### weight only int4 per channel groupwise quantized code ######

0 commit comments

Comments
 (0)