Skip to content

Commit 6109e08

Browse files
metascroymalfet
authored andcommitted
split cpu eval CI by dtype (#554)
* split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix
1 parent 05bd844 commit 6109e08

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

.ci/scripts/validate.sh

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,11 @@ function eval_model() {
255255
function eval_model_sanity_check() {
256256
local CHECKPOINT_PATH="$1"
257257
local TARGET_DEVICE="${2:-cpu}"
258+
local DTYPES="$3"
258259
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
259260
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
260261

261-
for DTYPE in float32 bfloat16 float16; do
262+
for DTYPE in $DTYPES; do
262263
echo ""############### Run eval with torch.compile for dtype $DTYPE "###############"
263264
echo ""
264265
echo "******************************************"
@@ -320,7 +321,8 @@ function run_eval(){
320321
}
321322

322323
function run_eval_sanity_check(){
323-
eval_model_sanity_check "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
324+
echo "Passing DTYPES=$DTYPES"
325+
eval_model_sanity_check "$CHECKPOINT_PATH" "$TARGET_DEVICE" "$DTYPES" || exit 1
324326
}
325327

326328
CHECKPOINT_PATH="$1"
@@ -365,6 +367,22 @@ if [ "$#" -gt 2 ]; then
365367
;;
366368
"eval_sanity_check")
367369
echo "arg:$arg"
370+
DTYPES="bfloat16 float16 float32"
371+
run_eval_sanity_check || exit 1
372+
;;
373+
"eval_sanity_check-bfloat16")
374+
echo "arg:$arg"
375+
DTYPES="bfloat16"
376+
run_eval_sanity_check || exit 1
377+
;;
378+
"eval_sanity_check-float16")
379+
echo "arg:$arg"
380+
DTYPES="float16"
381+
run_eval_sanity_check || exit 1
382+
;;
383+
"eval_sanity_check-float32")
384+
echo "arg:$arg"
385+
DTYPES="float32"
368386
run_eval_sanity_check || exit 1
369387
;;
370388
*)

.github/workflows/pull.yml

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,79 @@ jobs:
129129
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
130130
pushd ${TORCHCHAT_ROOT}
131131
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
132-
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check"
132+
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check-bfloat16"
133+
134+
test-cpu-eval-sanity-check-float16:
135+
name: test-cpu-eval-sanity-check-float16 (${{ matrix.platform }}, ${{ matrix.model_name }})
136+
needs: gather-models-cpu
137+
strategy:
138+
matrix: ${{ fromJSON(needs.gather-models-cpu.outputs.models) }}
139+
fail-fast: false
140+
runs-on: ${{ matrix.runner }}
141+
env:
142+
TORCHCHAT_ROOT: ${{ github.workspace }}
143+
REPO_NAME: ${{ matrix.repo_name }}
144+
steps:
145+
- name: Checkout repo
146+
uses: actions/checkout@v3
147+
- name: Setup Python
148+
uses: actions/setup-python@v4
149+
with:
150+
python-version: '3.11'
151+
- name: Print machine info
152+
run: |
153+
echo "$(uname -a)"
154+
- name: Install dependencies
155+
run: |
156+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
157+
pip3 install -r requirements.txt
158+
pip3 list
159+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
160+
- name: Download checkpoints
161+
run: |
162+
bash ${TORCHCHAT_ROOT}/.ci/scripts/wget_checkpoint.sh ${{ matrix.repo_name }} "${{ matrix.resources }}"
163+
- name: Run validation
164+
run: |
165+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
166+
pushd ${TORCHCHAT_ROOT}
167+
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
168+
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check-float16"
169+
170+
test-cpu-eval-sanity-check-float32:
171+
name: test-cpu-eval-sanity-check-float32 (${{ matrix.platform }}, ${{ matrix.model_name }})
172+
needs: gather-models-cpu
173+
strategy:
174+
matrix: ${{ fromJSON(needs.gather-models-cpu.outputs.models) }}
175+
fail-fast: false
176+
runs-on: ${{ matrix.runner }}
177+
env:
178+
TORCHCHAT_ROOT: ${{ github.workspace }}
179+
REPO_NAME: ${{ matrix.repo_name }}
180+
steps:
181+
- name: Checkout repo
182+
uses: actions/checkout@v3
183+
- name: Setup Python
184+
uses: actions/setup-python@v4
185+
with:
186+
python-version: '3.11'
187+
- name: Print machine info
188+
run: |
189+
echo "$(uname -a)"
190+
- name: Install dependencies
191+
run: |
192+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
193+
pip3 install -r requirements.txt
194+
pip3 list
195+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
196+
- name: Download checkpoints
197+
run: |
198+
bash ${TORCHCHAT_ROOT}/.ci/scripts/wget_checkpoint.sh ${{ matrix.repo_name }} "${{ matrix.resources }}"
199+
- name: Run validation
200+
run: |
201+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
202+
pushd ${TORCHCHAT_ROOT}
203+
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
204+
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check-float32"
133205
134206
gather-models-gpu:
135207
runs-on: ubuntu-22.04

0 commit comments

Comments
 (0)