@@ -190,7 +190,7 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190
190
return "py3"
191
191
192
192
193
- def _huggingface_pytorch_version (huggingface_vesion ):
193
+ def _huggingface_base_fm_version (huggingface_vesion , base_fw ):
194
194
config = image_uris .config_for_framework ("huggingface" )
195
195
training_config = config .get ("training" )
196
196
original_version = huggingface_vesion
@@ -200,21 +200,26 @@ def _huggingface_pytorch_version(huggingface_vesion):
200
200
)
201
201
version_config = training_config .get ("versions" ).get (huggingface_vesion )
202
202
for key in list (version_config .keys ()):
203
- if key .startswith ("pytorch" ):
204
- pt_version = key [7 :]
203
+ if key .startswith (base_fw ):
204
+ base_fw_version = key [len ( base_fw ) :]
205
205
if len (original_version .split ("." )) == 2 :
206
- pt_version = "." .join (pt_version .split ("." )[:- 1 ])
207
- return pt_version
206
+ base_fw_version = "." .join (base_fw_version .split ("." )[:- 1 ])
207
+ return base_fw_version
208
208
209
209
210
210
@pytest .fixture (scope = "module" )
211
211
def huggingface_pytorch_version (huggingface_training_version ):
212
- return _huggingface_pytorch_version (huggingface_training_version )
212
+ return _huggingface_base_fm_version (huggingface_training_version , "pytorch" )
213
213
214
214
215
215
@pytest .fixture (scope = "module" )
216
216
def huggingface_pytorch_latest_version (huggingface_training_latest_version ):
217
- return _huggingface_pytorch_version (huggingface_training_latest_version )
217
+ return _huggingface_base_fm_version (huggingface_training_latest_version , "pytorch" )
218
+
219
+
220
+ @pytest .fixture (scope = "module" )
221
+ def huggingface_tensorflow_latest_version (huggingface_training_latest_version ):
222
+ return _huggingface_base_fm_version (huggingface_training_latest_version , "tensorflow" )
218
223
219
224
220
225
@pytest .fixture (scope = "module" )
0 commit comments