Skip to content

Commit a343884

Browse files
helunwencserfacebook-github-bot
authored andcommitted
add CI job for eval_llama with wikitext task (#6336)
Summary: Pull Request resolved: #6336 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: cccclai Differential Revision: D64565304 Pulled By: helunwencser fbshipit-source-id: 58559df0a007830ddecea8e20297bbacb20dd773
1 parent 3ce5741 commit a343884

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
10+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
11+
PYTHON_EXECUTABLE=python3
12+
fi
13+
14+
# Download and prepare stories model artifacts
15+
prepare_model_artifacts() {
16+
echo "Preparing stories model artifacts"
17+
wget -O stories110M.pt "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
18+
wget -O tokenizer.model "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
19+
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
20+
}
21+
22+
run_and_verify() {
23+
NOW=$(date +"%H:%M:%S")
24+
echo "Starting to run eval_llama at ${NOW}"
25+
if [[ ! -f "stories110M.pt" ]]; then
26+
echo "stories110M.pt is missing."
27+
exit 1
28+
fi
29+
if [[ ! -f "tokenizer.model" ]]; then
30+
echo "tokenizer.model is missing."
31+
exit 1
32+
fi
33+
if [[ ! -f "params.json" ]]; then
34+
echo "params.json is missing."
35+
exit 1
36+
fi
37+
$PYTHON_EXECUTABLE -m examples.models.llama.eval_llama \
38+
-c stories110M.pt \
39+
-p params.json \
40+
-t tokenizer.model \
41+
-kv \
42+
-d fp32 \
43+
--max_seq_length 2048 \
44+
--limit 5 > result.txt
45+
46+
# Verify result.txt
47+
RESULT=$(cat result.txt)
48+
EXPECTED_TASK="wikitext"
49+
EXPECTED_RESULT="word_perplexity"
50+
if [[ "${RESULT}" == "${EXPECTED_TASK}: {"*"${EXPECTED_RESULT}"* ]]; then
51+
echo "Actual result: ${RESULT}"
52+
echo "Success"
53+
exit 0
54+
else
55+
echo "Actual result: ${RESULT}"
56+
echo "Failure; results not the same"
57+
exit 1
58+
fi
59+
}
60+
61+
prepare_model_artifacts
62+
run_and_verify

.github/workflows/pull.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,30 @@ jobs:
447447
448448
# run e2e (export, tokenizer and runner)
449449
PYTHON_EXECUTABLE=python bash .ci/scripts/test_phi_3_mini.sh
450+
451+
test-eval_llama-wikitext-linux:
452+
name: test-eval_llama-wikitext-linux
453+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
454+
strategy:
455+
fail-fast: false
456+
with:
457+
runner: linux.24xlarge
458+
docker-image: executorch-ubuntu-22.04-clang12
459+
submodules: 'true'
460+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
461+
timeout: 90
462+
script: |
463+
# The generic Linux job chooses to use base env, not the one setup by the image
464+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
465+
conda activate "${CONDA_ENV}"
466+
467+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
468+
469+
# install pybind
470+
bash install_requirements.sh --pybind xnnpack
471+
472+
# install llama requirements
473+
bash examples/models/llama/install_requirements.sh
474+
475+
# run eval_llama wikitext task
476+
PYTHON_EXECUTABLE=python bash .ci/scripts/test_eval_llama_wikitext.sh

examples/models/llama/evaluate/eager_eval.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ def __init__(
4040

4141
@property
4242
def eot_token_id(self):
43-
return self._tokenizer.eot_id
43+
"""
44+
The stories model does not have an EOT token, so we use the EOS token instead.
45+
"""
46+
if hasattr(self._tokenizer, "eot_id"):
47+
return self._tokenizer.eot_id
48+
return self._tokenizer.eos_id
4449

4550
@property
4651
def max_length(self):

0 commit comments

Comments
 (0)