Skip to content

Commit a9f134c

Browse files
guangy10malfet
authored andcommitted
Limit what could run for 7b model on cuda (#311)
1 parent e5363e4 commit a9f134c

File tree

3 files changed

+53
-31
lines changed

3 files changed

+53
-31
lines changed

.ci/scripts/convert_checkpoint.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ function convert_checkpoint() {
2222
return 0
2323
fi
2424

25+
[ -f "build/convert_hf_checkpoint.py" ] || exit 1
26+
2527
if [ -f "checkpoints/$MODEL_REPO/model.pth" ]; then
2628
echo "Converted checkpoint already exists. Skipping conversion for $MODEL_REPO."
2729
return 0
2830
fi
2931
echo "Convert Huggingface checkpoint for $MODEL_REPO"
30-
python3 scripts/convert_hf_checkpoint.py --checkpoint-dir "checkpoints/$MODEL_REPO"
32+
python3 build/convert_hf_checkpoint.py --checkpoint-dir "checkpoints/$MODEL_REPO"
3133
}
3234

3335

.ci/scripts/validate.sh

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@ function generate_compiled_model_output() {
2525
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
2626
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
2727

28-
for DTYPE in bfloat16 float16 float32; do
28+
if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
29+
DTYPES="bfloat16"
30+
EXCLUDE_INT8_QUANT=true
31+
else
32+
DTYPES="float32 bfloat16 float16"
33+
EXCLUDE_INT8_QUANT=false
34+
fi
35+
36+
for DTYPE in $DTYPES; do
2937
echo ""############### Run inference with torch.compile for dtype $DTYPE "###############"
3038
echo ""
3139
echo "******************************************"
@@ -66,21 +74,23 @@ function generate_compiled_model_output() {
6674
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
6775
cat "$MODEL_DIR/output_compiled"
6876

69-
echo "******************************************"
70-
echo "******* INT8 channel-wise quantized ******"
71-
echo "******************************************"
72-
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
73-
cat "$MODEL_DIR/output_eager"
74-
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
75-
cat "$MODEL_DIR/output_compiled"
76-
77-
echo "******************************************"
78-
echo "******** INT8 group-wise quantized *******"
79-
echo "******************************************"
80-
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
81-
cat "$MODEL_DIR/output_eager"
82-
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
83-
cat "$MODEL_DIR/output_compiled"
77+
if [ "$EXCLUDE_INT8_QUANT" = false ]; then
78+
echo "******************************************"
79+
echo "******* INT8 channel-wise quantized ******"
80+
echo "******************************************"
81+
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
82+
cat "$MODEL_DIR/output_eager"
83+
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
84+
cat "$MODEL_DIR/output_compiled"
85+
86+
echo "******************************************"
87+
echo "******** INT8 group-wise quantized *******"
88+
echo "******************************************"
89+
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
90+
cat "$MODEL_DIR/output_eager"
91+
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
92+
cat "$MODEL_DIR/output_compiled"
93+
fi
8494

8595
echo "******************************************"
8696
echo "******** INT4 group-wise quantized *******"
@@ -98,7 +108,15 @@ function generate_aoti_model_output() {
98108
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
99109
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
100110

101-
for DTYPE in bfloat16 float16 float32; do
111+
if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
112+
DTYPES="bfloat16"
113+
EXCLUDE_INT8_QUANT=true
114+
else
115+
DTYPES="float32 bfloat16 float16"
116+
EXCLUDE_INT8_QUANT=false
117+
fi
118+
119+
for DTYPE in $DTYPES; do
102120
echo ""############### Run inference with AOT Inductor for dtype $DTYPE "###############"
103121
echo ""
104122
echo "******************************************"
@@ -136,19 +154,21 @@ function generate_aoti_model_output() {
136154
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
137155
cat "$MODEL_DIR/output_aoti"
138156

139-
echo "******************************************"
140-
echo "******* INT8 channel-wise quantized ******"
141-
echo "******************************************"
142-
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
143-
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
144-
cat "$MODEL_DIR/output_aoti"
157+
if [ "$EXCLUDE_INT8_QUANT" = false ]; then
158+
echo "******************************************"
159+
echo "******* INT8 channel-wise quantized ******"
160+
echo "******************************************"
161+
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
162+
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
163+
cat "$MODEL_DIR/output_aoti"
145164

146-
echo "******************************************"
147-
echo "******** INT8 group-wise quantized *******"
148-
echo "******************************************"
149-
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
150-
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
151-
cat "$MODEL_DIR/output_aoti"
165+
echo "******************************************"
166+
echo "******** INT8 group-wise quantized *******"
167+
echo "******************************************"
168+
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
169+
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
170+
cat "$MODEL_DIR/output_aoti"
171+
fi
152172

153173
echo "******************************************"
154174
echo "******** INT4 group-wise quantized *******"

.github/workflows/periodic.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
matrix: ${{ fromJSON(needs.gather-models-gpu.outputs.models) }}
118118
fail-fast: false
119119
with:
120-
runner: linux.g5.12xlarge.nvidia.gpu
120+
runner: ${{ matrix.runner }}
121121
gpu-arch-type: cuda
122122
gpu-arch-version: "12.1"
123123
script: |

0 commit comments

Comments
 (0)