Skip to content

Commit a1d4192

Browse files
committed
breaking: addressing comments in PR feedback
1 parent 5f4926e commit a1d4192

File tree

6 files changed

+15
-23
lines changed

6 files changed

+15
-23
lines changed

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,17 +686,15 @@ def _region_supports_debugger(region_name):
686686
def validate_version_or_image_args(framework_version, py_version, image_name):
687687
"""Checks if version or image arguments are specified.
688688
689-
Used to validate framework and model arguments to enforce version or image specification.
690-
Raises ValueError if version or image arguments are not specified.
689+
Validates framework and model arguments to enforce version or image specification.
691690
692691
Args:
693-
framework_version (str): the version of the framework
694-
py_version (str): the version of python
695-
image_name (str): the uri of the image
692+
framework_version (str): The version of the framework.
693+
py_version (str): The version of Python.
694+
image_name (str): The URI of the image.
696695
"""
697696
if (framework_version is None or py_version is None) and image_name is None:
698697
raise ValueError(
699698
"framework_version or py_version was None, yet image_name was also None. "
700699
"Either specify both framework_version and py_version, or specify image_name."
701700
)
702-
return True

src/sagemaker/pytorch/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ def __init__(
7070
If ``source_dir`` is specified, then ``entry_point``
7171
must point to a file located at the root of ``source_dir``.
7272
framework_version (str): PyTorch version you want to use for
73-
executing your model training code. Defaults to None. List of supported versions.
73+
executing your model training code. Defaults to ``None``. Required unless
74+
``image_name`` is provided. List of supported versions:
7475
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
7576
py_version (str): Python version you want to use for executing your
76-
model training code. One of 'py2' or 'py3'. Defaults to None.
77+
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
78+
unless ``image_name`` is provided.
7779
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7880
with any other training source code dependencies aside from the entry
7981
point file (default: None). If ``source_dir`` is an S3 URI, it must

src/sagemaker/pytorch/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,16 @@ def __init__(
8888
hosting. If ``source_dir`` is specified, then ``entry_point``
8989
must point to a file located at the root of ``source_dir``.
9090
framework_version (str): PyTorch version you want to use for
91-
executing your model training code. Defaults to None.
91+
executing your model training code. Defaults to None. Required
92+
unless ``image`` is provided.
9293
py_version (str): Python version you want to use for executing your
93-
model training code. Defaults to None.
94+
model training code. Defaults to ``None``. Required unless
95+
``image`` is provided.
9496
image (str): A Docker image URI (default: None). If not specified, a
9597
default image for PyTorch will be used.
9698
9799
If ``framework_version`` or ``py_version`` are ``None``, then
98-
``image_name`` is required. If also ``None``, then a ``ValueError``
100+
``image`` is required. If also ``None``, then a ``ValueError``
99101
will be raised.
100102
predictor_cls (callable[str, sagemaker.session.Session]): A function
101103
to call to create a predictor with an endpoint name and

tests/integ/test_pytorch_train.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,6 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
9898
predictor.delete_endpoint()
9999

100100

101-
@pytest.mark.skipif(
102-
PYTHON_VERSION == "py2",
103-
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
104-
)
105101
def test_deploy_model(
106102
pytorch_training_job, sagemaker_session, cpu_instance_type, pytorch_full_version
107103
):
@@ -129,10 +125,6 @@ def test_deploy_model(
129125
assert output.shape == (batch_size, 10)
130126

131127

132-
@pytest.mark.skipif(
133-
PYTHON_VERSION == "py2",
134-
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
135-
)
136128
def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instance_type):
137129
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
138130

tests/integ/test_source_dirs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
from tests.integ import DATA_DIR, PYTHON_VERSION
2121

2222
from sagemaker.pytorch.estimator import PyTorch
23-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2423

2524

2625
@pytest.mark.local_mode
27-
def test_source_dirs(tmpdir, sagemaker_local_session):
26+
def test_source_dirs(tmpdir, sagemaker_local_session, pytorch_full_version):
2827
source_dir = os.path.join(DATA_DIR, "pytorch_source_dirs")
2928
lib = os.path.join(str(tmpdir), "alexa.py")
3029

@@ -36,7 +35,7 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
3635
role="SageMakerRole",
3736
source_dir=source_dir,
3837
dependencies=[lib],
39-
framework_version=PYTORCH_VERSION,
38+
framework_version=pytorch_full_version,
4039
py_version=PYTHON_VERSION,
4140
train_instance_count=1,
4241
train_instance_type="local",

tests/unit/test_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def fixture_sagemaker_session():
8080
return session
8181

8282

83-
# TODO: push assertions regarding uri generation particulars to create_image_uri tests
8483
def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION):
8584
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "cpu", py_version)
8685

0 commit comments

Comments
 (0)