Skip to content

make --device fast the default #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
227fce2
make --device fast the default
Apr 27, 2024
397bcb4
Update iOS.md (#517)
shoumikhin Apr 27, 2024
7a5d30f
Pip to pip3 (#504)
mikekgfb Apr 27, 2024
626be1c
break aoti CI jobs separately (#500)
metascroy Apr 27, 2024
e26c528
Support llama3 in chat in run.cpp (#486)
metascroy Apr 27, 2024
a7bc5ad
Add tests for quantize json, add cuda device specification and precis…
mikekgfb Apr 27, 2024
6d29eb6
remove code for no KV Cache path (#527)
mikekgfb Apr 27, 2024
875feff
Update ADVANCED-USERS.md (#529)
mikekgfb Apr 27, 2024
53714b7
runner-aoti on cuda (#531)
mikekgfb Apr 28, 2024
e039cad
Update runner_build.md (#530)
mikekgfb Apr 28, 2024
665a5da
clean up runner code a little (#532)
metascroy Apr 28, 2024
c53d44a
move int8 linear class and function into qops.py (#534)
mikekgfb Apr 29, 2024
a7a2457
add dtype tests for runner-aoti + runner-et (#539)
metascroy Apr 29, 2024
3394d36
Quantized embedding (#536)
mikekgfb Apr 29, 2024
6f7bc61
Move Linear int4 to qops (#537)
mikekgfb Apr 29, 2024
d2864b9
Revert "add dtype tests for runner-aoti + runner-et (#539)" (#548)
malfet Apr 29, 2024
af38291
fix generate for llama3 (#538)
metascroy Apr 29, 2024
aadac48
add delegation visualization instructions (#551)
lucylq Apr 29, 2024
33dc210
Add dtype runner aoti (#552)
metascroy Apr 29, 2024
e1c0815
test sdpa with fp16 (#553)
mikekgfb Apr 29, 2024
8903ab9
update (#560)
mikekgfb Apr 29, 2024
e12a164
Only support newest versions of lm-eval (#556)
jerryzh168 Apr 29, 2024
c26589a
split cpu eval CI by dtype (#554)
metascroy Apr 29, 2024
f867aba
Removing duplicate HF issue message from README (#559)
Jack-Khuu Apr 29, 2024
2afcc29
doc updates (#567)
metascroy Apr 30, 2024
cc84c6d
Merge branch 'default_device_fast' of https://github.com/pytorch/torc…
Apr 30, 2024
db547d4
Add VM-safe MPS check
Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions .ci/scripts/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function generate_compiled_model_output() {
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')


if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
DTYPES="bfloat16"
EXCLUDE_INT8_QUANT=true
Expand Down Expand Up @@ -74,7 +75,7 @@ function generate_compiled_model_output() {
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"

if [ "$EXCLUDE_INT8_QUANT" = false ]; then
if [ "${EXCLUDE_INT8_QUANT:-false}" == false ]; then
echo "******************************************"
echo "******* INT8 channel-wise quantized ******"
echo "******************************************"
Expand Down Expand Up @@ -109,17 +110,24 @@ function generate_compiled_model_output() {
function generate_aoti_model_output() {
local CHECKPOINT_PATH="$1"
local TARGET_DEVICE="${2:-cpu}"
local DTYPES="${3:-default}"
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')

if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
DTYPES="bfloat16"
EXCLUDE_INT8_QUANT=true
else
DTYPES="float32 bfloat16 float16"
EXCLUDE_INT8_QUANT=false
echo "Local DTYPES=$DTYPES"

if [[ $DTYPES == "default" ]]; then
if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
DTYPES="bfloat16"
EXCLUDE_INT8_QUANT=true
else
DTYPES="float32 bfloat16 float16"
EXCLUDE_INT8_QUANT=false
fi
fi

echo "Local after default DTYPES=$DTYPES"

for DTYPE in $DTYPES; do
echo ""############### Run inference with AOT Inductor for dtype $DTYPE "###############"
echo ""
Expand Down Expand Up @@ -158,7 +166,7 @@ function generate_aoti_model_output() {
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
cat "$MODEL_DIR/output_aoti"

if [ "$EXCLUDE_INT8_QUANT" = false ]; then
if [ "${EXCLUDE_INT8_QUANT:-false}" == false ]; then
echo "******************************************"
echo "******* INT8 channel-wise quantized ******"
echo "******************************************"
Expand Down Expand Up @@ -247,10 +255,11 @@ function eval_model() {
function eval_model_sanity_check() {
local CHECKPOINT_PATH="$1"
local TARGET_DEVICE="${2:-cpu}"
local DTYPES="$3"
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')

for DTYPE in float32 bfloat16 float16; do
for DTYPE in $DTYPES; do
echo ""############### Run eval with torch.compile for dtype $DTYPE "###############"
echo ""
echo "******************************************"
Expand Down Expand Up @@ -295,11 +304,12 @@ function run_compile() {
}

function run_aoti() {
generate_aoti_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
echo "Passing DTYPES=$DTYPES"
generate_aoti_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE" "$DTYPES" || exit 1
}

function run_executorch() {
if [ "$TARGET_DEVICE" = "cpu" ]; then
if [ "$TARGET_DEVICE" == "cpu" ]; then
generate_executorch_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
else
echo "Skipped: Executorch doesn't run on ${TARGET_DEVICE}"
Expand All @@ -311,31 +321,68 @@ function run_eval(){
}

function run_eval_sanity_check(){
eval_model_sanity_check "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
echo "Passing DTYPES=$DTYPES"
eval_model_sanity_check "$CHECKPOINT_PATH" "$TARGET_DEVICE" "$DTYPES" || exit 1
}

CHECKPOINT_PATH="$1"
TARGET_DEVICE="${2:-cpu}"
PROMPT="Hello, my name is"


if [ "$#" -gt 2 ]; then
# Additional arguments provided
for arg in "${@:3}"; do
case "$arg" in
"compile")
echo "arg:$arg"
run_compile || exit 1
;;
"aoti")
echo "arg:$arg"
DTYPES="default"
run_aoti || exit 1
;;
"aoti-bfloat16")
echo "arg:$arg"
DTYPES="bfloat16"
run_aoti || exit 1
;;
"aoti-float16")
echo "arg:$arg"
DTYPES="float16"
run_aoti || exit 1
;;
"aoti-float32")
echo "arg:$arg"
DTYPES="float32"
run_aoti || exit 1
;;
"executorch")
echo "arg:$arg"
run_executorch || exit 1
;;
"eval")
echo "arg:$arg"
run_eval || exit 1
;;
"eval_sanity_check")
echo "arg:$arg"
DTYPES="bfloat16 float16 float32"
run_eval_sanity_check || exit 1
;;
"eval_sanity_check-bfloat16")
echo "arg:$arg"
DTYPES="bfloat16"
run_eval_sanity_check || exit 1
;;
"eval_sanity_check-float16")
echo "arg:$arg"
DTYPES="float16"
run_eval_sanity_check || exit 1
;;
"eval_sanity_check-float32")
echo "arg:$arg"
DTYPES="float32"
run_eval_sanity_check || exit 1
;;
*)
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/hqq-dtype.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ jobs:
echo "::group::Download checkpoints"
# Install requirements
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install -r requirements.txt
pip list
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install -r requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
echo "::endgroup::"
Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/periodic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ jobs:
echo "$(uname -a)"
- name: Install dependencies
run: |
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r requirements.txt
pip list
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install -r requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Download checkpoints
run: |
Expand Down Expand Up @@ -81,9 +81,9 @@ jobs:
echo "$(uname -a)"
- name: Install dependencies
run: |
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r requirements.txt
pip list
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install -r requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Download checkpoints
run: |
Expand Down Expand Up @@ -128,9 +128,9 @@ jobs:
echo "::endgroup::"
echo "::group::Install required packages"
pip install --progress-bar off --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install -r ./requirements.txt
pip list
pip3 install --progress-bar off --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install -r ./requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
echo "::endgroup::"
Expand Down
Loading