Skip to content

feature: support TFS 2.2 #1595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/tensorflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LATEST_VERSION = "2.2.0"
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""

LATEST_SERVING_VERSION = "2.1.0"
LATEST_SERVING_VERSION = "2.2.0"
"""The latest version of TensorFlow Serving included in the SageMaker pre-built Docker images."""

LATEST_PY2_VERSION = "2.1.0"
9 changes: 1 addition & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sagemaker.rl import RLEstimator
from sagemaker.sklearn.defaults import SKLEARN_VERSION
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_VERSION, LATEST_SERVING_VERSION
from sagemaker.tensorflow.defaults import LATEST_VERSION

DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket"
Expand Down Expand Up @@ -336,10 +336,3 @@ def pytest_generate_tests(metafunc):
@pytest.fixture(scope="module")
def xgboost_full_version(request):
return request.config.getoption("--xgboost-full-version")


@pytest.fixture(scope="module")
def tf_serving_version(tf_full_version):
if tf_full_version == LATEST_VERSION:
return LATEST_SERVING_VERSION
return tf_full_version
12 changes: 6 additions & 6 deletions tests/integ/test_data_capture_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
sagemaker_session, tf_serving_version
sagemaker_session, tf_full_version
):
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
Expand All @@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(


def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
sagemaker_session, tf_serving_version
sagemaker_session, tf_full_version
):
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
Expand All @@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_session,
)
destination_s3_uri = os.path.join(
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(


def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
sagemaker_session, tf_serving_version
sagemaker_session, tf_full_version
):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
Expand All @@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_session,
)
destination_s3_uri = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@


@pytest.fixture(scope="module")
def predictor(sagemaker_session, tf_serving_version):
def predictor(sagemaker_session, tf_full_version):
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
Expand All @@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_serving_version):
model = Model(
model_data=model_data,
role=ROLE,
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
Expand Down
24 changes: 11 additions & 13 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest

from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
from sagemaker.tensorflow.defaults import LATEST_VERSION
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp

import tests.integ
Expand All @@ -41,8 +41,8 @@


@pytest.fixture(scope="module")
def py_version(tf_full_version, tf_serving_version):
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
def py_version(tf_full_version):
return "py37" if tf_full_version == LATEST_VERSION else tests.integ.PYTHON_VERSION


def test_mnist_with_checkpoint_config(
Expand All @@ -60,7 +60,7 @@ def test_mnist_with_checkpoint_config(
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=tf_full_version,
py_version="py37",
py_version=py_version,
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_mnist_with_checkpoint_config(
assert actual_training_checkpoint_config == expected_training_checkpoint_config


def test_server_side_encryption(sagemaker_session, tf_serving_version, py_version):
def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
output_path = os.path.join(
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
Expand All @@ -104,7 +104,7 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
train_instance_type="ml.c5.xlarge",
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=tf_serving_version,
framework_version=tf_full_version,
py_version=py_version,
code_location=output_path,
output_path=output_path,
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
py_version="py37",
py_version=py_version,
script_mode=True,
framework_version=tf_full_version,
distributions=PARAMETER_SERVER_DISTRIBUTION,
Expand All @@ -163,11 +163,11 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
role=ROLE,
train_instance_count=1,
train_instance_type="ml.c5.4xlarge",
py_version=tests.integ.PYTHON_VERSION,
py_version=py_version,
sagemaker_session=sagemaker_session,
script_mode=True,
# testing py-sdk functionality, no need to run against all TF versions
framework_version=LATEST_SERVING_VERSION,
framework_version=tf_full_version,
tags=TAGS,
)
inputs = estimator.sagemaker_session.upload_data(
Expand Down Expand Up @@ -199,9 +199,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)


def test_deploy_with_input_handlers(
sagemaker_session, instance_type, tf_serving_version, py_version
):
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_version, py_version):
estimator = TensorFlow(
entry_point="training.py",
source_dir=TFS_RESOURCE_PATH,
Expand All @@ -211,7 +209,7 @@ def test_deploy_with_input_handlers(
py_version=py_version,
sagemaker_session=sagemaker_session,
script_mode=True,
framework_version=tf_serving_version,
framework_version=tf_full_version,
tags=TAGS,
)

Expand Down
12 changes: 6 additions & 6 deletions tests/integ/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@pytest.fixture(scope="module")
def tfs_predictor(sagemaker_session, tf_serving_version):
def tfs_predictor(sagemaker_session, tf_full_version):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
model_data = sagemaker_session.upload_data(
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
Expand All @@ -37,7 +37,7 @@ def tfs_predictor(sagemaker_session, tf_serving_version):
model = Model(
model_data=model_data,
role="SageMakerRole",
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name)
Expand All @@ -54,7 +54,7 @@ def tar_dir(directory, tmpdir):

@pytest.fixture
def tfs_predictor_with_model_and_entry_point_same_tar(
sagemaker_local_session, tf_serving_version, tmpdir
sagemaker_local_session, tf_full_version, tmpdir
):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")

Expand All @@ -65,7 +65,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
model = Model(
model_data="file://" + model_tar,
role="SageMakerRole",
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_local_session,
)
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
Expand All @@ -78,7 +78,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(

@pytest.fixture(scope="module")
def tfs_predictor_with_model_and_entry_point_and_dependencies(
sagemaker_local_session, tf_serving_version
sagemaker_local_session, tf_full_version
):
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")

Expand All @@ -98,7 +98,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
model_data=model_data,
role="SageMakerRole",
dependencies=dependencies,
framework_version=tf_serving_version,
framework_version=tf_full_version,
sagemaker_session=sagemaker_local_session,
)

Expand Down
9 changes: 5 additions & 4 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from sagemaker.mxnet import MXNet
from sagemaker.pytorch import PyTorchModel
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_SERVING_VERSION
from sagemaker.transformer import Transformer
from sagemaker.estimator import Estimator
from sagemaker.utils import unique_name_from_base
Expand Down Expand Up @@ -344,17 +343,19 @@ def test_transform_mxnet_logs(
transformer.wait()


def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type, tmpdir):
def test_transform_tf_kms_network_isolation(
sagemaker_session, cpu_instance_type, tmpdir, tf_full_version, py_version
):
data_path = os.path.join(DATA_DIR, "tensorflow_mnist")

tf = TensorFlow(
entry_point=os.path.join(data_path, "mnist.py"),
role="SageMakerRole",
train_instance_count=1,
train_instance_type=cpu_instance_type,
framework_version=LATEST_SERVING_VERSION,
framework_version=tf_full_version,
script_mode=True,
py_version=PYTHON_VERSION,
py_version=py_version,
sagemaker_session=sagemaker_session,
)

Expand Down