Skip to content

Commit 1a8e3b5

Browse files
committed
init
1 parent f95fd44 commit 1a8e3b5

File tree

2 files changed

+87
-10
lines changed

2 files changed

+87
-10
lines changed

.ci/scripts/validate.sh

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ function generate_compiled_model_output() {
2525
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
2626
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
2727

28+
2829
if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
2930
DTYPES="bfloat16"
3031
EXCLUDE_INT8_QUANT=true
@@ -109,15 +110,18 @@ function generate_compiled_model_output() {
109110
function generate_aoti_model_output() {
110111
local CHECKPOINT_PATH="$1"
111112
local TARGET_DEVICE="${2:-cpu}"
113+
local DTYPES="${3-default}"
112114
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
113115
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
114116

115-
if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
116-
DTYPES="bfloat16"
117-
EXCLUDE_INT8_QUANT=true
118-
else
119-
DTYPES="float32 bfloat16 float16"
120-
EXCLUDE_INT8_QUANT=false
117+
if [[ DTYPES="default" ]]; then
118+
if [[ $CHECKPOINT_PATH != *"stories"* && $TARGET_DEVICE == "cuda" ]]; then
119+
DTYPES="bfloat16"
120+
EXCLUDE_INT8_QUANT=true
121+
else
122+
DTYPES="float32 bfloat16 float16"
123+
EXCLUDE_INT8_QUANT=false
124+
fi
121125
fi
122126

123127
for DTYPE in $DTYPES; do
@@ -295,7 +299,7 @@ function run_compile() {
295299
}
296300

297301
function run_aoti() {
298-
generate_aoti_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
302+
generate_aoti_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE" "$DTYPES" || exit 1
299303
}
300304

301305
function run_executorch() {
@@ -327,6 +331,7 @@ if [ "$#" -gt 2 ]; then
327331
run_compile || exit 1
328332
;;
329333
"aoti")
334+
DTYPES="${4:-default}"
330335
run_aoti || exit 1
331336
;;
332337
"executorch")

.github/workflows/pull.yml

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ jobs:
183183
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "compile"
184184
echo "::endgroup::"
185185
186-
test-gpu-aoti:
186+
test-gpu-aoti-bfloat16:
187187
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
188188
name: test-gpu-aoti (${{ matrix.platform }}, ${{ matrix.model_name }})
189189
needs: gather-models-gpu
@@ -216,7 +216,79 @@ jobs:
216216
echo "::endgroup::"
217217
218218
echo "::group::Run inference"
219-
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti"
219+
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti" "bfloat16"
220+
echo "::endgroup::"
221+
222+
test-gpu-aoti-float32:
223+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
224+
name: test-gpu-aoti (${{ matrix.platform }}, ${{ matrix.model_name }})
225+
needs: gather-models-gpu
226+
strategy:
227+
matrix: ${{ fromJSON(needs.gather-models-gpu.outputs.models) }}
228+
fail-fast: false
229+
with:
230+
runner: linux.g5.4xlarge.nvidia.gpu
231+
gpu-arch-type: cuda
232+
gpu-arch-version: "12.1"
233+
script: |
234+
echo "::group::Print machine info"
235+
nvidia-smi
236+
echo "::endgroup::"
237+
238+
echo "::group::Install required packages"
239+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
240+
pip install -r ./requirements.txt
241+
pip list
242+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
243+
echo "::endgroup::"
244+
245+
echo "::group::Download checkpoint"
246+
export REPO_NAME=${{ matrix.repo_name }}
247+
bash .ci/scripts/wget_checkpoint.sh ${REPO_NAME} ${{ matrix.resources }}
248+
echo "::endgroup::"
249+
250+
echo "::group::Convert checkpoint"
251+
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
252+
echo "::endgroup::"
253+
254+
echo "::group::Run inference"
255+
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti" "float32"
256+
echo "::endgroup::"
257+
258+
test-gpu-aoti-float16:
259+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
260+
name: test-gpu-aoti (${{ matrix.platform }}, ${{ matrix.model_name }})
261+
needs: gather-models-gpu
262+
strategy:
263+
matrix: ${{ fromJSON(needs.gather-models-gpu.outputs.models) }}
264+
fail-fast: false
265+
with:
266+
runner: linux.g5.4xlarge.nvidia.gpu
267+
gpu-arch-type: cuda
268+
gpu-arch-version: "12.1"
269+
script: |
270+
echo "::group::Print machine info"
271+
nvidia-smi
272+
echo "::endgroup::"
273+
274+
echo "::group::Install required packages"
275+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
276+
pip install -r ./requirements.txt
277+
pip list
278+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
279+
echo "::endgroup::"
280+
281+
echo "::group::Download checkpoint"
282+
export REPO_NAME=${{ matrix.repo_name }}
283+
bash .ci/scripts/wget_checkpoint.sh ${REPO_NAME} ${{ matrix.resources }}
284+
echo "::endgroup::"
285+
286+
echo "::group::Convert checkpoint"
287+
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
288+
echo "::endgroup::"
289+
290+
echo "::group::Run inference"
291+
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti" "float16"
220292
echo "::endgroup::"
221293
222294
test-gpu-eval-sanity-check:
@@ -749,7 +821,7 @@ jobs:
749821
750822
echo "Running compiled"
751823
python3 torchchat.py generate --gguf-path ${GGUF_PATH} --tokenizer-path ${TOKENIZER_PATH} --max-new-tokens 20 --temperature 0 --compile
752-
824+
753825
echo "******************************************"
754826
echo "******* Emb: channel-wise quantized ******"
755827
echo "******************************************"

0 commit comments

Comments
 (0)