Skip to content

Commit 54309ec

Browse files
authored
Revert "Embedding quantization per backend (#402)"
This reverts commit 052fb1a.
1 parent c73b88a commit 54309ec

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

build/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def set_backend(dso, pte):
2424
global active_builder_args_dso
2525
global active_builder_args_pte
26-
active_builder_args_dso = dso
26+
active_builder_args_dso = dso
2727
active_builder_args_pte = pte
2828

2929

quantize.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import find_multiple, get_precision, use_et_backend
18+
from build.utils import find_multiple, get_precision
1919

2020

2121
#########################################################################
@@ -92,6 +92,30 @@ def quantized_model(self) -> nn.Module:
9292
return self.quantizer.quantize(self.model_)
9393

9494

95+
#########################################################################
96+
### QuantHandler API definition ###
97+
### (unify with torchao in future) ###
98+
99+
100+
class QuantHandler:
101+
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
102+
self.model_ = model
103+
self.device = device
104+
self.tokenizer = tokenizer
105+
106+
def create_quantized_state_dict(self) -> Dict: # "StateDict"
107+
pass
108+
109+
def convert_for_runtime(self) -> nn.Module:
110+
pass
111+
112+
def quantized_model(self) -> nn.Module:
113+
model_updated_state_dict = self.create_quantized_state_dict()
114+
self.convert_for_runtime()
115+
self.model_.load_state_dict(model_updated_state_dict)
116+
return self.model_
117+
118+
95119
#########################################################################
96120
### wrapper for setting precision as a QuantHandler ###
97121

@@ -623,12 +647,6 @@ def __init__(
623647
self.groupsize = groupsize
624648
self.dtype = dtype
625649
self.packed = packed
626-
627-
if use_et_backend():
628-
self.forward = self.et_forward
629-
else:
630-
self.forward = self.aoti_forward
631-
632650
if not packed:
633651
self.register_buffer(
634652
"weight",
@@ -657,18 +675,12 @@ def __init__(
657675
)
658676

659677
@torch.no_grad()
660-
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
661-
if self.packed:
662-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
663-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
664-
)
665-
else:
666-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
678+
def forward(self, indices: torch.Tensor) -> torch.Tensor:
679+
if False: # Used for Executorch
680+
return torch.ops.llama_quantized.embedding_byte.dtype(
667681
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
668682
)
669683

670-
@torch.no_grad()
671-
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
672684
# result_weights = self.weight.index_select(0, indices.view(-1))
673685
# result_scales = self.scales.index_select(0, indices.view(-1))
674686

0 commit comments

Comments
 (0)