|
| 1 | +#!/bin/bash |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +set -exu |
| 9 | + |
| 10 | +MODEL_NAME=$1 # stories110M.pt |
| 11 | +BUILD_TOOL=$2 # buck2 |
| 12 | +DTYPE=$3 # fp16 or fp32 |
| 13 | + |
| 14 | +if [[ -z "${MODEL_NAME:-}" ]]; then |
| 15 | + echo "Missing model name, exiting..." |
| 16 | + exit 1 |
| 17 | +fi |
| 18 | + |
| 19 | +if [[ -z "${BUILD_TOOL:-}" ]]; then |
| 20 | + echo "Missing build tool (require buck2 or cmake), exiting..." |
| 21 | + exit |
| 22 | +fi |
| 23 | + |
| 24 | +if [[ -z "${DTYPE:-}" ]]; then |
| 25 | + echo "Missing dtype, choose fp16 or fp32, exiting..." |
| 26 | + exit 1 |
| 27 | +fi |
| 28 | + |
| 29 | +which "${PYTHON_EXECUTABLE}" |
| 30 | + |
| 31 | +# Check build tool. |
| 32 | +if [[ "${BUILD_TOOL}" == "buck2" ]]; then |
| 33 | + : |
| 34 | +else |
| 35 | + echo "Invalid build tool ${BUILD_TOOL}. Only buck2 is supported atm" |
| 36 | + exit 1 |
| 37 | +fi |
| 38 | + |
| 39 | +cleanup_files() { |
| 40 | + echo "Deleting downloaded and generated files" |
| 41 | + rm "${MODEL_NAME}" |
| 42 | + rm tokenizer.model |
| 43 | + rm tokenizer.bin |
| 44 | + rm "${EXPORTED_MODEL_NAME}" |
| 45 | +} |
| 46 | + |
| 47 | +# Download and create artifacts. |
| 48 | +PARAMS="params.json" |
| 49 | +touch "${PARAMS}" |
| 50 | +if [[ "${MODEL_NAME}" == "stories110M.pt" ]]; then |
| 51 | + # Download stories110M.pt and tokenizer from Github |
| 52 | + wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" |
| 53 | + wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" |
| 54 | + # Create params.json file |
| 55 | + echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > "${PARAMS}" |
| 56 | +else |
| 57 | + echo "Unsupported model name ${MODEL_NAME}" |
| 58 | + exit 1 |
| 59 | +fi |
| 60 | + |
| 61 | +# Check dtype. |
| 62 | +EXPORTED_MODEL_NAME="llama2" |
| 63 | +if [[ "${DTYPE}" == "fp16" ]]; then |
| 64 | + EXPORTED_MODEL_NAME="${EXPORTED_MODEL_NAME}_h" |
| 65 | +elif [[ "${DTYPE}" == "fp32" ]]; then |
| 66 | + : |
| 67 | +else |
| 68 | + echo "Unsupported dtype ${DTYPE}" |
| 69 | + exit 1 |
| 70 | +fi |
| 71 | + |
| 72 | +# Export model. |
| 73 | +EXPORTED_MODEL_NAME="${EXPORTED_MODEL_NAME}.pte" |
| 74 | +echo "Exporting ${EXPORTED_MODEL_NAME}" |
| 75 | +python3 -m examples.models.llama2.export_llama -c stories110M.pt -p "${PARAMS}" -d "${DTYPE}" |
| 76 | + |
| 77 | +# Create tokenizer.bin. |
| 78 | +echo "Creating tokenizer.bin" |
| 79 | +buck2 run examples/models/llama2/tokenizer:tokenizer_py -- -t tokenizer.model -o tokenizer.bin |
| 80 | + |
| 81 | +# Run model. |
| 82 | +echo "Running ${EXPORTED_MODEL_NAME} in portable mode" |
| 83 | +RESULT=$(timeout 500s buck2 run examples/models/llama2:main -- --model_path="${EXPORTED_MODEL_NAME}" --tokenizer_path=tokenizer.bin --prompt="Once" --temperature=0) || true |
| 84 | + |
| 85 | +# Check results. |
| 86 | +EXPECTED_PREFIX="Once upon a time," |
| 87 | +# Expected result - may take too long to generate: |
| 88 | +# "Once upon a time, there was a little girl named Lily. She loved to play outside" ... |
| 89 | +if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then |
| 90 | + echo "Expected result prefix: ${EXPECTED_PREFIX}" |
| 91 | + echo "Actual result: ${RESULT}" |
| 92 | + echo "Success" |
| 93 | + |
| 94 | + cleanup_files |
| 95 | +else |
| 96 | + echo "Expected result prefix: ${EXPECTED_PREFIX}" |
| 97 | + echo "Actual result: ${RESULT}" |
| 98 | + echo "Failure; results not the same" |
| 99 | + |
| 100 | + cleanup_files |
| 101 | + exit 1 |
| 102 | +fi |
0 commit comments