Skip to content

Commit 610ade6

Browse files
committed
limit test to gpu
Signed-off-by: Eli Uriegas <[email protected]> Signed-off-by: Eli Uriegas <[email protected]>
1 parent 8cff173 commit 610ade6

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def parse_args() -> Any:
6161
return parser.parse_args()
6262

6363

64-
def model_should_run_on_event(model: str, event: str) -> bool:
64+
def model_should_run_on_event(model: str, event: str, backend: str) -> bool:
6565
"""
6666
A helper function to decide whether a model should be tested on an event (pull_request/push)
6767
We put higher priority and fast models to pull request and rest to push.
@@ -71,7 +71,11 @@ def model_should_run_on_event(model: str, event: str) -> bool:
7171
elif event == "push":
7272
return model in []
7373
elif event == "periodic":
74-
return model in ["openlm-research/open_llama_7b", "huggingface-cli/meta-llama/Meta-Llama-3-8B"]
74+
# test llama3 on gpu only, see description in https://github.com/pytorch/torchchat/pull/399 for reasoning
75+
if backend == "gpu":
76+
return model in ["openlm-research/open_llama_7b", "huggingface-cli/meta-llama/Meta-Llama-3-8B"]
77+
else:
78+
return model in ["openlm-research/open_llama_7b"]
7579
else:
7680
return False
7781

@@ -106,7 +110,7 @@ def export_models_for_ci() -> dict[str, dict]:
106110
MODEL_REPOS.keys(),
107111
JOB_RUNNERS[backend].items(),
108112
):
109-
if not model_should_run_on_event(repo_name, event):
113+
if not model_should_run_on_event(repo_name, event, backend):
110114
continue
111115

112116
# This is mostly temporary to get this finished quickly while

0 commit comments

Comments
 (0)