Skip to content

Commit a7252ec

Browse files
author
Rui Wang Napieralski
committed
modify huggingface_pytorch_version fixture
1 parent a0f97e2 commit a7252ec

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

tests/conftest.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,31 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190
return "py3"
191191

192192

193+
def _huggingface_pytorch_version(huggingface_vesion):
194+
config = image_uris.config_for_framework("huggingface")
195+
training_config = config.get("training")
196+
original_version = huggingface_vesion
197+
if "version_aliases" in training_config:
198+
huggingface_vesion = training_config.get("version_aliases").get(
199+
huggingface_vesion, huggingface_vesion
200+
)
201+
version_config = training_config.get("versions").get(huggingface_vesion)
202+
for key in list(version_config.keys()):
203+
if key.startswith("pytorch"):
204+
pt_version = key[7:]
205+
if len(original_version.split(".")) == 2:
206+
pt_version = ".".join(pt_version.split(".")[:-1])
207+
return pt_version
208+
209+
193210
@pytest.fixture(scope="module")
194211
def huggingface_pytorch_version(huggingface_training_version):
195-
if Version(huggingface_training_version) <= Version("4.4.2"):
196-
if len(huggingface_training_version.split(".")) == 3:
197-
return "1.6.0"
198-
else:
199-
return "1.6"
200-
else:
201-
pytest.skip("Skipping Huggingface version.")
212+
return _huggingface_pytorch_version(huggingface_training_version)
213+
214+
215+
@pytest.fixture(scope="module")
216+
def huggingface_pytorch_latest_version(huggingface_training_latest_version):
217+
return _huggingface_pytorch_version(huggingface_training_latest_version)
202218

203219

204220
@pytest.fixture(scope="module")

tests/integ/test_huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_huggingface_training(
3030
sagemaker_session,
3131
gpu_instance_type,
3232
huggingface_training_latest_version,
33-
huggingface_pytorch_version,
33+
huggingface_pytorch_latest_version,
3434
):
3535
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
3636
data_path = os.path.join(DATA_DIR, "huggingface")
@@ -40,7 +40,7 @@ def test_huggingface_training(
4040
entry_point="examples/text-classification/run_glue.py",
4141
role="SageMakerRole",
4242
transformers_version=huggingface_training_latest_version,
43-
pytorch_version=huggingface_pytorch_version,
43+
pytorch_version=huggingface_pytorch_latest_version,
4444
instance_count=1,
4545
instance_type=gpu_instance_type,
4646
hyperparameters={

0 commit comments

Comments
 (0)