Skip to content

Commit 525acfb

Browse files
authored
ci: Add llama3 gpu workflow in perioidic (#399)
1 parent ea62e84 commit 525acfb

File tree

3 files changed

+72
-8
lines changed

3 files changed

+72
-8
lines changed

.ci/scripts/download_llama.sh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env bash
2+
3+
set -xeou pipefail
4+
5+
shopt -s globstar
6+
7+
install_huggingface_cli() {
8+
pip install -U "huggingface_hub[cli]"
9+
}
10+
11+
download_checkpoint() {
12+
# This funciton is "technically re-usable but ymmv"
13+
# includes org name, like <org>/<repo>
14+
local repo_name=$1
15+
local include=$2
16+
# basically just removes the org in <org>/<repo>
17+
local local_dir="checkpoints/${repo_name}"
18+
19+
mkdir -p "${local_dir}"
20+
huggingface-cli download \
21+
"${repo_name}" \
22+
--quiet \
23+
--include "${include}" \
24+
--local-dir "${local_dir}"
25+
}
26+
27+
# install huggingface-cli if not already installed
28+
if ! command -v huggingface-cli; then
29+
install_huggingface_cli
30+
fi
31+
32+
# TODO: Eventually you could extend this to download different models
33+
# taking in some arguments similar to .ci/scripts/wget_checkpoint.sh
34+
download_checkpoint "meta-llama/Meta-Llama-3-8B" "original/*"

.ci/scripts/gather_test_models.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/generation_config.json,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/pytorch_model-00001-of-00002.bin,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/pytorch_model-00002-of-00002.bin,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/pytorch_model.bin.index.json,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/special_tokens_map.json,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.json,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer.model,https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/tokenizer_config.json",
2020
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/generation_config.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/pytorch_model-00001-of-00002.bin,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/pytorch_model-00002-of-00002.bin,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/pytorch_model.bin.index.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/special_tokens_map.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/tokenizer.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/tokenizer.model,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/tokenizer_config.json",
2121
"mistralai/Mistral-7B-Instruct-v0.2": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/config.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/generation_config.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/pytorch_model-00001-of-00003.bin,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/pytorch_model-00002-of-00003.bin,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/pytorch_model-00003-of-00003.bin,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/pytorch_model.bin.index.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/special_tokens_map.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.json,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer.model,https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/tokenizer_config.json",
22+
23+
# huggingface-cli prefixed Models will download using the huggingface-cli tool
24+
# TODO: Convert all of the MODEL_REPOS with a NamedTuple that includes the install_method
25+
"huggingface-cli/meta-llama/Meta-Llama-3-8B": "",
2226
}
2327

2428
JOB_RUNNERS = {
@@ -57,7 +61,7 @@ def parse_args() -> Any:
5761
return parser.parse_args()
5862

5963

60-
def model_should_run_on_event(model: str, event: str) -> bool:
64+
def model_should_run_on_event(model: str, event: str, backend: str) -> bool:
6165
"""
6266
A helper function to decide whether a model should be tested on an event (pull_request/push)
6367
We put higher priority and fast models to pull request and rest to push.
@@ -67,7 +71,11 @@ def model_should_run_on_event(model: str, event: str) -> bool:
6771
elif event == "push":
6872
return model in []
6973
elif event == "periodic":
70-
return model in ["openlm-research/open_llama_7b"]
74+
# test llama3 on gpu only, see description in https://github.com/pytorch/torchchat/pull/399 for reasoning
75+
if backend == "gpu":
76+
return model in ["openlm-research/open_llama_7b", "huggingface-cli/meta-llama/Meta-Llama-3-8B"]
77+
else:
78+
return model in ["openlm-research/open_llama_7b"]
7179
else:
7280
return False
7381

@@ -102,15 +110,25 @@ def export_models_for_ci() -> dict[str, dict]:
102110
MODEL_REPOS.keys(),
103111
JOB_RUNNERS[backend].items(),
104112
):
105-
if not model_should_run_on_event(repo_name, event):
113+
if not model_should_run_on_event(repo_name, event, backend):
106114
continue
107115

116+
# This is mostly temporary to get this finished quickly while
117+
# doing minimal changes, see TODO at the top of the file to
118+
# see how this should probably be done
119+
install_method = "wget"
120+
final_repo_name = repo_name
121+
if repo_name.startswith("huggingface-cli"):
122+
install_method = "huggingface-cli"
123+
final_repo_name = repo_name.replace("huggingface-cli/", "")
124+
108125
record = {
109-
"repo_name": repo_name,
110-
"model_name": repo_name.split("/")[-1],
126+
"repo_name": final_repo_name,
127+
"model_name": final_repo_name.split("/")[-1],
111128
"resources": MODEL_REPOS[repo_name],
112129
"runner": runner[0],
113130
"platform": runner[1],
131+
"install_method": install_method,
114132
"timeout": 90,
115133
}
116134

.github/workflows/periodic.yml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ jobs:
113113
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
114114
name: test-gpu (${{ matrix.platform }}, ${{ matrix.model_name }})
115115
needs: gather-models-gpu
116+
secrets: inherit
116117
strategy:
117118
matrix: ${{ fromJSON(needs.gather-models-gpu.outputs.models) }}
118119
fail-fast: false
119120
with:
121+
secrets-env: "HF_TOKEN_PERIODIC"
120122
runner: ${{ matrix.runner }}
121123
gpu-arch-type: cuda
122124
gpu-arch-version: "12.1"
@@ -126,15 +128,25 @@ jobs:
126128
echo "::endgroup::"
127129
128130
echo "::group::Install required packages"
129-
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
131+
pip install --progress-bar off --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
130132
pip install -r ./requirements.txt
131133
pip list
132134
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
133135
echo "::endgroup::"
134136
135137
echo "::group::Download checkpoint"
136-
export REPO_NAME=${{ matrix.repo_name }}
137-
bash .ci/scripts/wget_checkpoint.sh ${REPO_NAME} ${{ matrix.resources }}
138+
export REPO_NAME="${{ matrix.repo_name }}"
139+
case "${{ matrix.install_method }}" in
140+
wget)
141+
bash .ci/scripts/wget_checkpoint.sh "${REPO_NAME}" "${{ matrix.resources }}"
142+
;;
143+
huggingface-cli)
144+
(
145+
set +x
146+
HF_TOKEN="${SECRET_HF_TOKEN_PERIODIC}" bash .ci/scripts/download_llama.sh
147+
)
148+
;;
149+
esac
138150
echo "::endgroup::"
139151
140152
echo "::group::Convert checkpoint"

0 commit comments

Comments
 (0)