@@ -190,36 +190,9 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190
190
return "py3"
191
191
192
192
193
- def _huggingface_base_fm_version (huggingface_vesion , base_fw ):
194
- config = image_uris .config_for_framework ("huggingface" )
195
- training_config = config .get ("training" )
196
- original_version = huggingface_vesion
197
- if "version_aliases" in training_config :
198
- huggingface_vesion = training_config .get ("version_aliases" ).get (
199
- huggingface_vesion , huggingface_vesion
200
- )
201
- version_config = training_config .get ("versions" ).get (huggingface_vesion )
202
- for key in list (version_config .keys ()):
203
- if key .startswith (base_fw ):
204
- base_fw_version = key [len (base_fw ) :]
205
- if len (original_version .split ("." )) == 2 :
206
- base_fw_version = "." .join (base_fw_version .split ("." )[:- 1 ])
207
- return base_fw_version
208
-
209
-
210
193
@pytest .fixture (scope = "module" )
211
194
def huggingface_pytorch_version (huggingface_training_version ):
212
- return _huggingface_base_fm_version (huggingface_training_version , "pytorch" )
213
-
214
-
215
- @pytest .fixture (scope = "module" )
216
- def huggingface_pytorch_latest_version (huggingface_training_latest_version ):
217
- return _huggingface_base_fm_version (huggingface_training_latest_version , "pytorch" )
218
-
219
-
220
- @pytest .fixture (scope = "module" )
221
- def huggingface_tensorflow_latest_version (huggingface_training_latest_version ):
222
- return _huggingface_base_fm_version (huggingface_training_latest_version , "tensorflow" )
195
+ return _huggingface_base_fm_version (huggingface_training_version , "pytorch" )[0 ]
223
196
224
197
225
198
@pytest .fixture (scope = "module" )
@@ -395,6 +368,32 @@ def _generate_all_framework_version_fixtures(metafunc):
395
368
)
396
369
397
370
371
+ def _huggingface_base_fm_version (huggingface_vesion , base_fw ):
372
+ config = image_uris .config_for_framework ("huggingface" )
373
+ training_config = config .get ("training" )
374
+ original_version = huggingface_vesion
375
+ if "version_aliases" in training_config :
376
+ huggingface_vesion = training_config .get ("version_aliases" ).get (
377
+ huggingface_vesion , huggingface_vesion
378
+ )
379
+ version_config = training_config .get ("versions" ).get (huggingface_vesion )
380
+ versions = list ()
381
+ for key in list (version_config .keys ()):
382
+ if key .startswith (base_fw ):
383
+ base_fw_version = key [len (base_fw ) :]
384
+ if len (original_version .split ("." )) == 2 :
385
+ base_fw_version = "." .join (base_fw_version .split ("." )[:- 1 ])
386
+ versions .append (base_fw_version )
387
+ return versions
388
+
389
+
390
+ def _generate_huggingface_base_fw_latest_versions (metafunc , huggingface_version , base_fw ):
391
+ versions = _huggingface_base_fm_version (huggingface_version , base_fw )
392
+ fixture_name = f"huggingface_{ base_fw } _latest_version"
393
+ if fixture_name in metafunc .fixturenames :
394
+ metafunc .parametrize (fixture_name , versions , scope = "session" )
395
+
396
+
398
397
def _parametrize_framework_version_fixtures (metafunc , fixture_prefix , config ):
399
398
fixture_name = "{}_version" .format (fixture_prefix )
400
399
if fixture_name in metafunc .fixturenames :
@@ -407,6 +406,10 @@ def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
407
406
if fixture_name in metafunc .fixturenames :
408
407
metafunc .parametrize (fixture_name , (latest_version ,), scope = "session" )
409
408
409
+ if "huggingface" in fixture_prefix :
410
+ _generate_huggingface_base_fw_latest_versions (metafunc , latest_version , "pytorch" )
411
+ _generate_huggingface_base_fw_latest_versions (metafunc , latest_version , "tensorflow" )
412
+
410
413
fixture_name = "{}_latest_py_version" .format (fixture_prefix )
411
414
if fixture_name in metafunc .fixturenames :
412
415
config = config ["versions" ]
0 commit comments