Skip to content

Commit c53472b

Browse files
authored
feature: support for TensorFlow 1.14 (#967)
1 parent e89b464 commit c53472b

File tree

9 files changed

+125
-8
lines changed

9 files changed

+125
-8
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ TensorFlow SageMaker Estimators
189189

190190
By using TensorFlow SageMaker Estimators, you can train and host TensorFlow models on Amazon SageMaker.
191191

192-
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``.
192+
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``, ``1.14``.
193193

194-
Supported versions of TensorFlow for Elastic Inference: ``1.11.0``, ``1.12.0``, ``1.13.0``
194+
Supported versions of TensorFlow for Elastic Inference: ``1.11.0``, ``1.12.0``, ``1.13.1``
195195

196196
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
197197

doc/using_tf.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ models on SageMaker Hosting.
88

99
For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`.
1010

11+
.. warning::
12+
The TensorFlow estimator is available only for Python 3, starting by the TensorFlow version 1.14.
13+
1114
.. warning::
1215
We have added a new format of your TensorFlow training script with TensorFlow version 1.11.
1316
This new way gives the user script more flexibility.

src/sagemaker/tensorflow/estimator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,18 @@ class TensorFlow(Framework):
195195

196196
__framework_name__ = "tensorflow"
197197

198-
LATEST_VERSION = "1.13"
198+
LATEST_VERSION = "1.14"
199199
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
200200

201201
_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
202+
_LOWEST_PYTHON_2_ONLY_VERSION = [1, 14]
202203

203204
def __init__(
204205
self,
205206
training_steps=None,
206207
evaluation_steps=None,
207208
checkpoint_path=None,
208-
py_version="py2",
209+
py_version=None,
209210
framework_version=None,
210211
model_dir=None,
211212
requirements_file="",
@@ -279,6 +280,9 @@ def __init__(
279280
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
280281
self.framework_version = framework_version or TF_VERSION
281282

283+
if not py_version:
284+
py_version = "py3" if self._only_python_3_supported() else "py2"
285+
282286
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
283287
self.checkpoint_path = checkpoint_path
284288

@@ -337,6 +341,13 @@ def _validate_args(
337341
)
338342
)
339343

344+
if py_version == "py2" and self._only_python_3_supported():
345+
msg = (
346+
"Python 2 containers are only available until TensorFlow version 1.13.1. "
347+
"Please use a Python 3 container."
348+
)
349+
raise AttributeError(msg)
350+
340351
if (not self._script_mode_enabled()) and self._only_script_mode_supported():
341352
logger.warning(
342353
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
@@ -349,6 +360,12 @@ def _only_script_mode_supported(self):
349360
int(s) for s in self.framework_version.split(".")
350361
] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
351362

363+
def _only_python_3_supported(self):
364+
"""Placeholder docstring"""
365+
return [
366+
int(s) for s in self.framework_version.split(".")
367+
] >= self._LOWEST_PYTHON_2_ONLY_VERSION
368+
352369
def _validate_requirements_file(self, requirements_file):
353370
"""Placeholder docstring"""
354371
if not requirements_file:

src/sagemaker/tensorflow/serving.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Model(sagemaker.model.FrameworkModel):
131131
logging.ERROR: "error",
132132
logging.CRITICAL: "crit",
133133
}
134+
LATEST_EIA_VERSION = [1, 13]
134135

135136
def __init__(
136137
self,
@@ -176,6 +177,37 @@ def __init__(
176177
self._framework_version = framework_version
177178
self._container_log_level = container_log_level
178179

180+
def deploy(
181+
self,
182+
initial_instance_count,
183+
instance_type,
184+
accelerator_type=None,
185+
endpoint_name=None,
186+
update_endpoint=False,
187+
tags=None,
188+
kms_key=None,
189+
wait=True,
190+
):
191+
192+
if accelerator_type and not self._eia_supported():
193+
msg = "The TensorFlow version %s doesn't support EIA." % self._framework_version
194+
195+
raise AttributeError(msg)
196+
return super(Model, self).deploy(
197+
initial_instance_count,
198+
instance_type,
199+
accelerator_type,
200+
endpoint_name,
201+
update_endpoint,
202+
tags,
203+
kms_key,
204+
wait,
205+
)
206+
207+
def _eia_supported(self):
208+
"""Return true if TF version is EIA enabled"""
209+
return [int(s) for s in self._framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
210+
179211
def prepare_container_def(self, instance_type, accelerator_type=None):
180212
"""
181213
Args:

tests/integ/test_tfs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
111111

112112

113113
@pytest.fixture(scope="module")
114-
def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version):
114+
def tfs_predictor_with_accelerator(sagemaker_session):
115115
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
116116
model_data = sagemaker_session.upload_data(
117117
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -121,7 +121,7 @@ def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version):
121121
model = Model(
122122
model_data=model_data,
123123
role="SageMakerRole",
124-
framework_version=tf_full_version,
124+
framework_version="1.13",
125125
sagemaker_session=sagemaker_session,
126126
)
127127
predictor = model.deploy(

tests/unit/test_fw_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ def test_create_image_uri_gov_cloud():
137137

138138

139139
def test_create_image_uri_merged():
140+
image_uri = fw_utils.create_image_uri(
141+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py3"
142+
)
143+
assert (
144+
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.14-gpu-py3"
145+
)
146+
140147
image_uri = fw_utils.create_image_uri(
141148
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py3"
142149
)

tests/unit/test_tf_estimator.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
2626
import sagemaker.tensorflow.estimator as tfe
2727

28-
2928
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
3029
SCRIPT_FILE = "dummy_script.py"
3130
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
@@ -956,6 +955,30 @@ def test_script_mode_deprecated_args(sagemaker_session):
956955
) in str(e.value)
957956

958957

958+
def test_py2_version_deprecated(sagemaker_session):
959+
with pytest.raises(AttributeError) as e:
960+
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14", py_version="py2")
961+
962+
msg = "Python 2 containers are only available until TensorFlow version 1.13.1. Please use a Python 3 container."
963+
assert msg in str(e.value)
964+
965+
966+
def test_py3_is_default_version_after_tf1_14(sagemaker_session):
967+
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14")
968+
969+
assert estimator.py_version == "py3"
970+
971+
972+
def test_py3_is_default_version_before_tf1_14(sagemaker_session):
973+
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.13")
974+
975+
assert estimator.py_version == "py2"
976+
977+
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.10")
978+
979+
assert estimator.py_version == "py2"
980+
981+
959982
def test_legacy_mode_deprecated(sagemaker_session):
960983
tf = _build_tf(
961984
sagemaker_session=sagemaker_session,

tests/unit/test_tfs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,41 @@ def test_tfs_model_image_accelerator(sagemaker_session, tf_version):
9696
assert isinstance(predictor, Predictor)
9797

9898

99+
def test_tfs_model_image_accelerator_not_supported(sagemaker_session):
100+
model = Model(
101+
"s3://some/data.tar.gz",
102+
role=ROLE,
103+
framework_version="1.13.1",
104+
sagemaker_session=sagemaker_session,
105+
)
106+
107+
# assert error is not raised
108+
109+
model.deploy(
110+
instance_type="ml.c4.xlarge", initial_instance_count=1, accelerator_type="ml.eia1.medium"
111+
)
112+
113+
model = Model(
114+
"s3://some/data.tar.gz",
115+
role=ROLE,
116+
framework_version="1.14",
117+
sagemaker_session=sagemaker_session,
118+
)
119+
120+
# assert error is not raised
121+
122+
model.deploy(instance_type="ml.c4.xlarge", initial_instance_count=1)
123+
124+
with pytest.raises(AttributeError) as e:
125+
model.deploy(
126+
instance_type="ml.c4.xlarge",
127+
accelerator_type="ml.eia1.medium",
128+
initial_instance_count=1,
129+
)
130+
131+
assert str(e.value) == "The TensorFlow version 1.14 doesn't support EIA."
132+
133+
99134
def test_tfs_model_with_log_level(sagemaker_session, tf_version):
100135
model = Model(
101136
"s3://some/data.tar.gz",

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ignore =
3333
FI15,
3434
FI16,
3535
FI17,
36-
FI18,
36+
FI18, # __future__ import "annotations" missing -> check only Python 3.7 compatible
3737
FI50,
3838
FI51,
3939
FI52,

0 commit comments

Comments
 (0)