Skip to content

Commit 179fd69

Browse files
helunwencserfacebook-github-bot
authored andcommitted
add ci job for eval_llama with mmlu task (#6337)
Summary: Pull Request resolved: #6337 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: cccclai Differential Revision: D64565303 Pulled By: helunwencser fbshipit-source-id: 2c20fd277775887a717208c27a791ab9a5482662
1 parent a343884 commit 179fd69

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

.ci/scripts/test_eval_llama_mmlu.sh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
10+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
11+
PYTHON_EXECUTABLE=python3
12+
fi
13+
14+
# Download and prepare stories model artifacts
15+
prepare_model_artifacts() {
16+
echo "Preparing stories model artifacts"
17+
wget -O stories110M.pt "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
18+
wget -O tokenizer.model "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
19+
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
20+
}
21+
22+
run_and_verify() {
23+
NOW=$(date +"%H:%M:%S")
24+
echo "Starting to run eval_llama at ${NOW}"
25+
if [[ ! -f "stories110M.pt" ]]; then
26+
echo "stories110M.pt is missing."
27+
exit 1
28+
fi
29+
if [[ ! -f "tokenizer.model" ]]; then
30+
echo "tokenizer.model is missing."
31+
exit 1
32+
fi
33+
if [[ ! -f "params.json" ]]; then
34+
echo "params.json is missing."
35+
exit 1
36+
fi
37+
$PYTHON_EXECUTABLE -m examples.models.llama.eval_llama \
38+
-c stories110M.pt \
39+
-p params.json \
40+
-t tokenizer.model \
41+
-kv \
42+
-d fp32 \
43+
--tasks mmlu \
44+
-f 5 \
45+
--max_seq_length 2048 \
46+
--limit 5 > result.txt
47+
48+
# Verify result.txt
49+
RESULT=$(cat result.txt)
50+
EXPECTED_TASK="mmlu"
51+
EXPECTED_RESULT="acc"
52+
if [[ "${RESULT}" == "${EXPECTED_TASK}: {"*"${EXPECTED_RESULT}"* ]]; then
53+
echo "Actual result: ${RESULT}"
54+
echo "Success"
55+
exit 0
56+
else
57+
echo "Actual result: ${RESULT}"
58+
echo "Failure; results not the same"
59+
exit 1
60+
fi
61+
}
62+
63+
prepare_model_artifacts
64+
run_and_verify

.github/workflows/pull.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,3 +474,30 @@ jobs:
474474
475475
# run eval_llama wikitext task
476476
PYTHON_EXECUTABLE=python bash .ci/scripts/test_eval_llama_wikitext.sh
477+
478+
test-eval_llama-mmlu-linux:
479+
name: test-eval_llama-mmlu-linux
480+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
481+
strategy:
482+
fail-fast: false
483+
with:
484+
runner: linux.24xlarge
485+
docker-image: executorch-ubuntu-22.04-clang12
486+
submodules: 'true'
487+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
488+
timeout: 90
489+
script: |
490+
# The generic Linux job chooses to use base env, not the one setup by the image
491+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
492+
conda activate "${CONDA_ENV}"
493+
494+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
495+
496+
# install pybind
497+
bash install_requirements.sh --pybind xnnpack
498+
499+
# install llama requirements
500+
bash examples/models/llama/install_requirements.sh
501+
502+
# run eval_llama mmlu task
503+
PYTHON_EXECUTABLE=python bash .ci/scripts/test_eval_llama_mmlu.sh

examples/models/llama/eval_llama_lib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@ def eval_llama(
291291
# Generate the eval wrapper
292292
eval_wrapper = gen_eval_wrapper(model_name, args)
293293

294+
# Needed for loading mmlu dataset.
295+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
296+
if args.tasks and "mmlu" in args.tasks:
297+
import datasets
298+
299+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
300+
294301
# Evaluate the model
295302
with torch.no_grad():
296303
eval_results = simple_evaluate(

0 commit comments

Comments
 (0)