Skip to content

Commit 839be3c

Browse files
committed
breaking: updates based on PR feedback
1 parent 9b92c02 commit 839be3c

File tree

4 files changed

+7
-11
lines changed

4 files changed

+7
-11
lines changed

src/sagemaker/pytorch/model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,9 @@ def __init__(
9494
model training code. Defaults to ``None``. Required unless
9595
``image`` is provided.
9696
image (str): A Docker image URI (default: None). If not specified, a
97-
default image for PyTorch will be used.
98-
99-
If ``framework_version`` or ``py_version`` are ``None``, then
100-
``image`` is required. If also ``None``, then a ``ValueError``
101-
will be raised.
97+
default image for PyTorch will be used. If ``framework_version``
98+
or ``py_version`` are ``None``, then ``image`` is required. If
99+
also ``None``, then a ``ValueError`` will be raised.
102100
predictor_cls (callable[str, sagemaker.session.Session]): A function
103101
to call to create a predictor with an endpoint name and
104102
SageMaker ``Session``. If specified, ``deploy()`` returns the

tests/integ/test_airflow_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,13 +608,13 @@ def test_xgboost_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu
608608

609609
@pytest.mark.canary_quick
610610
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
611-
sagemaker_session, cpu_instance_type
611+
sagemaker_session, cpu_instance_type, pytorch_full_version
612612
):
613613
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
614614
estimator = PyTorch(
615615
entry_point=PYTORCH_MNIST_SCRIPT,
616616
role=ROLE,
617-
framework_version="1.3.1",
617+
framework_version=pytorch_full_version,
618618
py_version="py3",
619619
train_instance_count=2,
620620
train_instance_type=cpu_instance_type,

tests/integ/test_git.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from tests.integ import lock as lock
2323
from sagemaker.mxnet.estimator import MXNet
24-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2524
from sagemaker.pytorch.estimator import PyTorch
2625
from sagemaker.sklearn.estimator import SKLearn
2726
from sagemaker.sklearn.model import SKLearnModel
@@ -63,7 +62,7 @@ def test_github(sagemaker_local_session):
6362
entry_point=script_path,
6463
role="SageMakerRole",
6564
source_dir="pytorch",
66-
framework_version=PYTORCH_VERSION,
65+
framework_version="0.4", # hard-code to last known good pytorch for now (see TODO above)
6766
py_version=PYTHON_VERSION,
6867
train_instance_count=1,
6968
train_instance_type="local",

tests/integ/test_source_dirs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import tests.integ.lock as lock
2020
from tests.integ import DATA_DIR, PYTHON_VERSION
2121

22-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2322
from sagemaker.pytorch.estimator import PyTorch
2423

2524

@@ -38,7 +37,7 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
3837
role="SageMakerRole",
3938
source_dir=source_dir,
4039
dependencies=[lib],
41-
framework_version=PYTORCH_VERSION,
40+
framework_version="0.4", # hard-code to last known good pytorch for now (see TODO above)
4241
py_version=PYTHON_VERSION,
4342
train_instance_count=1,
4443
train_instance_type="local",

0 commit comments

Comments
 (0)