Skip to content

Revert "Embedding quantization per backend" #411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def set_backend(dso, pte):
global active_builder_args_dso
global active_builder_args_pte
active_builder_args_dso = dso
active_builder_args_dso = dso
active_builder_args_pte = pte


Expand Down
44 changes: 28 additions & 16 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import find_multiple, get_precision, use_et_backend
from build.utils import find_multiple, get_precision


#########################################################################
Expand Down Expand Up @@ -92,6 +92,30 @@ def quantized_model(self) -> nn.Module:
return self.quantizer.quantize(self.model_)


#########################################################################
### QuantHandler API definition ###
### (unify with torchao in future) ###


class QuantHandler:
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
self.model_ = model
self.device = device
self.tokenizer = tokenizer

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass

def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.model_.load_state_dict(model_updated_state_dict)
return self.model_


#########################################################################
### wrapper for setting precision as a QuantHandler ###

Expand Down Expand Up @@ -623,12 +647,6 @@ def __init__(
self.groupsize = groupsize
self.dtype = dtype
self.packed = packed

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

if not packed:
self.register_buffer(
"weight",
Expand Down Expand Up @@ -657,18 +675,12 @@ def __init__(
)

@torch.no_grad()
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
if self.packed:
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else:
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
def forward(self, indices: torch.Tensor) -> torch.Tensor:
if False: # Used for Executorch
return torch.ops.llama_quantized.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

@torch.no_grad()
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

Expand Down