Skip to content

Commit 210247c

Browse files
committed
fix: updates based on PR feedback
1 parent 735c5ef commit 210247c

File tree

14 files changed

+112
-100
lines changed

14 files changed

+112
-100
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ the difference in code would be as follows:
104104
...
105105
source_dir="code",
106106
framework_version="1.10.0",
107-
py_version="py3",
107+
py_version="py2",
108108
train_instance_type="ml.m4.xlarge",
109109
image_name="520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2",
110110
hyperparameters={

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,7 @@ def create_model(
254254
)
255255

256256
if self.framework == RLFramework.TENSORFLOW.value:
257-
return TensorFlowModel(
258-
framework_version=self.framework_version, py_version=PYTHON_VERSION, **base_args
259-
)
257+
return TensorFlowModel(framework_version=self.framework_version, **base_args)
260258
if self.framework == RLFramework.MXNET.value:
261259
return MXNetModel(
262260
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args

src/sagemaker/tensorflow/estimator.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def __init__(
5656
5757
Args:
5858
py_version (str): Python version you want to use for executing your model training
59-
code. One of 'py2', 'py3', or 'py37'. Defaults to ``None``. Required unless
60-
``image_name`` is provided.
59+
code. Defaults to ``None``. Required unless ``image_name`` is provided.
6160
framework_version (str): TensorFlow version you want to use for executing your model
6261
training code. Defaults to ``None``. Required unless ``image_name`` is provided.
6362
List of supported versions:
@@ -144,15 +143,11 @@ def __init__(
144143
self.model_dir = model_dir
145144
self.distributions = distributions or {}
146145

147-
self._validate_args(py_version=py_version, framework_version=self.framework_version)
146+
self._validate_args(py_version=py_version)
148147

149-
def _validate_args(self, py_version, framework_version):
148+
def _validate_args(self, py_version):
150149
"""Placeholder docstring"""
151150

152-
if py_version:
153-
if framework_version is None:
154-
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR)
155-
156151
if py_version == "py2" and self._only_python_3_supported():
157152
msg = (
158153
"Python 2 containers are only available with {} and lower versions. "
@@ -293,7 +288,6 @@ def create_model(
293288
role=role or self.role,
294289
container_log_level=self.container_log_level,
295290
framework_version=self.framework_version,
296-
py_version=self.py_version,
297291
sagemaker_session=self.sagemaker_session,
298292
vpc_config=self.get_vpc_config(vpc_config_override),
299293
entry_point=entry_point,

src/sagemaker/tensorflow/model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import sagemaker
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
20-
from sagemaker.fw_utils import create_image_uri, validate_version_or_image_args
20+
from sagemaker.fw_utils import create_image_uri
2121
from sagemaker.predictor import json_serializer, json_deserializer
2222

2323

@@ -138,7 +138,6 @@ def __init__(
138138
entry_point=None,
139139
image=None,
140140
framework_version=None,
141-
py_version=None,
142141
container_log_level=None,
143142
predictor_cls=TensorFlowPredictor,
144143
**kwargs
@@ -159,15 +158,11 @@ def __init__(
159158
must point to a file located at the root of ``source_dir``.
160159
image (str): A Docker image URI (default: None). If not specified, a
161160
default image for TensorFlow Serving will be used. If
162-
``framework_version`` or ``py_version`` are ``None``, then
163-
``image`` is required. If also ``None``, then a ``ValueError``
164-
will be raised.
161+
``framework_version`` is ``None``, then ``image`` is required.
162+
If also ``None``, then a ``ValueError`` will be raised.
165163
framework_version (str): Optional. TensorFlow Serving version you
166164
want to use. Defaults to ``None``. Required unless ``image`` is
167165
provided.
168-
py_version (str): Python version you want to use for executing your
169-
model training code. One of 'py2', 'py3', or 'py37'. Defaults to
170-
``None``. Required unless ``image`` is provided.
171166
container_log_level (int): Log level to use within the container
172167
(default: logging.ERROR). Valid values are defined in the Python
173168
logging module.
@@ -183,9 +178,12 @@ def __init__(
183178
:class:`~sagemaker.model.FrameworkModel` and
184179
:class:`~sagemaker.model.Model`.
185180
"""
186-
validate_version_or_image_args(framework_version, py_version, image)
181+
if framework_version is None and image is None:
182+
raise ValueError(
183+
"Both framework_version and image were None. "
184+
"Either specify framework_version or specify image_name."
185+
)
187186
self.framework_version = framework_version
188-
self.py_version = py_version
189187

190188
super(TensorFlowModel, self).__init__(
191189
model_data=model_data,

tests/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def xgboost_version(request):
209209
"1.11.0",
210210
"1.12",
211211
"1.12.0",
212-
"1.13",
213212
"1.14",
214213
"1.14.0",
215214
"1.15",
@@ -226,6 +225,16 @@ def tf_version(request):
226225
return request.param
227226

228227

228+
@pytest.fixture(scope="module", params=["py2", "py3"])
229+
def tf_py_version(tf_version, request):
230+
version = [int(val) for val in tf_version.split(".")]
231+
if version < [1, 13]:
232+
return "py2"
233+
if version < [2, 2]:
234+
return request.param
235+
return "py37"
236+
237+
229238
@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"])
230239
def rl_coach_tf_version(request):
231240
return request.param
@@ -290,6 +299,23 @@ def tf_full_version(request):
290299
return tf_version
291300

292301

302+
@pytest.fixture(scope="module")
303+
def tf_full_py_version(tf_full_version, request):
304+
"""fixture to match tf_full_version
305+
306+
Fixture exists as such, since tf_full_version may be overridden --tf-full-version.
307+
Otherwise, this would simply be py37 to match the latest version support.
308+
309+
TODO: Evaluate use of --tf-full-version with possible eye to remove and simplify code.
310+
"""
311+
version = [int(val) for val in tf_full_version.split(".")]
312+
if version < [1, 13]:
313+
return "py2"
314+
if tf_full_version in [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]:
315+
return "py37"
316+
return "py3"
317+
318+
293319
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
294320
def ei_tf_full_version(request):
295321
tf_ei_version = request.config.getoption("--ei-tf-full-version")

tests/integ/test_airflow_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def test_tf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
562562
train_instance_type=cpu_instance_type,
563563
sagemaker_session=sagemaker_session,
564564
framework_version=TensorFlow.LATEST_VERSION,
565-
py_version=PYTHON_VERSION,
565+
py_version="py37", # only version available with 2.2.0
566566
metric_definitions=[
567567
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
568568
],

tests/integ/test_data_capture_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from sagemaker.model_monitor import DataCaptureConfig, NetworkConfig
2121
from sagemaker.tensorflow.model import TensorFlowModel
2222
from sagemaker.utils import unique_name_from_base
23-
from tests.integ import PYTHON_VERSION
2423
from tests.integ.retry import retries
2524

2625
ROLE = "SageMakerRole"
@@ -54,7 +53,6 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5453
model_data=model_data,
5554
role=ROLE,
5655
framework_version=tf_serving_version,
57-
py_version=PYTHON_VERSION,
5856
sagemaker_session=sagemaker_session,
5957
)
6058
predictor = model.deploy(
@@ -112,7 +110,6 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
112110
model_data=model_data,
113111
role=ROLE,
114112
framework_version=tf_serving_version,
115-
py_version=PYTHON_VERSION,
116113
sagemaker_session=sagemaker_session,
117114
)
118115
destination_s3_uri = os.path.join(
@@ -199,7 +196,6 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
199196
model_data=model_data,
200197
role=ROLE,
201198
framework_version=tf_serving_version,
202-
py_version=PYTHON_VERSION,
203199
sagemaker_session=sagemaker_session,
204200
)
205201
destination_s3_uri = os.path.join(

tests/integ/test_model_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.s3 import S3Uploader
2525
from datetime import datetime, timedelta
2626

27-
from tests.integ import DATA_DIR, PYTHON_VERSION
27+
from tests.integ import DATA_DIR
2828
from sagemaker.model_monitor import DatasetFormat
2929
from sagemaker.model_monitor import NetworkConfig, Statistics, Constraints
3030
from sagemaker.model_monitor import ModelMonitor
@@ -101,7 +101,6 @@ def predictor(sagemaker_session, tf_serving_version):
101101
model_data=model_data,
102102
role=ROLE,
103103
framework_version=tf_serving_version,
104-
py_version=PYTHON_VERSION,
105104
sagemaker_session=sagemaker_session,
106105
)
107106
predictor = model.deploy(

tests/integ/test_tf.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020

2121
from sagemaker.tensorflow import TensorFlow
22-
from sagemaker.tensorflow.defaults import LATEST_VERSION, LATEST_SERVING_VERSION
22+
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
2323
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2424

2525
import tests.integ
@@ -38,16 +38,9 @@
3838
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
3939
TAGS = [{"Key": "some-key", "Value": "some-value"}]
4040

41-
PY37_SUPPORTED_FRAMEWORK_VERSION = [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]
42-
43-
44-
@pytest.fixture(scope="module")
45-
def py_version(tf_full_version):
46-
return "py37" if tf_full_version in PY37_SUPPORTED_FRAMEWORK_VERSION else PYTHON_VERSION
47-
4841

4942
def test_mnist_with_checkpoint_config(
50-
sagemaker_session, instance_type, tf_full_version, py_version
43+
sagemaker_session, instance_type, tf_full_version, tf_full_py_version
5144
):
5245
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
5346
sagemaker_session.default_bucket(), sagemaker_timestamp()
@@ -60,7 +53,7 @@ def test_mnist_with_checkpoint_config(
6053
train_instance_type=instance_type,
6154
sagemaker_session=sagemaker_session,
6255
framework_version=tf_full_version,
63-
py_version=py_version,
56+
py_version=tf_full_py_version,
6457
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6558
checkpoint_s3_uri=checkpoint_s3_uri,
6659
checkpoint_local_path=checkpoint_local_path,
@@ -131,15 +124,15 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version):
131124

132125

133126
@pytest.mark.canary_quick
134-
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py_version):
127+
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, tf_full_py_version):
135128
estimator = TensorFlow(
136129
entry_point=SCRIPT,
137130
role=ROLE,
138131
train_instance_count=2,
139132
train_instance_type=instance_type,
140133
sagemaker_session=sagemaker_session,
141134
framework_version=tf_full_version,
142-
py_version=py_version,
135+
py_version=tf_full_py_version,
143136
distributions=PARAMETER_SERVER_DISTRIBUTION,
144137
)
145138
inputs = estimator.sagemaker_session.upload_data(

tests/integ/test_tfs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import tests.integ
2525
import tests.integ.timeout
2626
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor
27-
from tests.integ import PYTHON_VERSION
2827

2928

3029
@pytest.fixture(scope="module")
@@ -39,7 +38,6 @@ def tfs_predictor(sagemaker_session, tf_serving_version):
3938
model_data=model_data,
4039
role="SageMakerRole",
4140
framework_version=tf_serving_version,
42-
py_version=PYTHON_VERSION,
4341
sagemaker_session=sagemaker_session,
4442
)
4543
predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name)
@@ -68,7 +66,6 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
6866
model_data="file://" + model_tar,
6967
role="SageMakerRole",
7068
framework_version=tf_serving_version,
71-
py_version=PYTHON_VERSION,
7269
sagemaker_session=sagemaker_local_session,
7370
)
7471
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
@@ -102,7 +99,6 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
10299
role="SageMakerRole",
103100
dependencies=dependencies,
104101
framework_version=tf_serving_version,
105-
py_version=PYTHON_VERSION,
106102
sagemaker_session=sagemaker_local_session,
107103
)
108104

@@ -126,7 +122,6 @@ def tfs_predictor_with_accelerator(sagemaker_session, ei_tf_full_version, cpu_in
126122
model_data=model_data,
127123
role="SageMakerRole",
128124
framework_version=ei_tf_full_version,
129-
py_version=PYTHON_VERSION,
130125
sagemaker_session=sagemaker_session,
131126
)
132127
predictor = model.deploy(

tests/integ/test_tuner.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from sagemaker.predictor import json_deserializer
3838
from sagemaker.pytorch import PyTorch
3939
from sagemaker.tensorflow import TensorFlow
40-
from sagemaker.tensorflow.defaults import LATEST_VERSION
4140
from sagemaker.tuner import (
4241
IntegerParameter,
4342
ContinuousParameter,
@@ -52,13 +51,6 @@
5251

5352
DATA_PATH = os.path.join(DATA_DIR, "iris", "data")
5453

55-
PY37_SUPPORTED_FRAMEWORK_VERSION = [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]
56-
57-
58-
@pytest.fixture(scope="module")
59-
def py_version(tf_full_version):
60-
return "py37" if tf_full_version in PY37_SUPPORTED_FRAMEWORK_VERSION else PYTHON_VERSION
61-
6254

6355
@pytest.fixture(scope="module")
6456
def kmeans_train_set(sagemaker_session):
@@ -598,7 +590,9 @@ def test_tuning_mxnet(sagemaker_session, mxnet_full_version, cpu_instance_type):
598590

599591

600592
@pytest.mark.canary_quick
601-
def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
593+
def test_tuning_tf_script_mode(
594+
sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version
595+
):
602596
resource_path = os.path.join(DATA_DIR, "tensorflow_mnist")
603597
script_path = os.path.join(resource_path, "mnist.py")
604598

@@ -609,7 +603,7 @@ def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_ver
609603
train_instance_type=cpu_instance_type,
610604
sagemaker_session=sagemaker_session,
611605
framework_version=tf_full_version,
612-
py_version=py_version,
606+
py_version=tf_full_py_version,
613607
)
614608

615609
hyperparameter_ranges = {"epochs": IntegerParameter(1, 2)}
@@ -639,9 +633,10 @@ def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_ver
639633
tuner.wait()
640634

641635

636+
# TODO: evaluate skip mark and default framework_version 1.11
642637
@pytest.mark.canary_quick
643638
@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.")
644-
def test_tuning_tf(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
639+
def test_tuning_tf(sagemaker_session, cpu_instance_type):
645640
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
646641
script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py")
647642

@@ -654,8 +649,8 @@ def test_tuning_tf(sagemaker_session, cpu_instance_type, tf_full_version, py_ver
654649
train_instance_count=1,
655650
train_instance_type=cpu_instance_type,
656651
sagemaker_session=sagemaker_session,
657-
framework_version=tf_full_version,
658-
py_version=py_version,
652+
framework_version="1.11",
653+
py_version=PYTHON_VERSION,
659654
)
660655

661656
inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix="integ-test-data/tf_iris")
@@ -695,8 +690,9 @@ def test_tuning_tf(sagemaker_session, cpu_instance_type, tf_full_version, py_ver
695690
assert dict_result == list_result
696691

697692

693+
# TODO: evaluate skip mark and default framework_version 1.11
698694
@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.")
699-
def test_tuning_tf_vpc_multi(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
695+
def test_tuning_tf_vpc_multi(sagemaker_session, cpu_instance_type):
700696
"""Test Tensorflow multi-instance using the same VpcConfig for training and inference"""
701697
instance_type = cpu_instance_type
702698
instance_count = 2
@@ -720,8 +716,8 @@ def test_tuning_tf_vpc_multi(sagemaker_session, cpu_instance_type, tf_full_versi
720716
subnets=subnet_ids,
721717
security_group_ids=[security_group_id],
722718
encrypt_inter_container_traffic=True,
723-
framework_version=tf_full_version,
724-
py_version=py_version,
719+
framework_version="1.11",
720+
py_version=PYTHON_VERSION,
725721
)
726722

727723
inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix="integ-test-data/tf_iris")

0 commit comments

Comments
 (0)