Skip to content

Commit f3cc50a

Browse files
mikekgfbmalfet
authored andcommitted
Revert "Revert "Embedding quantization per backend"" (#415)
* Revert "Revert "Embedding quantization per backend (#402)" (#411)" This reverts commit 8b35acd. * 4b and 8b embedding table quantization * minor changes * remove extra et workflow
1 parent 767a9ae commit f3cc50a

File tree

4 files changed

+47
-186
lines changed

4 files changed

+47
-186
lines changed

.github/workflows/et-gguf.yml

Lines changed: 0 additions & 133 deletions
This file was deleted.

build/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from __future__ import annotations
8-
from typing import List
9-
from pathlib import Path
10-
import os
8+
119
import logging
10+
import os
11+
from pathlib import Path
12+
from typing import List
1213

1314
import torch
1415

@@ -23,7 +24,7 @@
2324
def set_backend(dso, pte):
2425
global active_builder_args_dso
2526
global active_builder_args_pte
26-
active_builder_args_dso = dso
27+
active_builder_args_dso = dso
2728
active_builder_args_pte = pte
2829

2930

@@ -83,9 +84,11 @@ def name_to_dtype(name):
8384
else:
8485
raise RuntimeError(f"unsupported dtype name {name} specified")
8586

87+
8688
def allowable_dtype_names() -> List[str]:
8789
return name_to_dtype_dict.keys()
8890

91+
8992
name_to_dtype_dict = {
9093
"fp32": torch.float,
9194
"fp16": torch.float16,
@@ -101,7 +104,8 @@ def allowable_dtype_names() -> List[str]:
101104
#########################################################################
102105
### general model build utility functions for CLI ###
103106

104-
def allowable_params_table() -> List[dtr]:
107+
108+
def allowable_params_table() -> List[str]:
105109
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
106110
known_model_params = [
107111
config.replace(".json", "") for config in os.listdir(config_path)

cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import json
88
from pathlib import Path
99

10+
import torch
11+
1012
from build.utils import allowable_dtype_names, allowable_params_table
1113
from download import download_and_convert, is_model_downloaded
1214

13-
import torch
14-
1515
# CPU is always available and also exportable to ExecuTorch
1616
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
1717

@@ -223,7 +223,7 @@ def add_arguments(parser):
223223
"-d",
224224
"--dtype",
225225
default="float32",
226-
choices = allowable_dtype_names(),
226+
choices=allowable_dtype_names(),
227227
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
228228
)
229229
parser.add_argument(

quantize.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import find_multiple, get_precision
18+
from build.utils import find_multiple, get_precision, use_et_backend
1919

2020

2121
#########################################################################
@@ -92,30 +92,6 @@ def quantized_model(self) -> nn.Module:
9292
return self.quantizer.quantize(self.model_)
9393

9494

95-
#########################################################################
96-
### QuantHandler API definition ###
97-
### (unify with torchao in future) ###
98-
99-
100-
class QuantHandler:
101-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
102-
self.model_ = model
103-
self.device = device
104-
self.tokenizer = tokenizer
105-
106-
def create_quantized_state_dict(self) -> Dict: # "StateDict"
107-
pass
108-
109-
def convert_for_runtime(self) -> nn.Module:
110-
pass
111-
112-
def quantized_model(self) -> nn.Module:
113-
model_updated_state_dict = self.create_quantized_state_dict()
114-
self.convert_for_runtime()
115-
self.model_.load_state_dict(model_updated_state_dict)
116-
return self.model_
117-
118-
11995
#########################################################################
12096
### wrapper for setting precision as a QuantHandler ###
12197

@@ -521,7 +497,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
521497

522498

523499
def replace_embedding_weight_only_grouped_int8_per_channel(
524-
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False
500+
module, device, bitwidth: int, groupsize: Optional[int]
525501
):
526502
for name, child in module.named_children():
527503
# print(f"name: {name}")
@@ -535,13 +511,13 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
535511
device=device,
536512
vocab_size=child.weight.shape[0],
537513
embedding_dim=child.weight.shape[1],
514+
bitwidth=bitwidth,
538515
groupsize=groupsize,
539-
packed=packed,
540516
),
541517
)
542518
else:
543519
replace_embedding_weight_only_grouped_int8_per_channel(
544-
child, device, bitwidth, groupsize, packed
520+
child, device, bitwidth, groupsize
545521
)
546522

547523

@@ -554,19 +530,15 @@ def __init__(
554530
*,
555531
bitwidth: int = 8,
556532
groupsize: Optional[int] = None,
557-
packed=True,
533+
packed=True, # we always pack bitwidth 4 now
558534
):
559-
# when quantization dictionary comes from JSON, packed is a string
560-
if isinstance(packed, str):
561-
packed = packed.lower() != "false"
562535
self.model_ = model
563536
self.device = device
564537
self.groupsize = groupsize
565538
self.bitwidth = bitwidth
566-
self.packed = packed
567539

568540
@torch.no_grad()
569-
def create_quantized_state_dict(self, packed=False) -> Dict:
541+
def create_quantized_state_dict(self) -> Dict:
570542
cur_state_dict = self.model_.state_dict()
571543

572544
if self.bitwidth == 4:
@@ -596,7 +568,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
596568
scales_dtype=mod.weight.dtype,
597569
)
598570

599-
if packed:
571+
if self.bitwidth == 4:
600572
if weight.shape[-1] % 2 != 0:
601573
raise RuntimeError("automatic padding not implemented yet")
602574

@@ -620,12 +592,12 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
620592

621593
def convert_for_runtime(self) -> nn.Module:
622594
replace_embedding_weight_only_grouped_int8_per_channel(
623-
self.model_, self.device, self.bitwidth, self.groupsize, self.packed
595+
self.model_, self.device, self.bitwidth, self.groupsize
624596
)
625597
return self.model_
626598

627599
def quantized_model(self) -> nn.Module:
628-
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
600+
model_updated_state_dict = self.create_quantized_state_dict()
629601
self.convert_for_runtime()
630602
self.model_.load_state_dict(model_updated_state_dict)
631603
return self.model_
@@ -637,30 +609,42 @@ def __init__(
637609
device,
638610
vocab_size: int,
639611
embedding_dim: int,
612+
bitwidth: int,
640613
groupsize: Optional[int] = None,
614+
*,
641615
dtype=torch.half,
642-
packed=False,
643616
) -> None:
644617
super().__init__()
645618
if groupsize is None or groupsize == 0:
646619
groupsize = embedding_dim
647620
self.groupsize = groupsize
648621
self.dtype = dtype
649-
self.packed = packed
650-
if not packed:
622+
self.bitwidth = bitwidth
623+
624+
if use_et_backend():
625+
self.forward = self.et_forward
626+
else:
627+
self.forward = self.aoti_forward
628+
629+
if bitwidth == 8:
651630
self.register_buffer(
652631
"weight",
653632
torch.empty(
654633
(vocab_size, embedding_dim), dtype=torch.int8, device=device
655634
),
656635
)
657-
else: # packed
636+
elif bitwidth == 4: # packed
658637
self.register_buffer(
659638
"weight",
660639
torch.empty(
661640
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
662641
),
663642
)
643+
else:
644+
raise RuntimeError(
645+
f"QUantized embedding does not support bitwidth={bitwidth}"
646+
)
647+
664648
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
665649
if groups_per_row > 1:
666650
self.register_buffer(
@@ -675,16 +659,22 @@ def __init__(
675659
)
676660

677661
@torch.no_grad()
678-
def forward(self, indices: torch.Tensor) -> torch.Tensor:
679-
if False: # Used for Executorch
680-
return torch.ops.llama_quantized.embedding_byte.dtype(
662+
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
663+
if self.bitwidth == 8:
664+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
665+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
666+
)
667+
else:
668+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
681669
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
682670
)
683671

672+
@torch.no_grad()
673+
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
684674
# result_weights = self.weight.index_select(0, indices.view(-1))
685675
# result_scales = self.scales.index_select(0, indices.view(-1))
686676

687-
if self.packed:
677+
if self.bitwidth == 4:
688678
weight_even = self.weight.div(16, rounding_mode="trunc")
689679
weight_odd = self.weight.remainder(16)
690680
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)

0 commit comments

Comments
 (0)