Skip to content

Commit b3eefd7

Browse files
author
Guang Yang
committed
Script to export HF models
1 parent 933685b commit b3eefd7

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

.github/workflows/trunk.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,93 @@ jobs:
351351
PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_model.sh "${MODEL_NAME}" "${BUILD_TOOL}" "${BACKEND}"
352352
echo "::endgroup::"
353353
done
354+
355+
test-huggingface-transformers:
356+
name: test-huggingface-transformers
357+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
358+
secrets: inherit
359+
strategy:
360+
matrix:
361+
hf_model_repo: [google/gemma-2b]
362+
fail-fast: false
363+
with:
364+
secrets-env: EXECUTORCH_HF_TOKEN
365+
runner: linux.12xlarge
366+
docker-image: executorch-ubuntu-22.04-clang12
367+
submodules: 'true'
368+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
369+
timeout: 90
370+
script: |
371+
echo "::group::Set up ExecuTorch"
372+
# The generic Linux job chooses to use base env, not the one setup by the image
373+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
374+
conda activate "${CONDA_ENV}"
375+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake
376+
377+
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
378+
rm -rf cmake-out
379+
cmake \
380+
-DCMAKE_INSTALL_PREFIX=cmake-out \
381+
-DCMAKE_BUILD_TYPE=Release \
382+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
383+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
384+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
385+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
386+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
387+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
388+
-DEXECUTORCH_BUILD_XNNPACK=ON \
389+
-DPYTHON_EXECUTABLE=python \
390+
-Bcmake-out .
391+
cmake --build cmake-out -j9 --target install --config Release
392+
393+
echo "Build llama runner"
394+
dir="examples/models/llama2"
395+
cmake \
396+
-DCMAKE_INSTALL_PREFIX=cmake-out \
397+
-DCMAKE_BUILD_TYPE=Release \
398+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
399+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
400+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
401+
-DEXECUTORCH_BUILD_XNNPACK=ON \
402+
-DPYTHON_EXECUTABLE=python \
403+
-Bcmake-out/${dir} \
404+
${dir}
405+
cmake --build cmake-out/${dir} -j9 --config Release
406+
echo "::endgroup::"
407+
408+
echo "::group::Set up HuggingFace Dependencies"
409+
pip install -U "huggingface_hub[cli]"
410+
huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
411+
pip install accelerate sentencepiece
412+
# TODO(guangyang): Switch to use released transformers library after all required patches are included
413+
pip install "git+https://github.com/huggingface/transformers.git@6cc4dfe3f1e8d421c6d6351388e06e9b123cbfe1"
414+
pip list
415+
echo "::endgroup::"
416+
417+
echo "::group::Export to ExecuTorch"
418+
TOKENIZER_FILE=tokenizer.model
419+
TOKENIZER_BIN_FILE=tokenizer.bin
420+
ET_MODEL_NAME=et_model
421+
# Fetch the file using a Python one-liner
422+
DOWNLOADED_TOKENIZER_FILE_PATH=$(python -c "
423+
from huggingface_hub import hf_hub_download
424+
# Download the file from the Hugging Face Hub
425+
downloaded_path = hf_hub_download(
426+
repo_id='${{ matrix.hf_model_repo }}',
427+
filename='${TOKENIZER_FILE}'
428+
)
429+
print(downloaded_path)
430+
")
431+
if [ -f "$DOWNLOADED_TOKENIZER_FILE_PATH" ]; then
432+
echo "${TOKENIZER_FILE} downloaded successfully at: $DOWNLOADED_TOKENIZER_FILE_PATH"
433+
python -m extension.llm.tokenizer.tokenizer -t $DOWNLOADED_TOKENIZER_FILE_PATH -o ./${TOKENIZER_BIN_FILE}
434+
ls ./tokenizer.bin
435+
else
436+
echo "Failed to download ${TOKENIZER_FILE} from ${{ matrix.hf_model_repo }}."
437+
exit 1
438+
fi
439+
440+
python -m extension.export_util.export_hf_model -hfm=${{ matrix.hf_model_repo }} -o ${ET_MODEL_NAME}
441+
442+
cmake-out/examples/models/llama2/llama_main --model_path=${ET_MODEL_NAME}.pte --tokenizer_path=${TOKENIZER_BIN_FILE} --prompt="My name is"
443+
echo "::endgroup::"
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
import torch.export._trace
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
8+
from torch.nn.attention import SDPBackend
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
from transformers.generation.configuration_utils import GenerationConfig
11+
from transformers.integrations.executorch import convert_and_export_with_cache
12+
from transformers.modeling_utils import PreTrainedModel
13+
14+
15+
def main() -> None:
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"-hfm",
19+
"--hf_model_repo",
20+
required=True,
21+
default=None,
22+
help="a valid huggingface model repo name",
23+
)
24+
parser.add_argument(
25+
"-o",
26+
"--output_name",
27+
required=False,
28+
default=None,
29+
help="output name of the exported model",
30+
)
31+
32+
args = parser.parse_args()
33+
34+
# Configs to HF model
35+
device = "cpu"
36+
dtype = torch.float32
37+
batch_size = 1
38+
max_length = 123
39+
cache_implementation = "static"
40+
attn_implementation = "sdpa"
41+
42+
# Load and configure a HF model
43+
model = AutoModelForCausalLM.from_pretrained(
44+
args.hf_model_repo,
45+
attn_implementation=attn_implementation,
46+
device_map=device,
47+
torch_dtype=dtype,
48+
generation_config=GenerationConfig(
49+
use_cache=True,
50+
cache_implementation=cache_implementation,
51+
max_length=max_length,
52+
cache_config={
53+
"batch_size": batch_size,
54+
"max_cache_len": max_length,
55+
},
56+
),
57+
)
58+
print(f"{model.config}")
59+
print(f"{model.generation_config}")
60+
61+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
62+
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
63+
cache_position = torch.tensor([0], dtype=torch.long)
64+
65+
def _get_constant_methods(model: PreTrainedModel):
66+
return {
67+
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
68+
"get_bos_id": model.config.bos_token_id,
69+
"get_eos_id": model.config.eos_token_id,
70+
"get_head_dim": model.config.hidden_size / model.config.num_attention_heads,
71+
"get_max_batch_size": model.generation_config.cache_config.batch_size,
72+
"get_max_seq_len": model.generation_config.cache_config.max_cache_len,
73+
"get_n_bos": 1,
74+
"get_n_eos": 1,
75+
"get_n_kv_heads": model.config.num_key_value_heads,
76+
"get_n_layers": model.config.num_hidden_layers,
77+
"get_vocab_size": model.config.vocab_size,
78+
"use_kv_cache": model.generation_config.use_cache,
79+
}
80+
81+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
82+
83+
exported_prog = convert_and_export_with_cache(model, input_ids, cache_position)
84+
prog = (
85+
to_edge(
86+
exported_prog,
87+
compile_config=EdgeCompileConfig(
88+
_check_ir_validity=False,
89+
_skip_dim_order=True,
90+
),
91+
constant_methods=_get_constant_methods(model),
92+
)
93+
.to_backend(XnnpackPartitioner())
94+
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
95+
)
96+
out_name = args.output_name if args.output_name else model.config.model_type
97+
filename = os.path.join("./", f"{out_name}.pte")
98+
with open(filename, "wb") as f:
99+
prog.write_to_file(f)
100+
print(f"Saved exported program to {filename}")
101+
102+
103+
if __name__ == "__main__":
104+
main()

0 commit comments

Comments
 (0)