Skip to content

Commit d69915a

Browse files
mikekgfbmalfet
authored andcommitted
Quantized embedding (#536)
* move int8 linear class and function into qops.py * move Quantized Embedding to qops.py
1 parent 5e266fb commit d69915a

File tree

3 files changed

+128
-114
lines changed

3 files changed

+128
-114
lines changed

eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
torch._inductor.config.epilogue_fusion = False
2929
torch._inductor.config.triton.cudagraphs = True
3030
torch._dynamo.config.cache_size_limit = 100000
31-
import time
3231

3332
try:
3433
import lm_eval

qops.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6+
7+
from build.utils import (
8+
find_multiple,
9+
get_device_str,
10+
get_precision,
11+
name_to_dtype,
12+
state_dict_device,
13+
use_et_backend,
14+
)
615
from torch.nn.parameter import Parameter
716

817

@@ -84,3 +93,119 @@ def __init__(
8493

8594
def forward(self, input: torch.Tensor) -> torch.Tensor:
8695
return linear_int8(input, self.weight, self.scales)
96+
97+
98+
class QuantizedEmbedding(torch.nn.Module):
99+
def __init__(
100+
self,
101+
num_embeddings: int, # vocab_size: int,
102+
embedding_dim: int,
103+
device=None,
104+
dtype=None,
105+
*,
106+
bitwidth: int,
107+
groupsize: Optional[int] = None,
108+
) -> None:
109+
super().__init__()
110+
if dtype is None:
111+
dtype = torch.half
112+
113+
if groupsize is None or groupsize == 0:
114+
groupsize = embedding_dim
115+
self.groupsize = groupsize
116+
self.dtype = dtype
117+
self.bitwidth = bitwidth
118+
119+
if use_et_backend():
120+
self.forward = self.et_forward
121+
else:
122+
self.forward = self.aoti_forward
123+
124+
if bitwidth == 8:
125+
self.register_buffer(
126+
"weight",
127+
torch.empty(
128+
(num_embeddings, embedding_dim), dtype=torch.int8, device=device
129+
),
130+
)
131+
elif bitwidth == 4: # packed
132+
self.register_buffer(
133+
"weight",
134+
torch.empty(
135+
(num_embeddings, embedding_dim // 2),
136+
dtype=torch.uint8,
137+
device=device,
138+
),
139+
)
140+
else:
141+
raise RuntimeError(
142+
f"QUantized embedding does not support bitwidth={bitwidth}"
143+
)
144+
145+
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
146+
if groups_per_row > 1:
147+
self.register_buffer(
148+
"scales",
149+
torch.ones(
150+
(num_embeddings, groups_per_row), dtype=torch.float16, device=device
151+
),
152+
)
153+
else:
154+
self.register_buffer(
155+
"scales",
156+
torch.ones((num_embeddings,), dtype=torch.float16, device=device),
157+
)
158+
159+
@torch.no_grad()
160+
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
161+
if self.bitwidth == 8:
162+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
163+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
164+
)
165+
else:
166+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
167+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
168+
)
169+
170+
@torch.no_grad()
171+
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
172+
# result_weights = self.weight.index_select(0, indices.view(-1))
173+
# result_scales = self.scales.index_select(0, indices.view(-1))
174+
175+
if self.bitwidth == 4:
176+
weight_even = self.weight.div(16, rounding_mode="trunc")
177+
weight_odd = self.weight.remainder(16)
178+
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
179+
weight = weight_unpacked.view(self.weight.shape[0], -1)
180+
weight = weight.to(torch.int8).add(-8)
181+
else:
182+
weight = self.weight
183+
184+
scales = self.scales.view(weight.shape[0], -1)
185+
186+
result_weights = F.embedding(indices, weight)
187+
result_scales = F.embedding(indices, scales)
188+
189+
rw_view = result_weights.to(dtype=result_scales.dtype).view(
190+
tuple(
191+
result_weights.shape[:-1]
192+
+ (
193+
scales.shape[1],
194+
-1,
195+
)
196+
)
197+
)
198+
rs_view = result_scales.view(
199+
tuple(result_scales.shape[:-1])
200+
+ (
201+
scales.shape[1],
202+
1,
203+
)
204+
)
205+
# print(f"rw_view {rw_view.shape}")
206+
# print(f"rs_view {rs_view.shape}")
207+
208+
r = rw_view * rs_view
209+
return r.view(indices.size() + (-1,))
210+
211+
# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))

quantize.py

Lines changed: 3 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
state_dict_device,
2424
use_et_backend,
2525
)
26+
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding
2627

27-
from qops import LinearInt8 as WeightOnlyInt8Linear
2828

2929
#########################################################################
3030
### torchchat quantization API ###
@@ -489,9 +489,9 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
489489
setattr(
490490
module,
491491
name,
492-
QuantizedGroupEmbedding(
492+
QuantizedEmbedding(
493493
device=device,
494-
vocab_size=child.weight.shape[0],
494+
num_embeddings=child.weight.shape[0],
495495
embedding_dim=child.weight.shape[1],
496496
bitwidth=bitwidth,
497497
groupsize=groupsize,
@@ -586,116 +586,6 @@ def quantized_model(self) -> nn.Module:
586586
return self.model_
587587

588588

589-
class QuantizedGroupEmbedding(torch.nn.Module):
590-
def __init__(
591-
self,
592-
device,
593-
vocab_size: int,
594-
embedding_dim: int,
595-
bitwidth: int,
596-
groupsize: Optional[int] = None,
597-
*,
598-
dtype=torch.half,
599-
) -> None:
600-
super().__init__()
601-
if groupsize is None or groupsize == 0:
602-
groupsize = embedding_dim
603-
self.groupsize = groupsize
604-
self.dtype = dtype
605-
self.bitwidth = bitwidth
606-
607-
if use_et_backend():
608-
self.forward = self.et_forward
609-
else:
610-
self.forward = self.aoti_forward
611-
612-
if bitwidth == 8:
613-
self.register_buffer(
614-
"weight",
615-
torch.empty(
616-
(vocab_size, embedding_dim), dtype=torch.int8, device=device
617-
),
618-
)
619-
elif bitwidth == 4: # packed
620-
self.register_buffer(
621-
"weight",
622-
torch.empty(
623-
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
624-
),
625-
)
626-
else:
627-
raise RuntimeError(
628-
f"QUantized embedding does not support bitwidth={bitwidth}"
629-
)
630-
631-
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
632-
if groups_per_row > 1:
633-
self.register_buffer(
634-
"scales",
635-
torch.ones(
636-
(vocab_size, groups_per_row), dtype=torch.float16, device=device
637-
),
638-
)
639-
else:
640-
self.register_buffer(
641-
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
642-
)
643-
644-
@torch.no_grad()
645-
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:
646-
if self.bitwidth == 8:
647-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
648-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
649-
)
650-
else:
651-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
652-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
653-
)
654-
655-
@torch.no_grad()
656-
def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
657-
# result_weights = self.weight.index_select(0, indices.view(-1))
658-
# result_scales = self.scales.index_select(0, indices.view(-1))
659-
660-
if self.bitwidth == 4:
661-
weight_even = self.weight.div(16, rounding_mode="trunc")
662-
weight_odd = self.weight.remainder(16)
663-
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
664-
weight = weight_unpacked.view(self.weight.shape[0], -1)
665-
weight = weight.to(torch.int8).add(-8)
666-
else:
667-
weight = self.weight
668-
669-
scales = self.scales.view(weight.shape[0], -1)
670-
671-
result_weights = F.embedding(indices, weight)
672-
result_scales = F.embedding(indices, scales)
673-
674-
rw_view = result_weights.to(dtype=result_scales.dtype).view(
675-
tuple(
676-
result_weights.shape[:-1]
677-
+ (
678-
scales.shape[1],
679-
-1,
680-
)
681-
)
682-
)
683-
rs_view = result_scales.view(
684-
tuple(result_scales.shape[:-1])
685-
+ (
686-
scales.shape[1],
687-
1,
688-
)
689-
)
690-
# print(f"rw_view {rw_view.shape}")
691-
# print(f"rs_view {rs_view.shape}")
692-
693-
r = rw_view * rs_view
694-
return r.view(indices.size() + (-1,))
695-
696-
# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
697-
698-
699589
#########################################################################
700590
##### weight only int4 per channel groupwise quantized code ######
701591

0 commit comments

Comments
 (0)