Skip to content

Commit 39c93aa

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
use dequantize per channel group for embedding (#2374)
Summary: Pull Request resolved: #2374 - add type hints to quantized_ops - use dequantize per channel group in embedding byte decomposition bypass-github-export-checks Reviewed By: mikekgfb Differential Revision: D54813256 fbshipit-source-id: 79b8f39d820378faa90d908e3ea56d4201a61598
1 parent e76cd88 commit 39c93aa

File tree

2 files changed

+169
-57
lines changed

2 files changed

+169
-57
lines changed

examples/models/llama2/ops/quantized_ops.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import torch
810
from torch.library import impl, impl_abstract
911

@@ -62,43 +64,45 @@ def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points):
6264
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
6365
0
6466
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
65-
if not weight_zero_points:
66-
weight_zero_points = torch.zeros(weight.size(0))
6767

6868

6969
@impl(quantized_lib, "embedding_byte", "CompositeExplicitAutograd")
70-
def embedding_byte_meta(
71-
weight,
72-
weight_scales,
73-
weight_zero_points,
74-
weight_quant_min,
75-
weight_quant_max,
76-
indices,
77-
):
70+
def embedding_byte(
71+
weight: torch.Tensor,
72+
weight_scales: torch.Tensor,
73+
weight_zero_points: Optional[torch.Tensor],
74+
weight_quant_min: int,
75+
weight_quant_max: int,
76+
indices: torch.Tensor,
77+
) -> torch.Tensor:
7878
embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
79-
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
79+
group_size = weight.size(1) // (
80+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
81+
)
82+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
8083
weight,
8184
weight_scales,
8285
weight_zero_points,
83-
0,
8486
weight_quant_min,
8587
weight_quant_max,
8688
weight.dtype,
89+
group_size,
90+
weight_scales.dtype,
8791
)
8892
return torch.ops.aten.embedding.default(weight, indices)
8993

9094

9195
@impl_abstract("llama_quantized::embedding_byte.out")
9296
def embedding_byte_out_meta(
93-
weight,
94-
weight_scales,
95-
weight_zero_points,
96-
weight_quant_min,
97-
weight_quant_max,
98-
indices,
99-
out,
100-
):
101-
return embedding_byte_meta(
97+
weight: torch.Tensor,
98+
weight_scales: torch.Tensor,
99+
weight_zero_points: Optional[torch.Tensor],
100+
weight_quant_min: int,
101+
weight_quant_max: int,
102+
indices: torch.Tensor,
103+
out: torch.Tensor,
104+
) -> torch.Tensor:
105+
return embedding_byte(
102106
weight,
103107
weight_scales,
104108
weight_zero_points,
@@ -109,42 +113,46 @@ def embedding_byte_out_meta(
109113

110114

111115
@impl(quantized_lib, "embedding_byte.dtype", "CompositeExplicitAutograd")
112-
def embedding_byte_dtype_meta(
113-
weight,
114-
weight_scales,
115-
weight_zero_points,
116-
weight_quant_min,
117-
weight_quant_max,
118-
indices,
116+
def embedding_byte_dtype(
117+
weight: torch.Tensor,
118+
weight_scales: torch.Tensor,
119+
weight_zero_points: Optional[torch.Tensor],
120+
weight_quant_min: int,
121+
weight_quant_max: int,
122+
indices: torch.Tensor,
119123
*,
120-
dtype,
121-
):
124+
dtype: Optional[torch.dtype] = None,
125+
) -> torch.Tensor:
122126
embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
123-
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
127+
group_size = weight.size(1) // (
128+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
129+
)
130+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
124131
weight,
125132
weight_scales,
126133
weight_zero_points,
127-
0,
128134
weight_quant_min,
129135
weight_quant_max,
130136
weight.dtype,
137+
group_size,
138+
dtype,
131139
)
132-
return torch.ops.aten.embedding.default(weight, indices).to(dtype)
140+
return torch.ops.aten.embedding.default(weight, indices)
133141

134142

135143
@impl_abstract("llama_quantized::embedding_byte.dtype_out")
136144
def embedding_byte_dtype_out_meta(
137-
weight,
138-
weight_scales,
139-
weight_zero_points,
140-
weight_quant_min,
141-
weight_quant_max,
142-
indices,
145+
weight: torch.Tensor,
146+
weight_scales: torch.Tensor,
147+
weight_zero_points: Optional[torch.Tensor],
148+
weight_quant_min: int,
149+
weight_quant_max: int,
150+
indices: torch.Tensor,
143151
*,
144-
dtype,
145-
out,
146-
):
147-
return embedding_byte_dtype_meta(
152+
dtype: Optional[torch.dtype] = None,
153+
out: torch.Tensor,
154+
) -> torch.Tensor:
155+
return embedding_byte_dtype(
148156
weight,
149157
weight_scales,
150158
weight_zero_points,

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 118 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
12-
1312
from executorch.exir.passes.replace_aten_with_edge_pass import (
1413
aten_to_edge,
1514
should_lower_to_edge,
@@ -487,6 +486,50 @@ def replacement(
487486
)
488487
return out
489488

489+
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
490+
def pattern_groupwise(
491+
weight,
492+
weight_scales,
493+
weight_zero_points,
494+
weight_quant_min,
495+
weight_quant_max,
496+
indices,
497+
group_size,
498+
):
499+
weight = (
500+
torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
501+
weight,
502+
weight_scales,
503+
weight_zero_points,
504+
weight_quant_min,
505+
weight_quant_max,
506+
weight.dtype,
507+
group_size,
508+
weight_scales.dtype,
509+
)
510+
)
511+
out = torch.ops.aten.embedding.default(weight, indices)
512+
return out
513+
514+
def replacement_groupwise(
515+
weight,
516+
weight_scales,
517+
weight_zero_points,
518+
weight_quant_min,
519+
weight_quant_max,
520+
indices,
521+
group_size,
522+
):
523+
out = torch.ops.quantized_decomposed.embedding_byte.default(
524+
weight,
525+
weight_scales,
526+
weight_zero_points,
527+
weight_quant_min,
528+
weight_quant_max,
529+
indices,
530+
)
531+
return out
532+
490533
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
491534
def pattern_with_padding_idx(
492535
weight,
@@ -528,35 +571,86 @@ def replacement_with_padding_idx(
528571
)
529572
return out
530573

531-
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype")
532-
def pattern_with_dtype(
574+
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
575+
def pattern_with_padding_idx_groupwise(
533576
weight,
534577
weight_scales,
535578
weight_zero_points,
536579
weight_quant_min,
537580
weight_quant_max,
538-
indicies,
539-
dtype,
581+
indices,
582+
group_size,
583+
padding_idx,
540584
):
541-
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
585+
weight = (
586+
torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
587+
weight,
588+
weight_scales,
589+
weight_zero_points,
590+
weight_quant_min,
591+
weight_quant_max,
592+
weight.dtype,
593+
group_size,
594+
weight_scales.dtype,
595+
)
596+
)
597+
out = torch.ops.aten.embedding.default(weight, indices, padding_idx)
598+
return out
599+
600+
def replacement_with_padding_idx_groupwise(
601+
weight,
602+
weight_scales,
603+
weight_zero_points,
604+
weight_quant_min,
605+
weight_quant_max,
606+
indices,
607+
group_size,
608+
_, # padding_idx only matters for training and not when running op for inference
609+
):
610+
out = torch.ops.quantized_decomposed.embedding_byte.default(
542611
weight,
543612
weight_scales,
544613
weight_zero_points,
545-
0,
546614
weight_quant_min,
547615
weight_quant_max,
548-
torch.uint8,
616+
indices,
549617
)
550-
out = torch.ops.aten.embedding.default(weight, indicies).to(dtype)
551618
return out
552619

553-
def replacement_with_dtype(
620+
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype")
621+
def pattern_with_dtype_groupwise(
554622
weight,
555623
weight_scales,
556624
weight_zero_points,
557625
weight_quant_min,
558626
weight_quant_max,
559-
indicies,
627+
indices,
628+
group_size,
629+
dtype,
630+
):
631+
weight = (
632+
torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
633+
weight,
634+
weight_scales,
635+
weight_zero_points,
636+
weight_quant_min,
637+
weight_quant_max,
638+
weight.dtype,
639+
group_size,
640+
dtype,
641+
)
642+
)
643+
out = torch.ops.aten.embedding.default(weight, indices)
644+
return out
645+
646+
def replacement_with_dtype_groupwise(
647+
weight,
648+
weight_scales,
649+
weight_zero_points,
650+
weight_quant_min,
651+
weight_quant_max,
652+
indices,
653+
group_size,
560654
dtype,
561655
):
562656
out = torch.ops.quantized_decomposed.embedding_byte.dtype(
@@ -565,7 +659,7 @@ def replacement_with_dtype(
565659
weight_zero_points,
566660
weight_quant_min,
567661
weight_quant_max,
568-
indicies,
662+
indices,
569663
dtype=dtype,
570664
)
571665
return out
@@ -576,14 +670,24 @@ def replacement_with_dtype(
576670
_trace_and_lower_to_edge_ops(replacement),
577671
[],
578672
),
673+
(
674+
_trace_and_lower_to_edge_ops(pattern_groupwise),
675+
_trace_and_lower_to_edge_ops(replacement_groupwise),
676+
[],
677+
),
579678
(
580679
_trace_and_lower_to_edge_ops(pattern_with_padding_idx),
581680
_trace_and_lower_to_edge_ops(replacement_with_padding_idx),
582681
[],
583682
),
584683
(
585-
_trace_and_lower_to_edge_ops(pattern_with_dtype),
586-
_trace_and_lower_to_edge_ops(replacement_with_dtype),
684+
_trace_and_lower_to_edge_ops(pattern_with_padding_idx_groupwise),
685+
_trace_and_lower_to_edge_ops(replacement_with_padding_idx_groupwise),
686+
[],
687+
),
688+
(
689+
_trace_and_lower_to_edge_ops(pattern_with_dtype_groupwise),
690+
_trace_and_lower_to_edge_ops(replacement_with_dtype_groupwise),
587691
[],
588692
),
589693
]

0 commit comments

Comments
 (0)