Skip to content

breaking: force image_uri to be passed for legacy TF images #1539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 53 additions & 62 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,18 @@

logger = logging.getLogger("sagemaker")

# TODO: consider creating a function for generating this command before removing this constant
_SCRIPT_MODE_TENSORBOARD_WARNING = (
"Tensorboard is not supported with script mode. You can run the following "
"command: tensorboard --logdir %s --host localhost --port 6006 This can be "
"run from anywhere with access to the S3 URI used as the logdir."
)


class TensorFlow(Framework):
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""

__framework_name__ = "tensorflow"
_SCRIPT_MODE_REPO_NAME = "tensorflow-scriptmode"
_ECR_REPO_NAME = "tensorflow-scriptmode"

LATEST_VERSION = defaults.LATEST_VERSION

_LATEST_1X_VERSION = "1.15.2"

_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
_LOWEST_SCRIPT_MODE_ONLY_VERSION = version.Version("1.13.1")

_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")

def __init__(
Expand All @@ -59,7 +50,6 @@ def __init__(
model_dir=None,
image_name=None,
distributions=None,
script_mode=True,
**kwargs
):
"""Initialize a ``TensorFlow`` estimator.
Expand All @@ -82,6 +72,8 @@ def __init__(
* *Local Mode with local sources (file:// instead of s3://)* - \
``/opt/ml/shared/model``
To disable having ``model_dir`` passed to your training script,
set ``model_dir=False``.
image_name (str): If specified, the estimator will use this image for training and
hosting, instead of selecting the appropriate SageMaker official image based on
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
Expand Down Expand Up @@ -114,8 +106,6 @@ def __init__(
}
}
script_mode (bool): Whether or not to use the Script Mode TensorFlow images
(default: True).
**kwargs: Additional kwargs passed to the Framework constructor.
.. tip::
Expand Down Expand Up @@ -154,7 +144,6 @@ def __init__(
self.model_dir = model_dir
self.distributions = distributions or {}

self._script_mode_enabled = script_mode
self._validate_args(py_version=py_version, framework_version=self.framework_version)

def _validate_args(self, py_version, framework_version):
Expand All @@ -171,30 +160,33 @@ def _validate_args(self, py_version, framework_version):
)
raise AttributeError(msg)

if (not self._script_mode_enabled) and self._only_script_mode_supported():
logger.warning(
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
if self._only_legacy_mode_supported() and self.image_name is None:
additional_instructions = ""
if self.model_dir is not False:
additional_instructions = " and set 'model_dir=False'"

legacy_image_uri = fw.create_image_uri(
self.sagemaker_session.boto_region_name,
"tensorflow",
self.train_instance_type,
self.framework_version,
self.py_version,
)
self._script_mode_enabled = True

if self._only_legacy_mode_supported():
# TODO: add link to docs to explain how to use legacy mode with v2
logger.warning(
"TF %s supports only legacy mode. If you were using any legacy mode parameters "
msg = (
"TF {} supports only legacy mode. Please supply the image URI directly with "
"'image_name={}'{}. If you were using any legacy mode parameters "
"(training_steps, evaluation_steps, checkpoint_path, requirements_file), "
"make sure to pass them directly as hyperparameters instead.",
self.framework_version,
)
self._script_mode_enabled = False
"make sure to pass them directly as hyperparameters instead."
).format(self.framework_version, legacy_image_uri, additional_instructions)

raise ValueError(msg)

def _only_legacy_mode_supported(self):
"""Placeholder docstring"""
return version.Version(self.framework_version) <= self._HIGHEST_LEGACY_MODE_ONLY_VERSION

def _only_script_mode_supported(self):
"""Placeholder docstring"""
return version.Version(self.framework_version) >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION

def _only_python_3_supported(self):
"""Placeholder docstring"""
return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION
Expand All @@ -214,10 +206,6 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
job_details, model_channel_name
)

model_dir = init_params["hyperparameters"].pop("model_dir", None)
if model_dir is not None:
init_params["model_dir"] = model_dir

image_name = init_params.pop("image")
framework, py_version, tag, script_mode = fw.framework_name_from_image(image_name)
if not framework:
Expand All @@ -226,8 +214,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
init_params["image_name"] = image_name
return init_params

if script_mode is None:
init_params["script_mode"] = False
model_dir = init_params["hyperparameters"].pop("model_dir", None)
if model_dir:
init_params["model_dir"] = model_dir
elif script_mode is None:
init_params["model_dir"] = False

init_params["py_version"] = py_version

Expand All @@ -239,6 +230,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
"1.4" if tag == "1.0" else fw.framework_version_from_tag(tag)
)

# Legacy images are required to be passed in explicitly.
if not script_mode:
init_params["image_name"] = image_name

training_job_name = init_params["base_job_name"]
if framework != cls.__framework_name__:
raise ValueError(
Expand Down Expand Up @@ -309,27 +304,26 @@ def hyperparameters(self):
hyperparameters = super(TensorFlow, self).hyperparameters()
additional_hyperparameters = {}

if self._script_mode_enabled:
mpi_enabled = False

if "parameter_server" in self.distributions:
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
if "parameter_server" in self.distributions:
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled

if "mpi" in self.distributions:
mpi_dict = self.distributions["mpi"]
mpi_enabled = mpi_dict.get("enabled", False)
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
mpi_enabled = False
if "mpi" in self.distributions:
mpi_dict = self.distributions["mpi"]
mpi_enabled = mpi_dict.get("enabled", False)
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled

if mpi_dict.get("processes_per_host"):
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
"processes_per_host"
)

additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
"custom_mpi_options", ""
if mpi_dict.get("processes_per_host"):
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
"processes_per_host"
)

additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
"custom_mpi_options", ""
)

if self.model_dir is not False:
self.model_dir = self.model_dir or self._default_s3_path("model", mpi=mpi_enabled)
additional_hyperparameters["model_dir"] = self.model_dir

Expand Down Expand Up @@ -375,16 +369,13 @@ def train_image(self):
if self.image_name:
return self.image_name

if self._script_mode_enabled:
return fw.create_image_uri(
self.sagemaker_session.boto_region_name,
self._SCRIPT_MODE_REPO_NAME,
self.train_instance_type,
self.framework_version,
self.py_version,
)

return super(TensorFlow, self).train_image()
return fw.create_image_uri(
self.sagemaker_session.boto_region_name,
self._ECR_REPO_NAME,
self.train_instance_type,
self.framework_version,
self.py_version,
)

def transformer(
self,
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@ def xgboost_version(request):
"1.9.0",
"1.10",
"1.10.0",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.13",
"1.14",
"1.14.0",
"1.15",
"1.15.0",
"1.15.2",
"2.0",
"2.0.0",
"2.0.1",
"2.1",
"2.1.0",
],
)
def tf_version(request):
Expand Down
Loading