Skip to content

Commit 2063ec9

Browse files
author
Wang Napieralski
committed
fix: parameterize PT and TF version for HuggingFace tests
1 parent 5d8ce16 commit 2063ec9

File tree

3 files changed

+599
-32
lines changed

3 files changed

+599
-32
lines changed

tests/conftest.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -190,38 +190,11 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190
return "py3"
191191

192192

193-
def _huggingface_base_fm_version(huggingface_vesion, base_fw):
194-
config = image_uris.config_for_framework("huggingface")
195-
training_config = config.get("training")
196-
original_version = huggingface_vesion
197-
if "version_aliases" in training_config:
198-
huggingface_vesion = training_config.get("version_aliases").get(
199-
huggingface_vesion, huggingface_vesion
200-
)
201-
version_config = training_config.get("versions").get(huggingface_vesion)
202-
for key in list(version_config.keys()):
203-
if key.startswith(base_fw):
204-
base_fw_version = key[len(base_fw) :]
205-
if len(original_version.split(".")) == 2:
206-
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
207-
return base_fw_version
208-
209-
210193
@pytest.fixture(scope="module")
211194
def huggingface_pytorch_version(huggingface_training_version):
212195
return _huggingface_base_fm_version(huggingface_training_version, "pytorch")
213196

214197

215-
@pytest.fixture(scope="module")
216-
def huggingface_pytorch_latest_version(huggingface_training_latest_version):
217-
return _huggingface_base_fm_version(huggingface_training_latest_version, "pytorch")
218-
219-
220-
@pytest.fixture(scope="module")
221-
def huggingface_tensorflow_latest_version(huggingface_training_latest_version):
222-
return _huggingface_base_fm_version(huggingface_training_latest_version, "tensorflow")
223-
224-
225198
@pytest.fixture(scope="module")
226199
def pytorch_eia_py_version():
227200
return "py3"
@@ -395,6 +368,32 @@ def _generate_all_framework_version_fixtures(metafunc):
395368
)
396369

397370

371+
def _huggingface_base_fm_version(huggingface_vesion, base_fw):
372+
config = image_uris.config_for_framework("huggingface")
373+
training_config = config.get("training")
374+
original_version = huggingface_vesion
375+
if "version_aliases" in training_config:
376+
huggingface_vesion = training_config.get("version_aliases").get(
377+
huggingface_vesion, huggingface_vesion
378+
)
379+
version_config = training_config.get("versions").get(huggingface_vesion)
380+
versions = list()
381+
for key in list(version_config.keys()):
382+
if key.startswith(base_fw):
383+
base_fw_version = key[len(base_fw) :]
384+
if len(original_version.split(".")) == 2:
385+
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
386+
versions.append(base_fw_version)
387+
return versions
388+
389+
390+
def _generate_huggingface_base_fw_latest_versions(metafunc, huggingface_version, base_fw):
391+
versions = _huggingface_base_fm_version(huggingface_version, base_fw)
392+
fixture_name = f"huggingface_{base_fw}_latest_version"
393+
if fixture_name in metafunc.fixturenames:
394+
metafunc.parametrize(fixture_name, versions, scope="session")
395+
396+
398397
def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
399398
fixture_name = "{}_version".format(fixture_prefix)
400399
if fixture_name in metafunc.fixturenames:
@@ -407,6 +406,10 @@ def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
407406
if fixture_name in metafunc.fixturenames:
408407
metafunc.parametrize(fixture_name, (latest_version,), scope="session")
409408

409+
if "huggingface" in fixture_prefix:
410+
_generate_huggingface_base_fw_latest_versions(metafunc, latest_version, "pytorch")
411+
_generate_huggingface_base_fw_latest_versions(metafunc, latest_version, "tensorflow")
412+
410413
fixture_name = "{}_latest_py_version".format(fixture_prefix)
411414
if fixture_name in metafunc.fixturenames:
412415
config = config["versions"]

0 commit comments

Comments
 (0)