Skip to content

Commit e49d36e

Browse files
guangy10malfet
authored andcommitted
Some simple bug fixes in the scripts (#155)
1 parent 741d7ba commit e49d36e

File tree

3 files changed

+72
-18
lines changed

3 files changed

+72
-18
lines changed

.ci/scripts/validate.sh

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#!/bin/bash
23
# Copyright (c) Meta Platforms, Inc. and affiliates.
34
# All rights reserved.
@@ -35,7 +36,7 @@ function generate_aoti_model_output() {
3536
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
3637
echo ""############### Run inference with AOTInductor for $MODEL_NAME "###############"
3738
python -W ignore export.py --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path "${MODEL_DIR}/${MODEL_NAME}.so" --device "$TARGET_DEVICE"
38-
python -W ignore generate.py --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" > "$MODEL_DIR/output_aoti"
39+
python -W ignore generate.py --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti"
3940
cat "$MODEL_DIR/output_aoti"
4041
}
4142

@@ -50,11 +51,49 @@ function generate_executorch_model_output() {
5051
cat "$MODEL_DIR/output_et"
5152
}
5253

54+
function run_compile() {
55+
generate_compiled_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE"
56+
}
57+
58+
function run_aoti() {
59+
generate_aoti_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE"
60+
}
61+
62+
function run_executorch() {
63+
if [ "$TARGET_DEVICE" = "cpu" ]; then
64+
generate_executorch_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE"
65+
else
66+
echo "Error: Executorch doesn't run on ${TARGET_DEVICE}"
67+
fi
68+
}
69+
5370

5471
CHECKPOINT_PATH="$1"
5572
TARGET_DEVICE="${2:-cpu}"
5673
PROMPT="Hello, my name is"
5774

58-
generate_compiled_model_output $CHECKPOINT_PATH $TARGET_DEVICE
59-
generate_aoti_model_output $CHECKPOINT_PATH $TARGET_DEVICE
60-
generate_executorch_model_output $CHECKPOINT_PATH $TARGET_DEVICE
75+
76+
if [ "$#" -gt 2 ]; then
77+
# Additional arguments provided
78+
for arg in "${@:3}"; do
79+
case "$arg" in
80+
"compile")
81+
run_compile
82+
;;
83+
"aoti")
84+
run_aoti
85+
;;
86+
"executorch")
87+
run_executorch
88+
;;
89+
*)
90+
echo "Unknown argument: $arg" >&2
91+
;;
92+
esac
93+
done
94+
else
95+
# No additional arguments provided, run all functions
96+
run_compile
97+
run_aoti
98+
run_executorch
99+
fi

.ci/scripts/wget_checkpoint.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@ MODEL_REPO="$1"
1111
RESOURCES_STRING="$2"
1212
CHECKPOINT_NAME="${MODEL_REPO##*/}"
1313

14-
pushd "${LLAMA_FAST_ROOT}" || exit
15-
1614
# Create the directory for the checkpoint
1715
mkdir -p "checkpoints/${MODEL_REPO}"
18-
cd "checkpoints/${MODEL_REPO}" || exit
16+
pushd "checkpoints/${MODEL_REPO}" || exit
1917

2018
# Download all resources
2119
IFS=',' # Set the field separator to comma

scripts/workflow.sh

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,23 @@ function download_checkpoint() {
4949
fi
5050
}
5151

52+
function run_validation_e2e() {
53+
local MODEL_REPO="$1"
54+
55+
echo ""
56+
echo "############### Validating ${MODEL_REPO##*/} ###############"
57+
download_checkpoint "$MODEL_REPO"
58+
bash .ci/scripts/convert_checkpoint.sh "$MODEL_REPO"
59+
60+
set +e
61+
CHECKPOINT_PATH="checkpoints/$MODEL_REPO/$CHECKPOINT_FILENAME"
62+
if [ -z "$ADDITIONAL_ARG" ]; then
63+
bash .ci/scripts/validate.sh "$CHECKPOINT_PATH" "$DEVICE"
64+
else
65+
bash .ci/scripts/validate.sh "$CHECKPOINT_PATH" "$DEVICE" "$ADDITIONAL_ARG"
66+
fi
67+
}
68+
5269

5370
# List of models to validate
5471
MODEL_REPOS=(
@@ -59,26 +76,26 @@ MODEL_REPOS=(
5976
"mistralai/Mistral-7B-Instruct-v0.1"
6077
"mistralai/Mistral-7B-Instruct-v0.2"
6178
# "openlm-research/open_llama_7b"
62-
# "codellama/CodeLlama-7b-Python-hf"
63-
# "codellama/CodeLlama-34b-Python-hf"
79+
"codellama/CodeLlama-7b-Python-hf"
80+
"codellama/CodeLlama-34b-Python-hf"
6481
# "meta-llama/Llama-2-7b-chat-hf"
6582
# "meta-llama/Llama-2-13b-chat-hf"
6683
# "meta-llama/Llama-2-70b-chat-hf"
6784
)
6885

6986
PROMPT="Hello, my name is"
7087
DEVICE="${1:-cpu}"
88+
INPUT_MODEL_REPO="${2:-}"
89+
ADDITIONAL_ARG="${3:-}"
7190
CHECKPOINT_FILENAME="model.pth"
7291

7392
echo "###############################################################"
7493
echo "############## Start LLama-fast Model Validation ##############"
7594
echo "###############################################################"
76-
for MODEL_REPO in "${MODEL_REPOS[@]}"; do
77-
echo "############### Validating ${MODEL_REPO##*/} ###############"
78-
download_checkpoint "$MODEL_REPO"
79-
bash .ci/scripts/convert_checkpoint.sh "$MODEL_REPO"
80-
81-
set +e
82-
CHECKPOINT_PATH="checkpoints/$MODEL_REPO/$CHECKPOINT_FILENAME"
83-
bash .ci/scripts/validate.sh "$CHECKPOINT_PATH" "$DEVICE"
84-
done
95+
if [ -z "$INPUT_MODEL_REPO" ]; then
96+
for MODEL_REPO in "${MODEL_REPOS[@]}"; do
97+
run_validation_e2e "$MODEL_REPO" "$DEVICE"
98+
done
99+
else
100+
run_validation_e2e "$INPUT_MODEL_REPO" "$DEVICE" "$ADDITIONAL_ARG"
101+
fi

0 commit comments

Comments
 (0)