Skip to content

Quantized embedding #536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.triton.cudagraphs = True
torch._dynamo.config.cache_size_limit = 100000
import time

try:
import lm_eval
Expand Down
125 changes: 125 additions & 0 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from build.utils import (
find_multiple,
get_device_str,
get_precision,
name_to_dtype,
state_dict_device,
use_et_backend,
)
from torch.nn.parameter import Parameter


Expand Down Expand Up @@ -84,3 +93,119 @@ def __init__(

def forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8(input, self.weight, self.scales)


class QuantizedEmbedding(torch.nn.Module):
def __init__(
self,
num_embeddings: int, # vocab_size: int,
embedding_dim: int,
device=None,
dtype=None,
*,
bitwidth: int,
groupsize: Optional[int] = None,
) -> None:
super().__init__()
if dtype is None:
dtype = torch.half

if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.bitwidth = bitwidth

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

if bitwidth == 8:
self.register_buffer(
"weight",
torch.empty(
(num_embeddings, embedding_dim), dtype=torch.int8, device=device
),
)
elif bitwidth == 4: # packed
self.register_buffer(
"weight",
torch.empty(
(num_embeddings, embedding_dim // 2),
dtype=torch.uint8,
device=device,
),
)
else:
raise RuntimeError(
f"QUantized embedding does not support bitwidth={bitwidth}"
)

groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
"scales",
torch.ones(
(num_embeddings, groups_per_row), dtype=torch.float16, device=device
),
)
else:
self.register_buffer(
"scales",
torch.ones((num_embeddings,), dtype=torch.float16, device=device),
)

@torch.no_grad()
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
if self.bitwidth == 8:
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else:
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

@torch.no_grad()
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

if self.bitwidth == 4:
weight_even = self.weight.div(16, rounding_mode="trunc")
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
weight = weight_unpacked.view(self.weight.shape[0], -1)
weight = weight.to(torch.int8).add(-8)
else:
weight = self.weight

scales = self.scales.view(weight.shape[0], -1)

result_weights = F.embedding(indices, weight)
result_scales = F.embedding(indices, scales)

rw_view = result_weights.to(dtype=result_scales.dtype).view(
tuple(
result_weights.shape[:-1]
+ (
scales.shape[1],
-1,
)
)
)
rs_view = result_scales.view(
tuple(result_scales.shape[:-1])
+ (
scales.shape[1],
1,
)
)
# print(f"rw_view {rw_view.shape}")
# print(f"rs_view {rs_view.shape}")

r = rw_view * rs_view
return r.view(indices.size() + (-1,))

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
116 changes: 3 additions & 113 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
state_dict_device,
use_et_backend,
)
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding

from qops import LinearInt8 as WeightOnlyInt8Linear

#########################################################################
### torchchat quantization API ###
Expand Down Expand Up @@ -489,9 +489,9 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
setattr(
module,
name,
QuantizedGroupEmbedding(
QuantizedEmbedding(
device=device,
vocab_size=child.weight.shape[0],
num_embeddings=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
bitwidth=bitwidth,
groupsize=groupsize,
Expand Down Expand Up @@ -586,116 +586,6 @@ def quantized_model(self) -> nn.Module:
return self.model_


class QuantizedGroupEmbedding(torch.nn.Module):
def __init__(
self,
device,
vocab_size: int,
embedding_dim: int,
bitwidth: int,
groupsize: Optional[int] = None,
*,
dtype=torch.half,
) -> None:
super().__init__()
if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.bitwidth = bitwidth

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

if bitwidth == 8:
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim), dtype=torch.int8, device=device
),
)
elif bitwidth == 4: # packed
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
),
)
else:
raise RuntimeError(
f"QUantized embedding does not support bitwidth={bitwidth}"
)

groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
"scales",
torch.ones(
(vocab_size, groups_per_row), dtype=torch.float16, device=device
),
)
else:
self.register_buffer(
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
)

@torch.no_grad()
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
if self.bitwidth == 8:
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else:
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

@torch.no_grad()
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))

if self.bitwidth == 4:
weight_even = self.weight.div(16, rounding_mode="trunc")
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
weight = weight_unpacked.view(self.weight.shape[0], -1)
weight = weight.to(torch.int8).add(-8)
else:
weight = self.weight

scales = self.scales.view(weight.shape[0], -1)

result_weights = F.embedding(indices, weight)
result_scales = F.embedding(indices, scales)

rw_view = result_weights.to(dtype=result_scales.dtype).view(
tuple(
result_weights.shape[:-1]
+ (
scales.shape[1],
-1,
)
)
)
rs_view = result_scales.view(
tuple(result_scales.shape[:-1])
+ (
scales.shape[1],
1,
)
)
# print(f"rw_view {rw_view.shape}")
# print(f"rs_view {rs_view.shape}")

r = rw_view * rs_view
return r.view(indices.size() + (-1,))

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))


#########################################################################
##### weight only int4 per channel groupwise quantized code ######

Expand Down