Skip to content

Commit f857c33

Browse files
committed
breaking: force image_uri to be passed for legacy TF images
This change also reorganizes the TF unit tests a bit, and updates the tf_version fixture to include recent versions.
1 parent 6eeca73 commit f857c33

File tree

8 files changed

+1019
-979
lines changed

8 files changed

+1019
-979
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 53 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,18 @@
2929

3030
logger = logging.getLogger("sagemaker")
3131

32-
# TODO: consider creating a function for generating this command before removing this constant
33-
_SCRIPT_MODE_TENSORBOARD_WARNING = (
34-
"Tensorboard is not supported with script mode. You can run the following "
35-
"command: tensorboard --logdir %s --host localhost --port 6006 This can be "
36-
"run from anywhere with access to the S3 URI used as the logdir."
37-
)
38-
3932

4033
class TensorFlow(Framework):
4134
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""
4235

4336
__framework_name__ = "tensorflow"
44-
_SCRIPT_MODE_REPO_NAME = "tensorflow-scriptmode"
37+
_ECR_REPO_NAME = "tensorflow-scriptmode"
4538

4639
LATEST_VERSION = defaults.LATEST_VERSION
4740

4841
_LATEST_1X_VERSION = "1.15.2"
4942

5043
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
51-
_LOWEST_SCRIPT_MODE_ONLY_VERSION = version.Version("1.13.1")
52-
5344
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
5445

5546
def __init__(
@@ -59,7 +50,6 @@ def __init__(
5950
model_dir=None,
6051
image_name=None,
6152
distributions=None,
62-
script_mode=True,
6353
**kwargs
6454
):
6555
"""Initialize a ``TensorFlow`` estimator.
@@ -82,6 +72,8 @@ def __init__(
8272
* *Local Mode with local sources (file:// instead of s3://)* - \
8373
``/opt/ml/shared/model``
8474
75+
To disable having ``model_dir`` passed to your training script,
76+
set ``model_dir=False``.
8577
image_name (str): If specified, the estimator will use this image for training and
8678
hosting, instead of selecting the appropriate SageMaker official image based on
8779
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -114,8 +106,6 @@ def __init__(
114106
}
115107
}
116108
117-
script_mode (bool): Whether or not to use the Script Mode TensorFlow images
118-
(default: True).
119109
**kwargs: Additional kwargs passed to the Framework constructor.
120110
121111
.. tip::
@@ -154,7 +144,6 @@ def __init__(
154144
self.model_dir = model_dir
155145
self.distributions = distributions or {}
156146

157-
self._script_mode_enabled = script_mode
158147
self._validate_args(py_version=py_version, framework_version=self.framework_version)
159148

160149
def _validate_args(self, py_version, framework_version):
@@ -171,30 +160,33 @@ def _validate_args(self, py_version, framework_version):
171160
)
172161
raise AttributeError(msg)
173162

174-
if (not self._script_mode_enabled) and self._only_script_mode_supported():
175-
logger.warning(
176-
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
163+
if self._only_legacy_mode_supported() and self.image_name is None:
164+
additional_instructions = ""
165+
if self.model_dir is not False:
166+
additional_instructions = " and set 'model_dir=False'"
167+
168+
legacy_image_uri = fw.create_image_uri(
169+
self.sagemaker_session.boto_region_name,
170+
"tensorflow",
171+
self.train_instance_type,
172+
self.framework_version,
173+
self.py_version,
177174
)
178-
self._script_mode_enabled = True
179175

180-
if self._only_legacy_mode_supported():
181176
# TODO: add link to docs to explain how to use legacy mode with v2
182-
logger.warning(
183-
"TF %s supports only legacy mode. If you were using any legacy mode parameters "
177+
msg = (
178+
"TF {} supports only legacy mode. Please supply the image URI directly with "
179+
"'image_name={}'{}. If you were using any legacy mode parameters "
184180
"(training_steps, evaluation_steps, checkpoint_path, requirements_file), "
185-
"make sure to pass them directly as hyperparameters instead.",
186-
self.framework_version,
187-
)
188-
self._script_mode_enabled = False
181+
"make sure to pass them directly as hyperparameters instead."
182+
).format(self.framework_version, legacy_image_uri, additional_instructions)
183+
184+
raise ValueError(msg)
189185

190186
def _only_legacy_mode_supported(self):
191187
"""Placeholder docstring"""
192188
return version.Version(self.framework_version) <= self._HIGHEST_LEGACY_MODE_ONLY_VERSION
193189

194-
def _only_script_mode_supported(self):
195-
"""Placeholder docstring"""
196-
return version.Version(self.framework_version) >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
197-
198190
def _only_python_3_supported(self):
199191
"""Placeholder docstring"""
200192
return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION
@@ -214,10 +206,6 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
214206
job_details, model_channel_name
215207
)
216208

217-
model_dir = init_params["hyperparameters"].pop("model_dir", None)
218-
if model_dir is not None:
219-
init_params["model_dir"] = model_dir
220-
221209
image_name = init_params.pop("image")
222210
framework, py_version, tag, script_mode = fw.framework_name_from_image(image_name)
223211
if not framework:
@@ -226,8 +214,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
226214
init_params["image_name"] = image_name
227215
return init_params
228216

229-
if script_mode is None:
230-
init_params["script_mode"] = False
217+
model_dir = init_params["hyperparameters"].pop("model_dir", None)
218+
if model_dir:
219+
init_params["model_dir"] = model_dir
220+
elif script_mode is None:
221+
init_params["model_dir"] = False
231222

232223
init_params["py_version"] = py_version
233224

@@ -239,6 +230,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
239230
"1.4" if tag == "1.0" else fw.framework_version_from_tag(tag)
240231
)
241232

233+
# Legacy images are required to be passed in explicitly.
234+
if not script_mode:
235+
init_params["image_name"] = image_name
236+
242237
training_job_name = init_params["base_job_name"]
243238
if framework != cls.__framework_name__:
244239
raise ValueError(
@@ -309,27 +304,26 @@ def hyperparameters(self):
309304
hyperparameters = super(TensorFlow, self).hyperparameters()
310305
additional_hyperparameters = {}
311306

312-
if self._script_mode_enabled:
313-
mpi_enabled = False
314-
315-
if "parameter_server" in self.distributions:
316-
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
317-
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
307+
if "parameter_server" in self.distributions:
308+
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
309+
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
318310

319-
if "mpi" in self.distributions:
320-
mpi_dict = self.distributions["mpi"]
321-
mpi_enabled = mpi_dict.get("enabled", False)
322-
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
311+
mpi_enabled = False
312+
if "mpi" in self.distributions:
313+
mpi_dict = self.distributions["mpi"]
314+
mpi_enabled = mpi_dict.get("enabled", False)
315+
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
323316

324-
if mpi_dict.get("processes_per_host"):
325-
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
326-
"processes_per_host"
327-
)
328-
329-
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
330-
"custom_mpi_options", ""
317+
if mpi_dict.get("processes_per_host"):
318+
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
319+
"processes_per_host"
331320
)
332321

322+
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
323+
"custom_mpi_options", ""
324+
)
325+
326+
if self.model_dir is not False:
333327
self.model_dir = self.model_dir or self._default_s3_path("model", mpi=mpi_enabled)
334328
additional_hyperparameters["model_dir"] = self.model_dir
335329

@@ -375,16 +369,13 @@ def train_image(self):
375369
if self.image_name:
376370
return self.image_name
377371

378-
if self._script_mode_enabled:
379-
return fw.create_image_uri(
380-
self.sagemaker_session.boto_region_name,
381-
self._SCRIPT_MODE_REPO_NAME,
382-
self.train_instance_type,
383-
self.framework_version,
384-
self.py_version,
385-
)
386-
387-
return super(TensorFlow, self).train_image()
372+
return fw.create_image_uri(
373+
self.sagemaker_session.boto_region_name,
374+
self._ECR_REPO_NAME,
375+
self.train_instance_type,
376+
self.framework_version,
377+
self.py_version,
378+
)
388379

389380
def transformer(
390381
self,

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ def xgboost_version(request):
194194
"1.9.0",
195195
"1.10",
196196
"1.10.0",
197+
"1.11",
198+
"1.11.0",
199+
"1.12",
200+
"1.12.0",
201+
"1.13",
202+
"1.14",
203+
"1.14.0",
204+
"1.15",
205+
"1.15.0",
206+
"1.15.2",
207+
"2.0",
208+
"2.0.0",
209+
"2.0.1",
210+
"2.1",
211+
"2.1.0",
197212
],
198213
)
199214
def tf_version(request):

0 commit comments

Comments
 (0)