Skip to content

Revert "Revert "Embedding quantization per backend"" #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 0 additions & 133 deletions .github/workflows/et-gguf.yml

This file was deleted.

14 changes: 9 additions & 5 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.

from __future__ import annotations
from typing import List
from pathlib import Path
import os

import logging
import os
from pathlib import Path
from typing import List

import torch

Expand All @@ -23,7 +24,7 @@
def set_backend(dso, pte):
global active_builder_args_dso
global active_builder_args_pte
active_builder_args_dso = dso
active_builder_args_dso = dso
active_builder_args_pte = pte


Expand Down Expand Up @@ -83,9 +84,11 @@ def name_to_dtype(name):
else:
raise RuntimeError(f"unsupported dtype name {name} specified")


def allowable_dtype_names() -> List[str]:
return name_to_dtype_dict.keys()


name_to_dtype_dict = {
"fp32": torch.float,
"fp16": torch.float16,
Expand All @@ -101,7 +104,8 @@ def allowable_dtype_names() -> List[str]:
#########################################################################
### general model build utility functions for CLI ###

def allowable_params_table() -> List[dtr]:

def allowable_params_table() -> List[str]:
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
known_model_params = [
config.replace(".json", "") for config in os.listdir(config_path)
Expand Down
6 changes: 3 additions & 3 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import json
from pathlib import Path

import torch

from build.utils import allowable_dtype_names, allowable_params_table
from download import download_and_convert, is_model_downloaded

import torch

# CPU is always available and also exportable to ExecuTorch
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -223,7 +223,7 @@ def add_arguments(parser):
"-d",
"--dtype",
default="float32",
choices = allowable_dtype_names(),
choices=allowable_dtype_names(),
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
)
parser.add_argument(
Expand Down
80 changes: 35 additions & 45 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import find_multiple, get_precision
from build.utils import find_multiple, get_precision, use_et_backend


#########################################################################
Expand Down Expand Up @@ -92,30 +92,6 @@ def quantized_model(self) -> nn.Module:
return self.quantizer.quantize(self.model_)


#########################################################################
### QuantHandler API definition ###
### (unify with torchao in future) ###


class QuantHandler:
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
self.model_ = model
self.device = device
self.tokenizer = tokenizer

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass

def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.model_.load_state_dict(model_updated_state_dict)
return self.model_


#########################################################################
### wrapper for setting precision as a QuantHandler ###

Expand Down Expand Up @@ -521,7 +497,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


def replace_embedding_weight_only_grouped_int8_per_channel(
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False
module, device, bitwidth: int, groupsize: Optional[int]
):
for name, child in module.named_children():
# print(f"name: {name}")
Expand All @@ -535,13 +511,13 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
device=device,
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
bitwidth=bitwidth,
groupsize=groupsize,
packed=packed,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, device, bitwidth, groupsize, packed
child, device, bitwidth, groupsize
)


Expand All @@ -554,19 +530,15 @@ def __init__(
*,
bitwidth: int = 8,
groupsize: Optional[int] = None,
packed=True,
packed=True, # we always pack bitwidth 4 now
):
# when quantization dictionary comes from JSON, packed is a string
if isinstance(packed, str):
packed = packed.lower() != "false"
self.model_ = model
self.device = device
self.groupsize = groupsize
self.bitwidth = bitwidth
self.packed = packed

@torch.no_grad()
def create_quantized_state_dict(self, packed=False) -> Dict:
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = self.model_.state_dict()

if self.bitwidth == 4:
Expand Down Expand Up @@ -596,7 +568,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
scales_dtype=mod.weight.dtype,
)

if packed:
if self.bitwidth == 4:
if weight.shape[-1] % 2 != 0:
raise RuntimeError("automatic padding not implemented yet")

Expand All @@ -620,12 +592,12 @@ def create_quantized_state_dict(self, packed=False) -> Dict:

def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.model_, self.device, self.bitwidth, self.groupsize, self.packed
self.model_, self.device, self.bitwidth, self.groupsize
)
return self.model_

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.model_.load_state_dict(model_updated_state_dict)
return self.model_
Expand All @@ -637,30 +609,42 @@ def __init__(
device,
vocab_size: int,
embedding_dim: int,
bitwidth: int,
groupsize: Optional[int] = None,
*,
dtype=torch.half,
packed=False,
) -> None:
super().__init__()
if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.packed = packed
if not packed:
self.bitwidth = bitwidth

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

if bitwidth == 8:
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim), dtype=torch.int8, device=device
),
)
else: # packed
elif bitwidth == 4: # packed
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
),
)
else:
raise RuntimeError(
f"QUantized embedding does not support bitwidth={bitwidth}"
)

groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
Expand All @@ -675,16 +659,22 @@ def __init__(
)

@torch.no_grad()
def forward(self, indices: torch.Tensor) -> torch.Tensor:
if False: # Used for Executorch
return torch.ops.llama_quantized.embedding_byte.dtype(
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
if self.bitwidth == 8:
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else:
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

@torch.no_grad()
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

if self.packed:
if self.bitwidth == 4:
weight_even = self.weight.div(16, rounding_mode="trunc")
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
Expand Down