Skip to content

Commit d63b352

Browse files
metascroyfacebook-github-bot
authored andcommitted
Add 2b embedding op (#5800)
Summary: Pull Request resolved: #5800 Reviewed By: kimishpatel Differential Revision: D64011080 Pulled By: metascroy fbshipit-source-id: 449005a54fd34bd41004a6c19cc5fc25b996003e
1 parent 7ba7990 commit d63b352

File tree

11 files changed

+1034
-257
lines changed

11 files changed

+1034
-257
lines changed

examples/models/llama2/source_transformation/quantize.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
494494
group_size=group_size,
495495
dtype=child.weight.dtype,
496496
packed=packed,
497+
bitwidth=bitwidth,
497498
),
498499
)
499500
else:
@@ -519,14 +520,17 @@ def __init__(
519520
self.group_size = group_size
520521
self.bitwidth = bitwidth
521522
self.packed = packed
522-
if (bitwidth != 4) and packed:
523-
raise RuntimeError("pack only works with bitsize 4")
523+
if (bitwidth not in [2, 4]) and packed:
524+
raise RuntimeError("pack only works with bitsize 2, 4")
524525

525526
@torch.no_grad()
526527
def create_quantized_state_dict(self, packed=False) -> Dict:
527528
cur_state_dict = self.mod.state_dict()
528529

529-
if self.bitwidth == 4:
530+
if self.bitwidth == 2:
531+
range_min = -2
532+
range_max = 1
533+
elif self.bitwidth == 4:
530534
range_min = -8
531535
range_max = 7
532536
elif self.bitwidth == 8:
@@ -555,17 +559,30 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
555559
)
556560

557561
if packed:
558-
if weight.shape[-1] % 2 != 0:
559-
raise RuntimeError("automatic padding not implemented yet")
560-
561-
weight_range_shifted = weight.add(8).view(torch.uint8)
562-
weight_view = weight_range_shifted.view(
563-
weight.shape[0], weight.shape[1] // 2, 2
564-
)
565-
weight_even = weight_view[:, :, 0] * 16 # left shift 4
566-
weight_odd = weight_view[:, :, 1]
567-
weight_packed = weight_even + weight_odd
568-
weight = weight_packed
562+
if self.bitwidth == 2:
563+
if weight.shape[-1] % 4 != 0:
564+
raise RuntimeError("automatic padding not implemented yet")
565+
weight_range_shifted = weight.add(2).view(torch.uint8)
566+
weight_view = weight_range_shifted.view(
567+
weight.shape[0], weight.shape[1] // 4, 4
568+
)
569+
weight_0 = weight_view[:, :, 0]
570+
weight_1 = weight_view[:, :, 1] << 2
571+
weight_2 = weight_view[:, :, 2] << 4
572+
weight_3 = weight_view[:, :, 3] << 6
573+
weight_packed = weight_0 + weight_1 + weight_2 + weight_3
574+
weight = weight_packed
575+
elif self.bitwidth == 4:
576+
if weight.shape[-1] % 2 != 0:
577+
raise RuntimeError("automatic padding not implemented yet")
578+
weight_range_shifted = weight.add(8).view(torch.uint8)
579+
weight_view = weight_range_shifted.view(
580+
weight.shape[0], weight.shape[1] // 2, 2
581+
)
582+
weight_even = weight_view[:, :, 0] * 16 # left shift 4
583+
weight_odd = weight_view[:, :, 1]
584+
weight_packed = weight_even + weight_odd
585+
weight = weight_packed
569586

570587
weight = weight.to(device=self.device)
571588
scales = scales.to(device=self.device)
@@ -598,13 +615,15 @@ def __init__(
598615
group_size: Optional[int] = None,
599616
dtype=torch.half,
600617
packed=False,
618+
bitwidth: int = 8,
601619
) -> None:
602620
super().__init__()
603621
if group_size is None or group_size == 0:
604622
group_size = embedding_dim
605623
self.group_size = group_size
606624
self.dtype = dtype
607625
self.packed = packed
626+
self.bitwidth = bitwidth
608627
if not packed:
609628
self.register_buffer(
610629
"weight",
@@ -613,12 +632,25 @@ def __init__(
613632
),
614633
)
615634
else: # packed
616-
self.register_buffer(
617-
"weight",
618-
torch.empty(
619-
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
620-
),
621-
)
635+
if bitwidth == 2:
636+
self.register_buffer(
637+
"weight",
638+
torch.empty(
639+
(vocab_size, embedding_dim // 4),
640+
dtype=torch.uint8,
641+
device=device,
642+
),
643+
)
644+
elif bitwidth == 4:
645+
self.register_buffer(
646+
"weight",
647+
torch.empty(
648+
(vocab_size, embedding_dim // 2),
649+
dtype=torch.uint8,
650+
device=device,
651+
),
652+
)
653+
622654
groups_per_row = (embedding_dim + group_size - 1) // group_size
623655
if groups_per_row > 1:
624656
self.register_buffer(
@@ -638,7 +670,14 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
638670
return torch.ops.quantized_decomposed.embedding_byte.dtype(
639671
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
640672
)
641-
else: # 4bit packed
673+
else: # packed
674+
if self.bitwidth == 2:
675+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
676+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
677+
)
678+
679+
# Remaining case (always return to make pyre happy)
680+
assert self.bitwidth == 4
642681
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
643682
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
644683
)
@@ -658,7 +697,7 @@ def get_quant_embedding_transform(args):
658697
model,
659698
bitwidth=bitwidth,
660699
group_size=group_size,
661-
packed=(bitwidth == 4),
700+
packed=(bitwidth in [2, 4]),
662701
).quantized_model()
663702

664703

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,138 @@ def embedding_byte_dtype_out_meta(
172172
)
173173

174174

175+
quantized_decomposed_lib.define(
176+
"embedding_2bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177+
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
178+
)
179+
180+
quantized_decomposed_lib.define(
181+
"embedding_2bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
182+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
183+
)
184+
185+
quantized_decomposed_lib.define(
186+
"embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
187+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
188+
)
189+
190+
quantized_decomposed_lib.define(
191+
"embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193+
)
194+
195+
196+
@impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd")
197+
def embedding_2bit(
198+
weight: torch.Tensor,
199+
weight_scales: torch.Tensor,
200+
weight_zero_points: Optional[torch.Tensor],
201+
weight_quant_min: int,
202+
weight_quant_max: int,
203+
indices: torch.Tensor,
204+
) -> torch.Tensor:
205+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
206+
group_size = (4 * weight.size(1)) // (
207+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
208+
)
209+
weight_0 = weight & 3
210+
weight_1 = (weight & 12) >> 2
211+
weight_2 = (weight & 48) >> 4
212+
weight_3 = (weight & 192) >> 6
213+
weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1)
214+
weight = weight_unpacked.view(weight.shape[0], -1)
215+
weight = weight.view(torch.int8).add(-2)
216+
217+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
218+
weight,
219+
weight_scales,
220+
weight_zero_points,
221+
weight_quant_min,
222+
weight_quant_max,
223+
weight.dtype,
224+
group_size,
225+
weight_scales.dtype,
226+
)
227+
return torch.ops.aten.embedding.default(weight, indices)
228+
229+
230+
@register_fake("quantized_decomposed::embedding_2bit.out")
231+
def embedding_2bit_out_meta(
232+
weight: torch.Tensor,
233+
weight_scales: torch.Tensor,
234+
weight_zero_points: Optional[torch.Tensor],
235+
weight_quant_min: int,
236+
weight_quant_max: int,
237+
indices: torch.Tensor,
238+
out: torch.Tensor,
239+
) -> torch.Tensor:
240+
return embedding_2bit(
241+
weight,
242+
weight_scales,
243+
weight_zero_points,
244+
weight_quant_min,
245+
weight_quant_max,
246+
indices,
247+
)
248+
249+
250+
@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd")
251+
def embedding_2bit_dtype(
252+
weight: torch.Tensor,
253+
weight_scales: torch.Tensor,
254+
weight_zero_points: Optional[torch.Tensor],
255+
weight_quant_min: int,
256+
weight_quant_max: int,
257+
indices: torch.Tensor,
258+
dtype: Optional[torch.dtype],
259+
) -> torch.Tensor:
260+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
261+
group_size = (4 * weight.size(1)) // (
262+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
263+
)
264+
weight_0 = weight & 3
265+
weight_1 = (weight & 12) >> 2
266+
weight_2 = (weight & 48) >> 4
267+
weight_3 = (weight & 192) >> 6
268+
weight_unpacked = torch.stack((weight_0, weight_1, weight_2, weight_3), dim=-1)
269+
weight = weight_unpacked.view(weight.shape[0], -1)
270+
weight = weight.view(torch.int8).add(-2)
271+
272+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
273+
weight,
274+
weight_scales,
275+
weight_zero_points,
276+
weight_quant_min,
277+
weight_quant_max,
278+
weight.dtype,
279+
group_size,
280+
dtype,
281+
)
282+
return torch.ops.aten.embedding.default(weight, indices)
283+
284+
285+
@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
286+
def embedding_2bit_dtype_out_meta(
287+
weight: torch.Tensor,
288+
weight_scales: torch.Tensor,
289+
weight_zero_points: Optional[torch.Tensor],
290+
weight_quant_min: int,
291+
weight_quant_max: int,
292+
indices: torch.Tensor,
293+
dtype: Optional[torch.dtype],
294+
out: torch.Tensor,
295+
) -> torch.Tensor:
296+
return embedding_2bit_dtype(
297+
weight,
298+
weight_scales,
299+
weight_zero_points,
300+
weight_quant_min,
301+
weight_quant_max,
302+
indices,
303+
dtype,
304+
)
305+
306+
175307
quantized_decomposed_lib.define(
176308
"embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177309
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",

0 commit comments

Comments
 (0)