@@ -190,15 +190,31 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190
190
return "py3"
191
191
192
192
193
+ def _huggingface_pytorch_version (huggingface_vesion ):
194
+ config = image_uris .config_for_framework ("huggingface" )
195
+ training_config = config .get ("training" )
196
+ original_version = huggingface_vesion
197
+ if "version_aliases" in training_config :
198
+ huggingface_vesion = training_config .get ("version_aliases" ).get (
199
+ huggingface_vesion , huggingface_vesion
200
+ )
201
+ version_config = training_config .get ("versions" ).get (huggingface_vesion )
202
+ for key in list (version_config .keys ()):
203
+ if key .startswith ("pytorch" ):
204
+ pt_version = key [7 :]
205
+ if len (original_version .split ("." )) == 2 :
206
+ pt_version = "." .join (pt_version .split ("." )[:- 1 ])
207
+ return pt_version
208
+
209
+
193
210
@pytest .fixture (scope = "module" )
194
211
def huggingface_pytorch_version (huggingface_training_version ):
195
- if Version (huggingface_training_version ) <= Version ("4.4.2" ):
196
- if len (huggingface_training_version .split ("." )) == 3 :
197
- return "1.6.0"
198
- else :
199
- return "1.6"
200
- else :
201
- pytest .skip ("Skipping Huggingface version." )
212
+ return _huggingface_pytorch_version (huggingface_training_version )
213
+
214
+
215
+ @pytest .fixture (scope = "module" )
216
+ def huggingface_pytorch_latest_version (huggingface_training_latest_version ):
217
+ return _huggingface_pytorch_version (huggingface_training_latest_version )
202
218
203
219
204
220
@pytest .fixture (scope = "module" )
0 commit comments