Skip to content

Commit 9d7e16f

Browse files
author
Guang Yang
committed
Script to export HF models
1 parent cac2c05 commit 9d7e16f

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed

.github/workflows/trunk.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,93 @@ jobs:
351351
PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_model.sh "${MODEL_NAME}" "${BUILD_TOOL}" "${BACKEND}"
352352
echo "::endgroup::"
353353
done
354+
355+
test-huggingface-transformers:
356+
name: test-huggingface-transformers
357+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
358+
strategy:
359+
matrix:
360+
hf_model_repo: [google/gemma-2b]
361+
fail-fast: false
362+
with:
363+
secrets-env: "HF_TOKEN_PERIODIC"
364+
runner: linux.12xlarge
365+
docker-image: executorch-ubuntu-22.04-clang12
366+
submodules: 'true'
367+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
368+
timeout: 90
369+
steps:
370+
- name: Set up ExecuTorch
371+
run: |
372+
# The generic Linux job chooses to use base env, not the one setup by the image
373+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
374+
conda activate "${CONDA_ENV}"
375+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake
376+
377+
pushd executorch
378+
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
379+
rm -rf cmake-out
380+
retry cmake \
381+
-DCMAKE_INSTALL_PREFIX=cmake-out \
382+
-DCMAKE_BUILD_TYPE=Release \
383+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
384+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
385+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
386+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
387+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
388+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
389+
-DEXECUTORCH_BUILD_XNNPACK=ON \
390+
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
391+
-Bcmake-out .
392+
cmake --build cmake-out -j9 --target install --config Release
393+
394+
echo "Build llama runner"
395+
dir="examples/models/llama2"
396+
retry cmake \
397+
-DCMAKE_INSTALL_PREFIX=cmake-out \
398+
-DCMAKE_BUILD_TYPE=Release \
399+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
400+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
401+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
402+
-DEXECUTORCH_BUILD_XNNPACK=ON \
403+
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
404+
-Bcmake-out/${dir} \
405+
${dir}
406+
cmake --build cmake-out/${dir} -j9 --config Release
407+
popd
408+
409+
- name: Set up HuggingFace Hub
410+
run: |
411+
pip install -U "huggingface_hub[cli]"
412+
HF_TOKEN="${SECRET_HF_TOKEN_PERIODIC}" huggingface-cli login
413+
414+
- name: Set up HuggingFace Transformers
415+
run: |
416+
# TODO(guangyang): Switch to use released transformers library after all required patches are included
417+
git clone --branch main https://github.com/huggingface/transformers.git
418+
pushd transformers
419+
pip install .
420+
popd
421+
422+
- name: Export to ExecuTorch
423+
run: |
424+
pushd executorch
425+
python -m extension.export_util.export_hf_model -hfm=${{ matrix.hf_model_repo }}
426+
427+
# Transform Hugging Face model repo name to cache dir name
428+
TRANSFORMED_MODEL_NAME="models--$(echo "$MODEL_NAME" | sed 's/\//--/g')"
429+
430+
# Search for tokenizer.model within the transformed model directory
431+
TOKENIZER_PATH=$(find "~/.cache/huggingface/hub" -type f -name "tokenizer.model" -path "*/$TRANSFORMED_MODEL_NAME/*" -print -quit)
432+
if [ -z "$TOKENIZER_PATH" ]; then
433+
echo "tokenizer.model not found for model ${{ matrix.hf_model_repo }}"
434+
exit 1
435+
else
436+
echo "Found tokenizer.model at: $TOKENIZER_PATH"
437+
echo "$TOKENIZER_PATH"
438+
cp TOKENIZER_PATH ./
439+
fi
440+
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
441+
442+
cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=tokenizer.bin --prompt="My name is"
443+
popd
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
import torch.export._trace
6+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
7+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
8+
from torch.nn.attention import SDPBackend
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
from transformers.generation.configuration_utils import GenerationConfig
11+
from transformers.integrations.executorch import convert_and_export_with_cache
12+
from transformers.modeling_utils import PreTrainedModel
13+
14+
15+
def main() -> None:
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"-hfm",
19+
"--hf_model_repo",
20+
required=False,
21+
default=None,
22+
help="a valid huggingface model repo name",
23+
)
24+
25+
args = parser.parse_args()
26+
27+
# Configs to HF model
28+
device = "cpu"
29+
dtype = torch.float32
30+
batch_size = 1
31+
max_length = 123
32+
cache_implementation = "static"
33+
attn_implementation = "sdpa"
34+
35+
# Load and configure a HF model
36+
model = AutoModelForCausalLM.from_pretrained(
37+
args.hf_model_repo,
38+
attn_implementation=attn_implementation,
39+
device_map=device,
40+
torch_dtype=dtype,
41+
generation_config=GenerationConfig(
42+
use_cache=True,
43+
cache_implementation=cache_implementation,
44+
max_length=max_length,
45+
cache_config={
46+
"batch_size": batch_size,
47+
"max_cache_len": max_length,
48+
},
49+
),
50+
)
51+
print(f"{model.config}")
52+
print(f"{model.generation_config}")
53+
54+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
55+
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
56+
cache_position = torch.tensor([0], dtype=torch.long)
57+
58+
def _get_constant_methods(model: PreTrainedModel):
59+
return {
60+
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
61+
"get_bos_id": model.config.bos_token_id,
62+
"get_eos_id": model.config.eos_token_id,
63+
"get_head_dim": model.config.hidden_size / model.config.num_attention_heads,
64+
"get_max_batch_size": model.generation_config.cache_config.batch_size,
65+
"get_max_seq_len": model.generation_config.cache_config.max_cache_len,
66+
"get_n_bos": 1,
67+
"get_n_eos": 1,
68+
"get_n_kv_heads": model.config.num_key_value_heads,
69+
"get_n_layers": model.config.num_hidden_layers,
70+
"get_vocab_size": model.config.vocab_size,
71+
"use_kv_cache": model.generation_config.use_cache,
72+
}
73+
74+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
75+
76+
exported_prog = convert_and_export_with_cache(model, input_ids, cache_position)
77+
prog = (
78+
to_edge(
79+
exported_prog,
80+
compile_config=EdgeCompileConfig(
81+
_check_ir_validity=False,
82+
_skip_dim_order=True,
83+
),
84+
constant_methods=_get_constant_methods(model),
85+
)
86+
.to_backend(XnnpackPartitioner())
87+
.to_executorch(
88+
ExecutorchBackendConfig(
89+
extract_delegate_segments=True
90+
)
91+
)
92+
)
93+
filename = os.path.join("./", f"{model.config.model_type}.pte")
94+
with open(filename, "wb") as f:
95+
prog.write_to_file(f)
96+
print(f"Saved exported program to {filename}")
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)