Skip to content

Commit ef99fff

Browse files
authored
Bump torchao pin, adjust llama export to support pre-quantization via quantize_ (phi4-mini load/export)
Differential Revision: D73147002 Pull Request resolved: #10142
1 parent f911567 commit ef99fff

File tree

10 files changed

+123
-40
lines changed

10 files changed

+123
-40
lines changed

backends/xnnpack/operators/op_dynamic_dequantize_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def define_node(
7878

7979
@register_node_visitor
8080
class OpDequantizeAffine(NodeVisitor):
81-
target = "quant.dequantize_affine.default"
81+
target = "torchao.dequantize_affine.default"
8282

8383
def __init__(self, *args) -> None:
8484
super().__init__(*args)

backends/xnnpack/operators/op_dynamic_quantize_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def define_node(
127127

128128
@register_node_visitor
129129
class OpQuantizeAffine(NodeVisitor):
130-
target = "quant.quantize_affine.default"
130+
target = "torchao.quantize_affine.default"
131131

132132
def define_node(
133133
self,

backends/xnnpack/operators/op_skip_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class OpChooseQparamsAffine(OpSkipOps):
8585
do nothing if node is choose_qparams_affine.default
8686
"""
8787

88-
target = "quant.choose_qparams_affine.default"
88+
target = "torchao.choose_qparams_affine.default"
8989

9090

9191
@register_node_visitor

backends/xnnpack/partition/config/quant_affine_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
3636
try:
3737
import torchao.quantization.quant_primitives # noqa
3838

39-
return torch.ops.quant.quantize_affine.default
39+
return torch.ops.torchao.quantize_affine.default
4040
except:
4141
return None
4242

@@ -48,7 +48,7 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
4848
try:
4949
import torchao.quantization.quant_primitives # noqa
5050

51-
return torch.ops.quant.dequantize_affine.default
51+
return torch.ops.torchao.dequantize_affine.default
5252
except:
5353
return None
5454

@@ -60,6 +60,6 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
6060
try:
6161
import torchao.quantization.quant_primitives # noqa
6262

63-
return torch.ops.quant.choose_qparams_affine.default
63+
return torch.ops.torchao.choose_qparams_affine.default
6464
except:
6565
return None

backends/xnnpack/test/ops/test_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,9 @@ def _test_groupwise_dq_linear(
402402
.export()
403403
.check_count(
404404
{
405-
"torch.ops.quant.choose_qparams_affine.default": 1 * num_linears,
406-
"torch.ops.quant.quantize_affine.default": 1 * num_linears,
407-
"torch.ops.quant.dequantize_affine.default": 2 * num_linears,
405+
"torch.ops.torchao.choose_qparams_affine.default": 1 * num_linears,
406+
"torch.ops.torchao.quantize_affine.default": 1 * num_linears,
407+
"torch.ops.torchao.dequantize_affine.default": 2 * num_linears,
408408
"torch.ops.aten.linear.default": 1 * num_linears,
409409
}
410410
)

examples/models/llama/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ python -m examples.models.llama.export_llama \
416416
```
417417
418418
A few notes:
419-
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
419+
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
420420
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.
421421
422422
Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.

examples/models/llama/source_transformation/quantize.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,24 @@ def quantize( # noqa C901
107107
print("quantized model:", model)
108108
return model
109109
elif qmode.startswith("torchao:8da"):
110+
# Check for required args
111+
if group_size is None:
112+
raise Exception(
113+
"For torchao:8daxw quantization, group size must be specified."
114+
)
115+
110116
pattern = r"torchao:8da(\d+)w"
111117
matches = re.findall(pattern, qmode)
112118
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
113119
bitwidth = int(matches[0][0])
114120

115-
from torchao.experimental.quant_api import Int8DynamicActivationIntxWeightConfig
116-
from torchao.quantization.granularity import PerGroup, PerRow
117-
from torchao.quantization.quant_api import quantize_
121+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
122+
from torchao.quantization.granularity import PerAxis, PerGroup
123+
from torchao.quantization.quant_api import (
124+
Int8DynamicActivationIntxWeightConfig,
125+
MappingType,
126+
quantize_,
127+
)
118128
from torchao.utils import unwrap_tensor_subclass
119129

120130
with torch.no_grad():
@@ -124,8 +134,11 @@ def quantize( # noqa C901
124134
model,
125135
Int8DynamicActivationIntxWeightConfig(
126136
weight_dtype=getattr(torch, f"int{bitwidth}"),
127-
granularity=(PerRow() if group_size == 0 else PerGroup(group_size)),
128-
has_weight_zeros=False,
137+
weight_granularity=(
138+
PerAxis(0) if group_size == 0 else PerGroup(group_size)
139+
),
140+
weight_mapping_type=MappingType.SYMMETRIC,
141+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
129142
),
130143
)
131144
model = unwrap_tensor_subclass(model)
@@ -777,38 +790,42 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
777790
EmbeddingQuantizer,
778791
SharedEmbeddingQuantizer,
779792
)
780-
from torchao.quantization.granularity import PerGroup, PerRow
793+
from torchao.quantization.granularity import PerAxis, PerGroup
794+
from torchao.quantization.quant_api import MappingType
781795

782796
quant_args = args.embedding_quantize.split(":")[1].split(",")
783797
if len(quant_args) == 2:
784798
bitwidth, group_size = quant_args
785-
has_weight_zeros = True
799+
is_asymmetric = True
786800
else:
787-
bitwidth, group_size, has_weight_zeros = quant_args
801+
bitwidth, group_size, is_asymmetric = quant_args
788802

789803
if group_size in ["none", "None", "0"]:
790804
group_size = 0
791805

792806
group_size = int(group_size)
793807
bitwidth = int(bitwidth)
794-
has_weight_zeros = bool(has_weight_zeros)
808+
is_asymmetric = bool(is_asymmetric)
795809
weight_dtype = getattr(torch, f"int{bitwidth}")
796-
granularity = PerRow() if group_size == 0 else PerGroup(group_size)
810+
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
811+
mapping_type = (
812+
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
813+
)
797814

798815
def _torchao_embedding_quantizer(model):
799816
with torch.no_grad():
800817
if not args.use_shared_embedding:
801818
EmbeddingQuantizer(
802819
weight_dtype=weight_dtype,
803820
granularity=granularity,
804-
has_weight_zeros=has_weight_zeros,
821+
mapping_type=mapping_type,
805822
use_fallback=False,
806823
).quantize(model)
807824
else:
808825
SharedEmbeddingQuantizer(
809826
weight_dtype=weight_dtype,
810827
granularity=granularity,
811-
has_weight_zeros=has_weight_zeros,
828+
mapping_type=mapping_type,
812829
).quantize(model)
813830
return model
814831

examples/models/phi_4_mini/convert_weights.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23
from typing import Dict
34

45
import torch
@@ -7,6 +8,63 @@
78

89
from torchtune.training import FullModelHFCheckpointer
910

11+
_HF_PHI_4_FROM_META = {
12+
"tok_embeddings.weight": "model.embed_tokens.weight",
13+
"norm.weight": "model.norm.weight",
14+
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
15+
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
16+
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
17+
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
18+
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
19+
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
20+
"layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight",
21+
"layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight",
22+
"layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight",
23+
"output.weight": "lm_head.weight",
24+
}
25+
26+
27+
def phi_4_hf_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28+
"""
29+
Convert a state dict from hf's format to Meta's format.
30+
31+
Args:
32+
state_dict (Dict[str, torch.Tensor]): State dict in hf's format.
33+
34+
Returns:
35+
Dict[str, torch.Tensor]: State dict in Meta's format.
36+
"""
37+
converted_state_dict = {}
38+
inverted_mapping_dict = {v: k for k, v in _HF_PHI_4_FROM_META.items()}
39+
40+
for key, value in state_dict.items():
41+
if key.endswith("mlp.gate_up_proj.weight"):
42+
# Split the gate_up_proj into gate_proj and up_proj
43+
hidden_dim = value.shape[0] // 2
44+
assert 2 * hidden_dim == value.shape[0]
45+
gate = value[0:hidden_dim, :]
46+
up = value[hidden_dim:, :]
47+
for new_key, new_value in [("gate_proj", gate), ("up_proj", up)]:
48+
new_key = key.replace("gate_up_proj", new_key)
49+
new_key = get_mapped_key(new_key, inverted_mapping_dict)
50+
converted_state_dict[new_key] = new_value
51+
elif key.endswith("self_attn.qkv_proj.weight"):
52+
# Split the qkv_proj into q_proj, k_proj, and v_proj
53+
q_dim = value.shape[1]
54+
kv_dim = (value.shape[0] - q_dim) // 2
55+
assert 2 * kv_dim + q_dim == value.shape[0]
56+
q = value[0:q_dim, :]
57+
k = value[q_dim : (q_dim + kv_dim), :]
58+
v = value[(q_dim + kv_dim) :, :]
59+
for new_key, new_value in [("q_proj", q), ("k_proj", k), ("v_proj", v)]:
60+
new_key = key.replace("qkv_proj", new_key)
61+
new_key = get_mapped_key(new_key, inverted_mapping_dict)
62+
converted_state_dict[new_key] = new_value
63+
else:
64+
new_key = get_mapped_key(key, inverted_mapping_dict)
65+
converted_state_dict[new_key] = value
66+
return converted_state_dict
67+
1068

1169
# Standard _FROM_META weight mapping of Meta weights to TorchTune.
1270
_PHI_4_FROM_META = {
@@ -51,22 +109,30 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
51109
return converted_state_dict
52110

53111

54-
def convert_weights(input_dir: str, output_file: str) -> None:
55-
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
56-
checkpointer = FullModelHFCheckpointer(
57-
checkpoint_dir=input_dir,
58-
checkpoint_files=[
59-
"model-00001-of-00002.safetensors",
60-
"model-00002-of-00002.safetensors",
61-
],
62-
output_dir=".",
63-
model_type="PHI4",
64-
)
112+
def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
113+
# If input_dir_or_checkpoint is a directory downloaded from HF, FullModelHFCheckpointer is used to extract the state dict
114+
# If input_dir_or_checkpoint is a checkpoint (from eager model model), it is loaded directly
115+
if os.path.isdir(input_dir_or_checkpoint):
116+
checkpointer = FullModelHFCheckpointer(
117+
checkpoint_dir=input_dir_or_checkpoint,
118+
checkpoint_files=[
119+
"model-00001-of-00002.safetensors",
120+
"model-00002-of-00002.safetensors",
121+
],
122+
output_dir=".",
123+
model_type="PHI4",
124+
)
125+
print("Loading checkpoint from directory...")
126+
sd = checkpointer.load_checkpoint()
127+
sd = sd["model"]
128+
print("Converting checkpoint...")
129+
sd = phi_4_tune_to_meta(sd)
130+
else:
131+
print("Loading checkpoint from file...")
132+
sd = torch.load(input_dir_or_checkpoint, map_location="cpu", weights_only=True)
133+
print("Converting checkpoint...")
134+
sd = phi_4_hf_to_meta(sd)
65135

66-
print("Loading checkpoint...")
67-
sd = checkpointer.load_checkpoint()
68-
print("Converting checkpoint...")
69-
sd = phi_4_tune_to_meta(sd["model"])
70136
print("Saving checkpoint...")
71137
torch.save(sd, output_file)
72138
print("Done.")
@@ -79,7 +145,7 @@ def main():
79145
parser.add_argument(
80146
"input_dir",
81147
type=str,
82-
help="Path to directory containing checkpoint files",
148+
help="Path to directory containing checkpoint files, or path to a single checkpoint file.",
83149
)
84150
parser.add_argument("output", type=str, help="Path to the output checkpoint")
85151

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def keep(op):
986986
try:
987987
# Ops in torch.ops.quant are not always loaded, so we use try/except
988988
# Aliases output, but we need to allow it for XNNPACK
989-
allow_list.append(torch.ops.quant.choose_qparams_affine.default)
989+
allow_list.append(torch.ops.torchao.choose_qparams_affine.default)
990990
except:
991991
pass
992992

third-party/ao

Submodule ao updated 103 files

0 commit comments

Comments
 (0)