Skip to content

Commit a540bfe

Browse files
committed
init
1 parent a664d7b commit a540bfe

File tree

6 files changed

+116
-30
lines changed

6 files changed

+116
-30
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,13 @@ def from_q_dq_node(
185185
quant_node_args = extract_qdq_affine_op_args_for_decomposed_ops(quant_node)
186186

187187
scale = quant_node_args[1]
188-
zp = quant_node_args[2]
188+
zp = quant_node_args[2] if len(quant_node_args) > 2 else None
189189
axis = 0
190190
if per_channel:
191191
assert isinstance(scale, torch.fx.Node) and isinstance(scale.target, str)
192-
assert isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
192+
assert zp is None or (
193+
isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
194+
)
193195
assert (
194196
ep is not None
195197
), "ExportedProgram must be provided to extract per channel params"
@@ -200,7 +202,11 @@ def _get_tensor(node):
200202
return cast(torch.Tensor, param)
201203

202204
scale = _get_tensor(scale)
203-
zp = _get_tensor(zp)
205+
zp = (
206+
_get_tensor(zp)
207+
if zp is not None
208+
else torch.zeros_like(scale, dtype=torch.int8)
209+
)
204210
axis = cast(int, quant_node_args[3])
205211

206212
if _groupwise:

backends/xnnpack/utils/quant_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,25 @@ def is_dynamic_qdq(node: torch.fx.Node) -> bool:
5858
node_input_args = extract_qdq_affine_op_args_for_decomposed_ops(node)
5959

6060
scale = node_input_args[1]
61-
zp = node_input_args[2]
62-
if not (isinstance(scale, torch.fx.Node) and isinstance(zp, torch.fx.Node)):
61+
if not isinstance(scale, torch.fx.Node):
6362
return False
64-
65-
if not (scale.target == operator.getitem and zp.target == operator.getitem):
63+
if not (scale.target == operator.getitem):
6664
return False
67-
6865
scale_choose_qparam = scale.all_input_nodes[0]
69-
zp_choose_qparam = zp.all_input_nodes[0]
70-
71-
if not (is_qparam(scale_choose_qparam) and is_qparam(zp_choose_qparam)):
66+
if not is_qparam(scale_choose_qparam):
7267
return False
7368

69+
if len(node_input_args) > 2:
70+
zp = node_input_args[2]
71+
if not isinstance(zp, torch.fx.Node):
72+
return False
73+
74+
if not (zp.target == operator.getitem):
75+
return False
76+
zp_choose_qparam = zp.all_input_nodes[0]
77+
if not is_qparam(zp_choose_qparam):
78+
return False
79+
7480
return True
7581

7682

@@ -223,7 +229,7 @@ def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node):
223229
# add target_dtype_node after quant_min/quant_max
224230
args.append(target_dtype)
225231
# zero_point_domain
226-
if len(node.args) > 7 and node.args[7] != "INT":
232+
if len(node.args) > 7 and node.args[7] not in ["INT", "NONE"]:
227233
return None, None
228234

229235
if is_per_channel_group(node):

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,6 @@ def _to_edge_and_lower_llama_xnnpack(
763763
raise NotImplementedError(
764764
"export_llama does not support XNNPack and generating ETRecord at the moment."
765765
)
766-
767766
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
768767
partitioners
769768
)

examples/models/llama/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
get_default_model_resource_dir,
1717
)
1818
from executorch.examples.models.llama.llama_transformer import Transformer
19-
2019
from executorch.examples.models.llama.model_args import ModelArgs
20+
from torchao.utils import TorchAOBaseTensor
2121

2222
try:
2323
from .fairseq2 import convert_to_llama_checkpoint
@@ -101,6 +101,7 @@ def __init__(self, **kwargs):
101101
if fairseq2_checkpoint:
102102
print("Using fairseq2 checkpoint")
103103
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
104+
print("checkpoint", checkpoint)
104105
if "model" in checkpoint:
105106
# NB: some checkpoint contains a "model" field, which is the actual weights dict
106107
checkpoint = checkpoint["model"]
@@ -257,6 +258,9 @@ def __init__(self, **kwargs):
257258
strict=False,
258259
assign=True,
259260
) # self.model_ = Transformer(gptconf)
261+
for param in self.model_.parameters():
262+
if isinstance(param, TorchAOBaseTensor):
263+
param.requires_grad = False
260264
else:
261265
print("Checkpoint not provided, defaulting weights to zeros.")
262266
self.model_.to_empty(device="cpu")

examples/models/phi_4_mini/convert_weights.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23
from typing import Dict
34

45
import torch
@@ -7,6 +8,63 @@
78

89
from torchtune.training import FullModelHFCheckpointer
910

11+
_HF_PHI_4_FROM_META = {
12+
"tok_embeddings.weight": "model.embed_tokens.weight",
13+
"norm.weight": "model.norm.weight",
14+
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
15+
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
16+
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
17+
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
18+
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
19+
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
20+
"layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight",
21+
"layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight",
22+
"layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight",
23+
"output.weight": "lm_head.weight",
24+
}
25+
26+
27+
def phi_4_hf_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28+
"""
29+
Convert a state dict from hf's format to Meta's format.
30+
31+
Args:
32+
state_dict (Dict[str, torch.Tensor]): State dict in hf's format.
33+
34+
Returns:
35+
Dict[str, torch.Tensor]: State dict in Meta's format.
36+
"""
37+
converted_state_dict = {}
38+
inverted_mapping_dict = {v: k for k, v in _HF_PHI_4_FROM_META.items()}
39+
40+
for key, value in state_dict.items():
41+
if key.endswith("mlp.gate_up_proj.weight"):
42+
# Split the gate_up_proj into gate_proj and up_proj
43+
hidden_dim = value.shape[0] // 2
44+
assert 2 * hidden_dim == value.shape[0]
45+
gate = value[0:hidden_dim, :]
46+
up = value[hidden_dim:, :]
47+
for new_key, new_value in [("gate_proj", gate), ("up_proj", up)]:
48+
new_key = key.replace("gate_up_proj", new_key)
49+
new_key = get_mapped_key(new_key, inverted_mapping_dict)
50+
converted_state_dict[new_key] = new_value
51+
elif key.endswith("self_attn.qkv_proj.weight"):
52+
# Split the qkv_proj into q_proj, k_proj, and v_proj
53+
q_dim = value.shape[1]
54+
kv_dim = (value.shape[0] - q_dim) // 2
55+
assert 2 * kv_dim + q_dim == value.shape[0]
56+
q = value[0:q_dim, :]
57+
k = value[q_dim : (q_dim + kv_dim), :]
58+
v = value[(q_dim + kv_dim) :, :]
59+
for new_key, new_value in [("q_proj", q), ("k_proj", k), ("v_proj", v)]:
60+
new_key = key.replace("qkv_proj", new_key)
61+
new_key = get_mapped_key(new_key, inverted_mapping_dict)
62+
converted_state_dict[new_key] = new_value
63+
else:
64+
new_key = get_mapped_key(key, inverted_mapping_dict)
65+
converted_state_dict[new_key] = value
66+
return converted_state_dict
67+
1068

1169
# Standard _FROM_META weight mapping of Meta weights to TorchTune.
1270
_PHI_4_FROM_META = {
@@ -51,22 +109,29 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
51109
return converted_state_dict
52110

53111

54-
def convert_weights(input_dir: str, output_file: str) -> None:
112+
def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
55113
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
56-
checkpointer = FullModelHFCheckpointer(
57-
checkpoint_dir=input_dir,
58-
checkpoint_files=[
59-
"model-00001-of-00002.safetensors",
60-
"model-00002-of-00002.safetensors",
61-
],
62-
output_dir=".",
63-
model_type="PHI4",
64-
)
114+
if os.path.isdir(input_dir_or_checkpoint):
115+
checkpointer = FullModelHFCheckpointer(
116+
checkpoint_dir=input_dir_or_checkpoint,
117+
checkpoint_files=[
118+
"model-00001-of-00002.safetensors",
119+
"model-00002-of-00002.safetensors",
120+
],
121+
output_dir=".",
122+
model_type="PHI4",
123+
)
124+
print("Loading checkpoint from directory...")
125+
sd = checkpointer.load_checkpoint()
126+
sd = sd["model"]
127+
print("Converting checkpoint...")
128+
sd = phi_4_tune_to_meta(sd)
129+
else:
130+
print("Loading checkpoint from file...")
131+
sd = torch.load(input_dir_or_checkpoint, map_location="cpu", weights_only=True)
132+
print("Converting checkpoint...")
133+
sd = phi_4_hf_to_meta(sd)
65134

66-
print("Loading checkpoint...")
67-
sd = checkpointer.load_checkpoint()
68-
print("Converting checkpoint...")
69-
sd = phi_4_tune_to_meta(sd["model"])
70135
print("Saving checkpoint...")
71136
torch.save(sd, output_file)
72137
print("Done.")
@@ -79,7 +144,7 @@ def main():
79144
parser.add_argument(
80145
"input_dir",
81146
type=str,
82-
help="Path to directory containing checkpoint files",
147+
help="Path to directory containing checkpoint files, or path to a single checkpoint file.",
83148
)
84149
parser.add_argument("output", type=str, help="Path to the output checkpoint")
85150

extension/llm/export/builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from executorch.exir.passes import MemoryPlanningPass
3232
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
3333
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
34-
3534
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
3635

3736
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
@@ -41,6 +40,7 @@
4140
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
4241
from torch.export import export_for_training, ExportedProgram
4342
from torch.nn.attention import SDPBackend
43+
from torchao.utils import unwrap_tensor_subclass
4444

4545
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4646
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -199,6 +199,11 @@ def _get_edge_config(self) -> EdgeCompileConfig:
199199
return edge_config
200200

201201
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
202+
if module is not None:
203+
unwrap_tensor_subclass(module)
204+
else:
205+
unwrap_tensor_subclass(self.model)
206+
202207
dynamic_shape = self._get_dynamic_shape()
203208
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204209
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -226,6 +231,7 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
226231
logging.info("Re-exporting with:")
227232
else:
228233
logging.info("Exporting with:")
234+
229235
logging.info(f"inputs: {self.example_inputs}")
230236
logging.info(f"kwargs: {self.example_kwarg_inputs}")
231237
logging.info(f"dynamic shapes: {dynamic_shape}")

0 commit comments

Comments
 (0)