@@ -61,7 +61,7 @@ def parse_args() -> Any:
61
61
return parser .parse_args ()
62
62
63
63
64
- def model_should_run_on_event (model : str , event : str ) -> bool :
64
+ def model_should_run_on_event (model : str , event : str , backend : str ) -> bool :
65
65
"""
66
66
A helper function to decide whether a model should be tested on an event (pull_request/push)
67
67
We put higher priority and fast models to pull request and rest to push.
@@ -71,7 +71,11 @@ def model_should_run_on_event(model: str, event: str) -> bool:
71
71
elif event == "push" :
72
72
return model in []
73
73
elif event == "periodic" :
74
- return model in ["openlm-research/open_llama_7b" , "huggingface-cli/meta-llama/Meta-Llama-3-8B" ]
74
+ # test llama3 on gpu only, see description in https://github.com/pytorch/torchchat/pull/399 for reasoning
75
+ if backend == "gpu" :
76
+ return model in ["openlm-research/open_llama_7b" , "huggingface-cli/meta-llama/Meta-Llama-3-8B" ]
77
+ else :
78
+ return model in ["openlm-research/open_llama_7b" ]
75
79
else :
76
80
return False
77
81
@@ -106,7 +110,7 @@ def export_models_for_ci() -> dict[str, dict]:
106
110
MODEL_REPOS .keys (),
107
111
JOB_RUNNERS [backend ].items (),
108
112
):
109
- if not model_should_run_on_event (repo_name , event ):
113
+ if not model_should_run_on_event (repo_name , event , backend ):
110
114
continue
111
115
112
116
# This is mostly temporary to get this finished quickly while
0 commit comments