Skip to content

Commit 4e752f7

Browse files
author
Guang Yang
committed
Script to export HF models
1 parent ccaaa46 commit 4e752f7

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

examples/models/export_hf_model.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
import torch.export._trace
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
8+
from torch.nn.attention import SDPBackend
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
from transformers.generation.configuration_utils import GenerationConfig
11+
from transformers.integrations.executorch import convert_and_export
12+
from transformers.modeling_utils import PreTrainedModel
13+
14+
15+
def main() -> None:
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"-hfm",
19+
"--hf_model_repo",
20+
required=False,
21+
default=None,
22+
help="a valid huggingface model repo name",
23+
)
24+
25+
args = parser.parse_args()
26+
27+
# Configs to HF model
28+
device = "cpu"
29+
dtype = torch.float32
30+
batch_size = 1
31+
max_length = 123
32+
cache_implementation = "static"
33+
attn_implementation = "sdpa"
34+
35+
# Load and configure a HF model
36+
model = AutoModelForCausalLM.from_pretrained(
37+
args.hf_model_repo,
38+
attn_implementation=attn_implementation,
39+
device_map=device,
40+
torch_dtype=dtype,
41+
generation_config=GenerationConfig(
42+
use_cache=True,
43+
cache_implementation=cache_implementation,
44+
max_length=max_length,
45+
cache_config={
46+
"batch_size": batch_size,
47+
"max_cache_len": max_length,
48+
},
49+
),
50+
)
51+
print(f"{model.config}")
52+
print(f"{model.generation_config}")
53+
54+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
55+
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
56+
cache_position = torch.tensor([0], dtype=torch.long)
57+
58+
def _get_constant_methods(model: PreTrainedModel):
59+
return {
60+
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
61+
"get_bos_id": model.config.bos_token_id,
62+
"get_eos_id": model.config.eos_token_id,
63+
"get_head_dim": model.config.hidden_size / model.config.num_attention_heads,
64+
"get_max_batch_size": model.generation_config.cache_config.batch_size,
65+
"get_max_seq_len": model.generation_config.cache_config.max_cache_len,
66+
"get_n_bos": 1,
67+
"get_n_eos": 1,
68+
"get_n_kv_heads": model.config.num_key_value_heads,
69+
"get_n_layers": model.config.num_hidden_layers,
70+
"get_vocab_size": model.config.vocab_size,
71+
"use_kv_cache": model.generation_config.use_cache,
72+
}
73+
74+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
75+
76+
exported_prog = convert_and_export(model, input_ids, cache_position)
77+
prog = (
78+
to_edge(
79+
exported_prog,
80+
compile_config=EdgeCompileConfig(
81+
_check_ir_validity=False,
82+
_skip_dim_order=True,
83+
),
84+
constant_methods=_get_constant_methods(model),
85+
)
86+
.to_backend(XnnpackPartitioner())
87+
.to_executorch(
88+
ExecutorchBackendConfig(
89+
extract_constant_segment=True, extract_delegate_segments=True
90+
)
91+
)
92+
)
93+
filename = os.path.join("./", f"{model.config.model_type}.pte")
94+
with open(filename, "wb") as f:
95+
prog.write_to_file(f)
96+
print(f"Saved exported program to {filename}")
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)