Skip to content

Commit 44bd99e

Browse files
authored
feature: support TensorFlow training 2.2 (#1521)
* feature: TensorFlow 2.2 support
1 parent 78495df commit 44bd99e

File tree

9 files changed

+53
-32
lines changed

9 files changed

+53
-32
lines changed

doc/using_tf.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ For general information about using the SageMaker Python SDK, see :ref:`overview
2020

2121
.. contents::
2222

23-
Supported versions of TensorFlow for Elastic Inference: ``1.11``, ``1.12``, ``1.13``, ``1.14``.
23+
Supported versions of TensorFlow for Elastic Inference: ``1.11``, ``1.12``, ``1.13``, ``1.14``, ``1.15``, ``2.0``.
2424

2525

2626
*****************************

src/sagemaker/tensorflow/defaults.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
This is no longer updated so as to not break existing workflows.
1919
"""
2020

21-
LATEST_VERSION = "2.1.0"
21+
LATEST_VERSION = "2.2.0"
2222
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
2323

24+
LATEST_SERVING_VERSION = "2.1.0"
25+
"""The latest version of TensorFlow Serving included in the SageMaker pre-built Docker images."""
26+
2427
LATEST_PY2_VERSION = "2.1.0"

tests/conftest.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from sagemaker.pytorch import PyTorch
2828
from sagemaker.rl import RLEstimator
2929
from sagemaker.sklearn.defaults import SKLEARN_VERSION
30-
from sagemaker.tensorflow.estimator import TensorFlow
30+
from sagemaker.tensorflow import TensorFlow
31+
from sagemaker.tensorflow.defaults import LATEST_VERSION, LATEST_SERVING_VERSION
3132

3233
DEFAULT_REGION = "us-west-2"
3334
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
@@ -259,7 +260,7 @@ def sklearn_full_version(request):
259260
return request.config.getoption("--sklearn-full-version")
260261

261262

262-
@pytest.fixture(scope="module", params=[TensorFlow._LATEST_1X_VERSION, TensorFlow.LATEST_VERSION])
263+
@pytest.fixture(scope="module", params=[TensorFlow._LATEST_1X_VERSION, LATEST_VERSION])
263264
def tf_full_version(request):
264265
tf_version = request.config.getoption("--tf-full-version")
265266
if tf_version is None:
@@ -335,3 +336,10 @@ def pytest_generate_tests(metafunc):
335336
@pytest.fixture(scope="module")
336337
def xgboost_full_version(request):
337338
return request.config.getoption("--xgboost-full-version")
339+
340+
341+
@pytest.fixture(scope="module")
342+
def tf_serving_version(tf_full_version):
343+
if tf_full_version == LATEST_VERSION:
344+
return LATEST_SERVING_VERSION
345+
return tf_full_version

tests/integ/test_data_capture_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
44-
sagemaker_session, tf_full_version
44+
sagemaker_session, tf_serving_version
4545
):
4646
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
4747
model_data = sagemaker_session.upload_data(
@@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5252
model = Model(
5353
model_data=model_data,
5454
role=ROLE,
55-
framework_version=tf_full_version,
55+
framework_version=tf_serving_version,
5656
sagemaker_session=sagemaker_session,
5757
)
5858
predictor = model.deploy(
@@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
9898

9999

100100
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
101-
sagemaker_session, tf_full_version
101+
sagemaker_session, tf_serving_version
102102
):
103103
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
104104
model_data = sagemaker_session.upload_data(
@@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
109109
model = Model(
110110
model_data=model_data,
111111
role=ROLE,
112-
framework_version=tf_full_version,
112+
framework_version=tf_serving_version,
113113
sagemaker_session=sagemaker_session,
114114
)
115115
destination_s3_uri = os.path.join(
@@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
184184

185185

186186
def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
187-
sagemaker_session, tf_full_version
187+
sagemaker_session, tf_serving_version
188188
):
189189
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
190190
model_data = sagemaker_session.upload_data(
@@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
195195
model = Model(
196196
model_data=model_data,
197197
role=ROLE,
198-
framework_version=tf_full_version,
198+
framework_version=tf_serving_version,
199199
sagemaker_session=sagemaker_session,
200200
)
201201
destination_s3_uri = os.path.join(

tests/integ/test_model_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888

8989

9090
@pytest.fixture(scope="module")
91-
def predictor(sagemaker_session, tf_full_version):
91+
def predictor(sagemaker_session, tf_serving_version):
9292
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
9393
model_data = sagemaker_session.upload_data(
9494
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_full_version):
100100
model = Model(
101101
model_data=model_data,
102102
role=ROLE,
103-
framework_version=tf_full_version,
103+
framework_version=tf_serving_version,
104104
sagemaker_session=sagemaker_session,
105105
)
106106
predictor = model.deploy(

tests/integ/test_tf_script_mode.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020

2121
from sagemaker.tensorflow import TensorFlow
22+
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
2223
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2324

2425
import tests.integ
@@ -40,10 +41,8 @@
4041

4142

4243
@pytest.fixture(scope="module")
43-
def py_version(tf_full_version):
44-
return (
45-
"py37" if tf_full_version == TensorFlow._LATEST_1X_VERSION else tests.integ.PYTHON_VERSION
46-
)
44+
def py_version(tf_full_version, tf_serving_version):
45+
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
4746

4847

4948
def test_mnist_with_checkpoint_config(
@@ -61,7 +60,7 @@ def test_mnist_with_checkpoint_config(
6160
sagemaker_session=sagemaker_session,
6261
script_mode=True,
6362
framework_version=tf_full_version,
64-
py_version=py_version,
63+
py_version="py37",
6564
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6665
checkpoint_s3_uri=checkpoint_s3_uri,
6766
checkpoint_local_path=checkpoint_local_path,
@@ -91,7 +90,7 @@ def test_mnist_with_checkpoint_config(
9190
assert actual_training_checkpoint_config == expected_training_checkpoint_config
9291

9392

94-
def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
93+
def test_server_side_encryption(sagemaker_session, tf_serving_version, py_version):
9594
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
9695
output_path = os.path.join(
9796
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
@@ -105,7 +104,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
105104
train_instance_type="ml.c5.xlarge",
106105
sagemaker_session=sagemaker_session,
107106
script_mode=True,
108-
framework_version=tf_full_version,
107+
framework_version=tf_serving_version,
109108
py_version=py_version,
110109
code_location=output_path,
111110
output_path=output_path,
@@ -140,7 +139,7 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
140139
train_instance_count=2,
141140
train_instance_type=instance_type,
142141
sagemaker_session=sagemaker_session,
143-
py_version=py_version,
142+
py_version="py37",
144143
script_mode=True,
145144
framework_version=tf_full_version,
146145
distributions=PARAMETER_SERVER_DISTRIBUTION,
@@ -168,7 +167,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
168167
sagemaker_session=sagemaker_session,
169168
script_mode=True,
170169
# testing py-sdk functionality, no need to run against all TF versions
171-
framework_version=TensorFlow.LATEST_VERSION,
170+
framework_version=LATEST_SERVING_VERSION,
172171
tags=TAGS,
173172
)
174173
inputs = estimator.sagemaker_session.upload_data(
@@ -200,7 +199,9 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
200199
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
201200

202201

203-
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_version, py_version):
202+
def test_deploy_with_input_handlers(
203+
sagemaker_session, instance_type, tf_serving_version, py_version
204+
):
204205
estimator = TensorFlow(
205206
entry_point="training.py",
206207
source_dir=TFS_RESOURCE_PATH,
@@ -210,7 +211,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_ve
210211
py_version=py_version,
211212
sagemaker_session=sagemaker_session,
212213
script_mode=True,
213-
framework_version=tf_full_version,
214+
framework_version=tf_serving_version,
214215
tags=TAGS,
215216
)
216217

tests/integ/test_tfs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@pytest.fixture(scope="module")
30-
def tfs_predictor(sagemaker_session, tf_full_version):
30+
def tfs_predictor(sagemaker_session, tf_serving_version):
3131
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
3232
model_data = sagemaker_session.upload_data(
3333
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -37,7 +37,7 @@ def tfs_predictor(sagemaker_session, tf_full_version):
3737
model = Model(
3838
model_data=model_data,
3939
role="SageMakerRole",
40-
framework_version=tf_full_version,
40+
framework_version=tf_serving_version,
4141
sagemaker_session=sagemaker_session,
4242
)
4343
predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name)
@@ -54,7 +54,7 @@ def tar_dir(directory, tmpdir):
5454

5555
@pytest.fixture
5656
def tfs_predictor_with_model_and_entry_point_same_tar(
57-
sagemaker_local_session, tf_full_version, tmpdir
57+
sagemaker_local_session, tf_serving_version, tmpdir
5858
):
5959
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
6060

@@ -65,7 +65,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
6565
model = Model(
6666
model_data="file://" + model_tar,
6767
role="SageMakerRole",
68-
framework_version=tf_full_version,
68+
framework_version=tf_serving_version,
6969
sagemaker_session=sagemaker_local_session,
7070
)
7171
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
@@ -78,7 +78,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
7878

7979
@pytest.fixture(scope="module")
8080
def tfs_predictor_with_model_and_entry_point_and_dependencies(
81-
sagemaker_local_session, tf_full_version
81+
sagemaker_local_session, tf_serving_version
8282
):
8383
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
8484

@@ -98,7 +98,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
9898
model_data=model_data,
9999
role="SageMakerRole",
100100
dependencies=dependencies,
101-
framework_version=tf_full_version,
101+
framework_version=tf_serving_version,
102102
sagemaker_session=sagemaker_local_session,
103103
)
104104

tests/integ/test_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.mxnet import MXNet
2626
from sagemaker.pytorch import PyTorchModel
2727
from sagemaker.tensorflow import TensorFlow
28+
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
2829
from sagemaker.transformer import Transformer
2930
from sagemaker.estimator import Estimator
3031
from sagemaker.utils import unique_name_from_base
@@ -351,7 +352,7 @@ def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type
351352
role="SageMakerRole",
352353
train_instance_count=1,
353354
train_instance_type=cpu_instance_type,
354-
framework_version=TensorFlow.LATEST_VERSION,
355+
framework_version=LATEST_SERVING_VERSION,
355356
script_mode=True,
356357
py_version=PYTHON_VERSION,
357358
sagemaker_session=sagemaker_session,

tests/integ/test_tuner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sagemaker.predictor import json_deserializer
3838
from sagemaker.pytorch import PyTorch
3939
from sagemaker.tensorflow import TensorFlow
40+
from sagemaker.tensorflow.defaults import LATEST_VERSION
4041
from sagemaker.tuner import (
4142
IntegerParameter,
4243
ContinuousParameter,
@@ -51,6 +52,13 @@
5152

5253
DATA_PATH = os.path.join(DATA_DIR, "iris", "data")
5354

55+
PY37_SUPPORTED_FRAMEWORK_VERSION = [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]
56+
57+
58+
@pytest.fixture(scope="module")
59+
def py_version(tf_full_version):
60+
return "py37" if tf_full_version in PY37_SUPPORTED_FRAMEWORK_VERSION else PYTHON_VERSION
61+
5462

5563
@pytest.fixture(scope="module")
5664
def kmeans_train_set(sagemaker_session):
@@ -590,7 +598,7 @@ def test_tuning_mxnet(sagemaker_session, mxnet_full_version, cpu_instance_type):
590598

591599

592600
@pytest.mark.canary_quick
593-
def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_version):
601+
def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
594602
resource_path = os.path.join(DATA_DIR, "tensorflow_mnist")
595603
script_path = os.path.join(resource_path, "mnist.py")
596604

@@ -601,7 +609,7 @@ def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_ver
601609
train_instance_type=cpu_instance_type,
602610
script_mode=True,
603611
sagemaker_session=sagemaker_session,
604-
py_version=PYTHON_VERSION,
612+
py_version=py_version,
605613
framework_version=tf_full_version,
606614
)
607615

0 commit comments

Comments
 (0)