Skip to content

change: use image_uris.retrieve instead of fw_utils.create_image_uri for DLC frameworks #1724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import logging

import sagemaker
from sagemaker import image_uris
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
validate_version_or_image_args,
Expand Down Expand Up @@ -175,11 +175,12 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
str: The appropriate image URI based on the given parameters.

"""
return create_image_uri(
region_name,
return image_uris.retrieve(
self.__framework_name__,
instance_type,
self.framework_version,
self.py_version,
region_name,
version=self.framework_version,
py_version=self.py_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
image_scope="inference",
)
13 changes: 9 additions & 4 deletions src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from sagemaker.cli.compatibility.v2.modifiers import framework_version, matching
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
from sagemaker import fw_utils
from sagemaker import image_uris

TF_NAMESPACES = ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator")
LEGACY_MODE_PARAMETERS = (
Expand Down Expand Up @@ -169,9 +169,14 @@ def _image_uri_from_args(self, keywords):
instance_type = kw.value.s if isinstance(kw.value, ast.Str) else None

if tf_version and instance_type:
return fw_utils.create_image_uri(
self.region, "tensorflow", instance_type, tf_version, "py2"
)
return image_uris.retrieve(
"tensorflow",
self.region,
version=tf_version,
py_version="py2",
instance_type=instance_type,
image_scope="training",
).replace("-scriptmode", "")

return None

Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
from six import string_types
from six.moves.urllib.parse import urlparse
import sagemaker
from sagemaker import git_utils
from sagemaker import git_utils, image_uris
from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
from sagemaker.debugger import get_rule_container_image_uri
from sagemaker.s3 import S3Uploader

from sagemaker.fw_utils import (
create_image_uri,
tar_and_upload_dir,
parse_s3_url,
UploadedCode,
Expand Down Expand Up @@ -1832,12 +1831,13 @@ def train_image(self):
"""
if self.image_uri:
return self.image_uri
return create_image_uri(
self.sagemaker_session.boto_region_name,
return image_uris.retrieve(
self.__framework_name__,
self.instance_type,
self.framework_version, # pylint: disable=no-member
self.sagemaker_session.boto_region_name,
instance_type=self.instance_type,
version=self.framework_version, # pylint: disable=no-member
py_version=self.py_version, # pylint: disable=no-member
image_scope="training",
)

@classmethod
Expand Down
17 changes: 7 additions & 10 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import packaging.version

import sagemaker
from sagemaker import image_uris
from sagemaker.deserializers import JSONDeserializer
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
validate_version_or_image_args,
Expand Down Expand Up @@ -183,17 +183,14 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
str: The appropriate image URI based on the given parameters.

"""
framework_name = self.__framework_name__
if self._is_mms_version():
framework_name = "{}-serving".format(framework_name)

return create_image_uri(
return image_uris.retrieve(
self.__framework_name__,
region_name,
framework_name,
instance_type,
self.framework_version,
self.py_version,
version=self.framework_version,
py_version=self.py_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
image_scope="inference",
)

def _is_mms_version(self):
Expand Down
17 changes: 7 additions & 10 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import packaging.version

import sagemaker
from sagemaker import image_uris
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
validate_version_or_image_args,
Expand Down Expand Up @@ -182,17 +182,14 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
str: The appropriate image URI based on the given parameters.

"""
framework_name = self.__framework_name__
if self._is_mms_version():
framework_name = "{}-serving".format(framework_name)

return create_image_uri(
return image_uris.retrieve(
self.__framework_name__,
region_name,
framework_name,
instance_type,
self.framework_version,
self.py_version,
version=self.framework_version,
py_version=self.py_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
image_scope="inference",
)

def _is_mms_version(self):
Expand Down
27 changes: 7 additions & 20 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from packaging import version

from sagemaker import utils
from sagemaker import image_uris, utils
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw
Expand All @@ -34,7 +34,6 @@ class TensorFlow(Framework):
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""

__framework_name__ = "tensorflow"
_ECR_REPO_NAME = "tensorflow-scriptmode"

_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
Expand Down Expand Up @@ -151,12 +150,13 @@ def _validate_args(self, py_version):
raise AttributeError(msg)

if self.image_uri is None and self._only_legacy_mode_supported():
legacy_image_uri = fw.create_image_uri(
self.sagemaker_session.boto_region_name,
legacy_image_uri = image_uris.retrieve(
"tensorflow",
self.instance_type,
self.framework_version,
self.py_version,
self.sagemaker_session.boto_region_name,
instance_type=self.instance_type,
version=self.framework_version,
py_version=self.py_version,
image_scope="training",
)

msg = (
Expand Down Expand Up @@ -355,19 +355,6 @@ def _validate_and_set_debugger_configs(self):
# Set defaults for debugging.
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)

def train_image(self):
"""Placeholder docstring"""
if self.image_uri:
return self.image_uri

return fw.create_image_uri(
self.sagemaker_session.boto_region_name,
self._ECR_REPO_NAME,
self.instance_type,
self.framework_version,
self.py_version,
)

def transformer(
self,
instance_count,
Expand Down
14 changes: 7 additions & 7 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import logging

import sagemaker
from sagemaker import image_uris
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.deserializers import JSONDeserializer
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer

Expand Down Expand Up @@ -122,7 +122,7 @@ def predict(self, data, initial_args=None):
class TensorFlowModel(sagemaker.model.FrameworkModel):
"""A ``FrameworkModel`` implementation for inference with TensorFlow Serving."""

__framework_name__ = "tensorflow-serving"
__framework_name__ = "tensorflow"
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
LOG_LEVEL_MAP = {
logging.DEBUG: "debug",
Expand Down Expand Up @@ -286,13 +286,13 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
if self.image_uri:
return self.image_uri

region_name = self.sagemaker_session.boto_region_name
return create_image_uri(
region_name,
return image_uris.retrieve(
self.__framework_name__,
instance_type,
self.framework_version,
self.sagemaker_session.boto_region_name,
version=self.framework_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
image_scope="inference",
)

def serving_image_uri(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def test_node_should_be_modified_random_function_call():


@patch("boto3.Session")
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session):
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
def test_modify_node_set_model_dir_and_image_name(retrieve_image_uri, boto_session):
boto_session.return_value.region_name = REGION_NAME

tf_constructors = (
Expand All @@ -97,14 +97,19 @@ def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session
modifier.modify_node(node)

assert "TensorFlow(image_uri='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
create_image_uri.assert_called_with(
REGION_NAME, "tensorflow", "ml.m4.xlarge", "1.11.0", "py2"
retrieve_image_uri.assert_called_with(
"tensorflow",
REGION_NAME,
instance_type="ml.m4.xlarge",
version="1.11.0",
py_version="py2",
image_scope="training",
)


@patch("boto3.Session")
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
def test_modify_node_set_image_name_from_args(retrieve_image_uri, boto_session):
boto_session.return_value.region_name = REGION_NAME

tf_constructor = "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0')"
Expand All @@ -113,7 +118,14 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
modifier.modify_node(node)

create_image_uri.assert_called_with(REGION_NAME, "tensorflow", "ml.p2.xlarge", "1.4.0", "py2")
retrieve_image_uri.assert_called_with(
"tensorflow",
REGION_NAME,
instance_type="ml.p2.xlarge",
version="1.4.0",
py_version="py2",
image_scope="training",
)

expected_string = (
"TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
Expand All @@ -123,8 +135,8 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):


@patch("boto3.Session", MagicMock())
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
def test_modify_node_set_hyperparameters(create_image_uri):
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
def test_modify_node_set_hyperparameters(retrieve_image_uri):
tf_constructor = """TensorFlow(
checkpoint_path='s3://foo/bar',
training_steps=100,
Expand All @@ -147,8 +159,8 @@ def test_modify_node_set_hyperparameters(create_image_uri):


@patch("boto3.Session", MagicMock())
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
def test_modify_node_preserve_other_hyperparameters(create_image_uri):
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
def test_modify_node_preserve_other_hyperparameters(retrieve_image_uri):
tf_constructor = """sagemaker.tensorflow.TensorFlow(
training_steps=100,
evaluation_steps=10,
Expand All @@ -173,8 +185,8 @@ def test_modify_node_preserve_other_hyperparameters(create_image_uri):


@patch("boto3.Session", MagicMock())
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
def test_modify_node_prefer_param_over_hyperparameter(create_image_uri):
@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI)
def test_modify_node_prefer_param_over_hyperparameter(retrieve_image_uri):
tf_constructor = """sagemaker.tensorflow.TensorFlow(
training_steps=100,
requirements_file='source/requirements.txt',
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def test_create_model(
container_log_level=container_log_level,
base_job_name=base_job_name,
enable_network_isolation=True,
output_path="s3://mybucket/output",
)

job_name = "doing something"
tf.fit(inputs="s3://mybucket/train", job_name=job_name)
tf._current_job_name = "doing something"

model_name = "doing something else"
name_from_base.return_value = model_name
Expand Down Expand Up @@ -233,10 +233,10 @@ def test_create_model_with_optional_params(
base_job_name="job",
source_dir=source_dir,
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
output_path="s3://mybucket/output",
)

job_name = "doing something"
tf.fit(inputs="s3://mybucket/train", job_name=job_name)
tf._current_job_name = "doing something"

new_role = "role"
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
Expand Down
Loading