Skip to content

Commit 8452f5f

Browse files
committed
fix unit tests
1 parent 5ad0510 commit 8452f5f

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

tests/conftest.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import json
1616
import os
17-
import pdb
1817

1918
import boto3
2019
import pytest
@@ -193,8 +192,9 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
193192

194193
@pytest.fixture(scope="module")
195194
def huggingface_pytorch_training_version(huggingface_training_version):
196-
return _huggingface_base_fm_version(huggingface_training_version, "pytorch", "huggingface_training")[0]
197-
195+
return _huggingface_base_fm_version(
196+
huggingface_training_version, "pytorch", "huggingface_training"
197+
)[0]
198198

199199

200200
@pytest.fixture(scope="module")
@@ -386,14 +386,16 @@ def _huggingface_base_fm_version(huggingface_version, base_fw, fixture_prefix):
386386

387387
for key in list(version_config.keys()):
388388
if key.startswith(base_fw):
389-
base_fw_version = key[len(base_fw):]
389+
base_fw_version = key[len(base_fw) :]
390390
if len(original_version.split(".")) == 2:
391391
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
392392
versions.append(base_fw_version)
393393
return versions
394394

395395

396-
def _generate_huggingface_base_fw_latest_versions(metafunc, fixture_prefix, huggingface_version, base_fw):
396+
def _generate_huggingface_base_fw_latest_versions(
397+
metafunc, fixture_prefix, huggingface_version, base_fw
398+
):
397399
versions = _huggingface_base_fm_version(huggingface_version, base_fw, fixture_prefix)
398400
fixture_name = f"{fixture_prefix}_{base_fw}_latest_version"
399401

@@ -414,8 +416,12 @@ def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
414416
metafunc.parametrize(fixture_name, (latest_version,), scope="session")
415417

416418
if "huggingface" in fixture_prefix:
417-
_generate_huggingface_base_fw_latest_versions(metafunc, fixture_prefix, latest_version, "pytorch")
418-
_generate_huggingface_base_fw_latest_versions(metafunc, fixture_prefix, latest_version, "tensorflow")
419+
_generate_huggingface_base_fw_latest_versions(
420+
metafunc, fixture_prefix, latest_version, "pytorch"
421+
)
422+
_generate_huggingface_base_fw_latest_versions(
423+
metafunc, fixture_prefix, latest_version, "tensorflow"
424+
)
419425

420426
fixture_name = "{}_latest_py_version".format(fixture_prefix)
421427
if fixture_name in metafunc.fixturenames:

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import json
1717
import os
18-
import pdb
1918

2019
import pytest
2120
from mock import MagicMock, Mock, patch
@@ -267,11 +266,12 @@ def test_huggingface(
267266
assert actual_train_args == expected_train_args
268267

269268

270-
def test_attach(sagemaker_session, huggingface_training_version, huggingface_training_pytorch_latest_version):
271-
print(huggingface_training_version)
269+
def test_attach(
270+
sagemaker_session, huggingface_training_version, huggingface_pytorch_training_version
271+
):
272272
training_image = (
273273
f"1.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
274-
f"{huggingface_training_pytorch_latest_version}-"
274+
f"{huggingface_pytorch_training_version}-"
275275
f"transformers{huggingface_training_version}-gpu-py36-cu110-ubuntu18.04"
276276
)
277277
returned_job_description = {
@@ -306,7 +306,7 @@ def test_attach(sagemaker_session, huggingface_training_version, huggingface_tra
306306
assert estimator.latest_training_job.job_name == "neo"
307307
assert estimator.py_version == "py36"
308308
assert estimator.framework_version == huggingface_training_version
309-
assert estimator.pytorch_version == huggingface_training_pytorch_latest_version
309+
assert estimator.pytorch_version == huggingface_pytorch_training_version
310310
assert estimator.role == "arn:aws:iam::366:role/SageMakerRole"
311311
assert estimator.instance_count == 1
312312
assert estimator.max_run == 24 * 60 * 60

0 commit comments

Comments
 (0)