Skip to content

Commit 9508897

Browse files
authored
Revert "Revert "Embedding quantization per backend (#402)" (#411)"
This reverts commit 8b35acd.
1 parent 73a8f1e commit 9508897

File tree

2 files changed

+17
-29
lines changed

2 files changed

+17
-29
lines changed

build/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def set_backend(dso, pte):
2424
global active_builder_args_dso
2525
global active_builder_args_pte
26-
active_builder_args_dso = dso
26+
active_builder_args_dso = dso
2727
active_builder_args_pte = pte
2828

2929

quantize.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import find_multiple, get_precision
18+
from build.utils import find_multiple, get_precision, use_et_backend
1919

2020

2121
#########################################################################
@@ -92,30 +92,6 @@ def quantized_model(self) -> nn.Module:
9292
return self.quantizer.quantize(self.model_)
9393

9494

95-
#########################################################################
96-
### QuantHandler API definition ###
97-
### (unify with torchao in future) ###
98-
99-
100-
class QuantHandler:
101-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
102-
self.model_ = model
103-
self.device = device
104-
self.tokenizer = tokenizer
105-
106-
def create_quantized_state_dict(self) -> Dict: # "StateDict"
107-
pass
108-
109-
def convert_for_runtime(self) -> nn.Module:
110-
pass
111-
112-
def quantized_model(self) -> nn.Module:
113-
model_updated_state_dict = self.create_quantized_state_dict()
114-
self.convert_for_runtime()
115-
self.model_.load_state_dict(model_updated_state_dict)
116-
return self.model_
117-
118-
11995
#########################################################################
12096
### wrapper for setting precision as a QuantHandler ###
12197

@@ -647,6 +623,12 @@ def __init__(
647623
self.groupsize = groupsize
648624
self.dtype = dtype
649625
self.packed = packed
626+
627+
if use_et_backend():
628+
self.forward = self.et_forward
629+
else:
630+
self.forward = self.aoti_forward
631+
650632
if not packed:
651633
self.register_buffer(
652634
"weight",
@@ -675,12 +657,18 @@ def __init__(
675657
)
676658

677659
@torch.no_grad()
678-
def forward(self, indices: torch.Tensor) -> torch.Tensor:
679-
if False: # Used for Executorch
680-
return torch.ops.llama_quantized.embedding_byte.dtype(
660+
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
661+
if self.packed:
662+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
663+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
664+
)
665+
else:
666+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
681667
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
682668
)
683669

670+
@torch.no_grad()
671+
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
684672
# result_weights = self.weight.index_select(0, indices.view(-1))
685673
# result_scales = self.scales.index_select(0, indices.view(-1))
686674

0 commit comments

Comments
 (0)