Skip to content

Commit c55492e

Browse files
committed
breaking: small updates based on PR feedback
1 parent a1d4192 commit c55492e

File tree

3 files changed

+15
-17
lines changed

3 files changed

+15
-17
lines changed

tests/integ/test_git.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from tests.integ import lock as lock
2323
from sagemaker.mxnet.estimator import MXNet
24-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2524
from sagemaker.pytorch.estimator import PyTorch
2625
from sagemaker.sklearn.estimator import SKLearn
2726
from sagemaker.sklearn.model import SKLearnModel
@@ -52,15 +51,15 @@
5251

5352

5453
@pytest.mark.local_mode
55-
def test_github(sagemaker_local_session):
54+
def test_github(sagemaker_local_session, pytorch_full_version):
5655
script_path = "mnist.py"
5756
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
5857
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
5958
pytorch = PyTorch(
6059
entry_point=script_path,
6160
role="SageMakerRole",
6261
source_dir="pytorch",
63-
framework_version=PYTORCH_VERSION,
62+
framework_version=pytorch_full_version,
6463
py_version=PYTHON_VERSION,
6564
train_instance_count=1,
6665
train_instance_type="local",

tests/integ/test_tuner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from sagemaker.mxnet.estimator import MXNet
3737
from sagemaker.predictor import json_deserializer
3838
from sagemaker.pytorch import PyTorch
39-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
4039
from sagemaker.tensorflow import TensorFlow
4140
from sagemaker.tensorflow.defaults import LATEST_VERSION
4241
from sagemaker.tuner import (
@@ -820,15 +819,15 @@ def test_tuning_chainer(sagemaker_session, cpu_instance_type):
820819
reason="This test has always failed, but the failure was masked by a bug. "
821820
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
822821
)
823-
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type):
822+
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type, pytorch_full_version):
824823
mnist_dir = os.path.join(DATA_DIR, "pytorch_mnist")
825824
mnist_script = os.path.join(mnist_dir, "mnist.py")
826825

827826
estimator = PyTorch(
828827
entry_point=mnist_script,
829828
role="SageMakerRole",
830829
train_instance_count=1,
831-
framework_version=PYTORCH_VERSION,
830+
framework_version=pytorch_full_version,
832831
py_version=PYTHON_VERSION,
833832
train_instance_type=cpu_instance_type,
834833
sagemaker_session=sagemaker_session,

tests/unit/test_pytorch.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,7 @@ def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION):
9393

9494

9595
def _pytorch_estimator(
96-
sagemaker_session,
97-
framework_version=defaults.PYTORCH_VERSION,
98-
train_instance_type=None,
99-
base_job_name=None,
100-
**kwargs
96+
sagemaker_session, framework_version, train_instance_type=None, base_job_name=None, **kwargs
10197
):
10298
return PyTorch(
10399
entry_point=SCRIPT_PATH,
@@ -572,13 +568,17 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_version):
572568
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
573569

574570

575-
def test_pt_enable_sm_metrics(sagemaker_session):
576-
pytorch = _pytorch_estimator(sagemaker_session, enable_sagemaker_metrics=True)
571+
def test_pt_enable_sm_metrics(sagemaker_session, pytorch_full_version):
572+
pytorch = _pytorch_estimator(
573+
sagemaker_session, framework_version=pytorch_full_version, enable_sagemaker_metrics=True
574+
)
577575
assert pytorch.enable_sagemaker_metrics
578576

579577

580-
def test_pt_disable_sm_metrics(sagemaker_session):
581-
pytorch = _pytorch_estimator(sagemaker_session, enable_sagemaker_metrics=False)
578+
def test_pt_disable_sm_metrics(sagemaker_session, pytorch_full_version):
579+
pytorch = _pytorch_estimator(
580+
sagemaker_session, framework_version=pytorch_full_version, enable_sagemaker_metrics=False
581+
)
582582
assert not pytorch.enable_sagemaker_metrics
583583

584584

@@ -594,9 +594,9 @@ def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
594594
assert pytorch.enable_sagemaker_metrics
595595

596596

597-
def test_custom_image_estimator_deploy(sagemaker_session):
597+
def test_custom_image_estimator_deploy(sagemaker_session, pytorch_full_version):
598598
custom_image = "mycustomimage:latest"
599-
pytorch = _pytorch_estimator(sagemaker_session)
599+
pytorch = _pytorch_estimator(sagemaker_session, framework_version=pytorch_full_version)
600600
pytorch.fit(inputs="s3://mybucket/train", job_name="new_name")
601601
model = pytorch.create_model(image=custom_image)
602602
assert model.image == custom_image

0 commit comments

Comments
 (0)