-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add support for TensorFlow script mode and Python 3 #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
429c474
12a258b
d14db32
2de88a7
e9e2592
b09867d
358a811
aca3958
311e164
cfea5ab
75028a0
8c7d644
9cd35fb
fea9cfc
2b693c2
b5057a7
b0794b8
c762318
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,8 +22,7 @@ | |
import time | ||
|
||
from sagemaker.estimator import Framework | ||
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \ | ||
empty_framework_version_warning | ||
import sagemaker.fw_utils as fw | ||
from sagemaker.tensorflow.defaults import TF_VERSION | ||
from sagemaker.tensorflow.model import TensorFlowModel | ||
from sagemaker.tensorflow.serving import Model | ||
|
@@ -34,6 +33,16 @@ | |
LOGGER = logging.getLogger('sagemaker') | ||
|
||
|
||
_FRAMEWORK_MODE_ARGS = ('training_steps', 'evaluation_steps', 'requirements_file', 'checkpoint_path') | ||
_SCRIPT_MODE = 'tensorflow-scriptmode' | ||
_SCRIPT_MODE_SERVING_ERROR_MSG = 'Script mode containers does not support serving yet. ' \ | ||
'Please use our new tensorflow-serving container by creating the model ' \ | ||
'with \'endpoint_type\' set to \'tensorflow-serving\'.' | ||
_SCRIPT_MODE_TENSORBOARD_WARNING = 'Tensorboard is not supported with script mode. You can run the following ' \ | ||
'command: tensorboard --logdir {} --host localhost --port 6006 This can be ' \ | ||
'run from anywhere with access to the S3 URI used as the logdir.' | ||
|
||
|
||
class Tensorboard(threading.Thread): | ||
def __init__(self, estimator, logdir=None): | ||
"""Initialize ``Tensorboard`` instance. | ||
|
@@ -163,9 +172,9 @@ class TensorFlow(Framework): | |
|
||
__framework_name__ = 'tensorflow' | ||
|
||
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, | ||
py_version='py2', framework_version=None, requirements_file='', image_name=None, | ||
**kwargs): | ||
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', | ||
framework_version=None, model_dir=None, requirements_file='', image_name=None, | ||
mvsusp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
script_mode=False, distributions=None, **kwargs): | ||
"""Initialize an ``TensorFlow`` estimator. | ||
Args: | ||
training_steps (int): Perform this many steps of training. `None`, the default means train forever. | ||
|
@@ -176,6 +185,9 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N | |
py_version (str): Python version you want to use for executing your model training code (default: 'py2'). | ||
framework_version (str): TensorFlow version you want to use for executing your model training code. | ||
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators | ||
model_dir (str): S3 location where the checkpoint data and models can be exported to during training | ||
(default: None). If not specified a default S3 URI will be generated. It will be passed in the | ||
training script as one of the command line arguments. | ||
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and | ||
relative to ``source_dir``. Details on the format can be found in the | ||
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_. | ||
|
@@ -185,21 +197,61 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N | |
Examples: | ||
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 | ||
custom-image:latest. | ||
script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False). | ||
This will be ignored if py_version is set to 'py3'. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little worried about the implicit script mode if Python 3 given the popularity of the Python 3 support for TF feature request. Maybe we should log a warning somewhere if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I will do this in the follow up PR with the doc update |
||
distribution (dict): A dictionary with information on how to run distributed training | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please, explain the format of the dict or link to docs explaining it |
||
(default: None). Currently we only support distributed training with parameter servers. To enable it | ||
use the following setup: | ||
{ | ||
'parameter_server': | ||
{ | ||
'enabled': True | ||
} | ||
} | ||
**kwargs: Additional kwargs passed to the Framework constructor. | ||
""" | ||
if framework_version is None: | ||
LOGGER.warning(empty_framework_version_warning(TF_VERSION, TF_VERSION)) | ||
LOGGER.warning(fw.empty_framework_version_warning(TF_VERSION, TF_VERSION)) | ||
self.framework_version = framework_version or TF_VERSION | ||
|
||
super(TensorFlow, self).__init__(image_name=image_name, **kwargs) | ||
self.checkpoint_path = checkpoint_path | ||
self.py_version = py_version | ||
self.training_steps = training_steps | ||
self.evaluation_steps = evaluation_steps | ||
self.model_dir = model_dir | ||
self.script_mode = script_mode | ||
self.distributions = distributions or {} | ||
|
||
self._validate_args(py_version=py_version, script_mode=script_mode, framework_version=framework_version, | ||
training_steps=training_steps, evaluation_steps=evaluation_steps, | ||
requirements_file=requirements_file, checkpoint_path=checkpoint_path) | ||
self._validate_requirements_file(requirements_file) | ||
self.requirements_file = requirements_file | ||
|
||
def _validate_args(self, py_version, script_mode, framework_version, training_steps, | ||
evaluation_steps, requirements_file, checkpoint_path): | ||
|
||
if py_version == 'py3' or script_mode: | ||
|
||
if framework_version is None: | ||
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR) | ||
|
||
found_args = [] | ||
if training_steps: | ||
found_args.append('training_steps') | ||
if evaluation_steps: | ||
found_args.append('evaluation_steps') | ||
if requirements_file: | ||
found_args.append('requirements_file') | ||
if checkpoint_path: | ||
found_args.append('checkpoint_path') | ||
if found_args: | ||
raise AttributeError( | ||
'{} are deprecated in script mode. Please do not set {}.' | ||
.format(', '.join(_FRAMEWORK_MODE_ARGS), ', '.join(found_args)) | ||
) | ||
icywang86rui marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _validate_requirements_file(self, requirements_file): | ||
if not requirements_file: | ||
return | ||
|
@@ -245,7 +297,10 @@ def fit_super(): | |
if run_tensorboard_locally and wait is False: | ||
raise ValueError("Tensorboard is not supported with async fit") | ||
|
||
if run_tensorboard_locally: | ||
if self._script_mode_enabled() and run_tensorboard_locally: | ||
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir)) | ||
fit_super() | ||
elif run_tensorboard_locally: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this is simpler? if run_tensorboard_locally:
tensorboard = Tensorboard(self)
tensorboard.validate_requirements()
try:
tensorboard.start()
fit_super()
finally:
# sleep 20 secs for tensorboard start up if fit() quits instantly
time.sleep(20)
tensorboard.event.set()
tensorboard.join()
else:
if self._script_mode_enabled():
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
fit_super() or even try:
if run_tensorboard_locally:
tensorboard = Tensorboard(self)
tensorboard.validate_requirements()
tensorboard.start()
finally:
if run_tensorboard_locally:
# sleep 20 secs for tensorboard start up if fit() quits instantly
time.sleep(20)
tensorboard.event.set()
tensorboard.join()
if self._script_mode_enabled():
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
fit_super() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two doesn't read much better to me. I have combined the first two ifs. Let me know if you feel strongly about this. :) |
||
tensorboard = Tensorboard(self) | ||
tensorboard.validate_requirements() | ||
|
||
|
@@ -275,13 +330,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na | |
model_channel_name) | ||
|
||
# Move some of the tensorflow specific init params from hyperparameters into the main init params. | ||
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']: | ||
for argument in ('checkpoint_path', 'training_steps', 'evaluation_steps', 'model_dir'): | ||
value = init_params['hyperparameters'].pop(argument, None) | ||
if value is not None: | ||
init_params[argument] = value | ||
|
||
image_name = init_params.pop('image') | ||
framework, py_version, tag = framework_name_from_image(image_name) | ||
framework, py_version, tag = fw.framework_name_from_image(image_name) | ||
if not framework: | ||
# If we were unable to parse the framework name from the image it is not one of our | ||
# officially supported images, in this case just add the image to the init params. | ||
|
@@ -294,7 +349,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na | |
# containing framework version, device type and python version (e.g. '1.5-gpu-py2'). | ||
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version | ||
# otherwise extract framework version from the tag itself. | ||
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag( | ||
init_params['framework_version'] = '1.4' if tag == '1.0' else fw.framework_version_from_tag( | ||
tag) | ||
|
||
training_job_name = init_params['base_job_name'] | ||
|
@@ -328,7 +383,7 @@ def create_model(self, model_server_workers=None, role=None, | |
""" | ||
|
||
role = role or self.role | ||
if endpoint_type == 'tensorflow-serving': | ||
if endpoint_type == 'tensorflow-serving' or self._script_mode_enabled(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override) | ||
|
||
return self._create_default_model(model_server_workers=model_server_workers, role=role, | ||
|
@@ -362,18 +417,39 @@ def hyperparameters(self): | |
"""Return hyperparameters used by your custom TensorFlow code during model training.""" | ||
hyperparameters = super(TensorFlow, self).hyperparameters() | ||
|
||
if not self.checkpoint_path: | ||
local_code = get_config_value('local.local_code', self.sagemaker_session.config) | ||
if self.sagemaker_session.local_mode and local_code: | ||
self.checkpoint_path = '/opt/ml/shared/checkpoints' | ||
else: | ||
self.checkpoint_path = os.path.join(self.output_path, | ||
self._current_job_name, 'checkpoints') | ||
self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this parameter not used in script mode? Let's not set it if that is the case to avoid future breaking changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this need to be set for the framework mode containers for now. We will need to remove them in the future. |
||
|
||
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path, | ||
'training_steps': self.training_steps, | ||
'evaluation_steps': self.evaluation_steps, | ||
'sagemaker_requirements': self.requirements_file} | ||
if self._script_mode_enabled(): | ||
self.model_dir = self.model_dir or self._default_s3_path('model') | ||
additional_hyperparameters = {'model_dir': self.model_dir} | ||
if 'parameter_server' in self.distributions: | ||
enabled = self.distributions['parameter_server'].get('enabled', False) | ||
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled | ||
else: | ||
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path, | ||
'training_steps': self.training_steps, | ||
'evaluation_steps': self.evaluation_steps, | ||
'sagemaker_requirements': self.requirements_file} | ||
|
||
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) | ||
return hyperparameters | ||
|
||
def _default_s3_path(self, directory): | ||
local_code = get_config_value('local.local_code', self.sagemaker_session.config) | ||
if self.sagemaker_session.local_mode and local_code: | ||
return '/opt/ml/shared/{}'.format(directory) | ||
else: | ||
return os.path.join(self.output_path, self._current_job_name, directory) | ||
|
||
def _script_mode_enabled(self): | ||
return self.py_version == 'py3' or self.script_mode | ||
|
||
def train_image(self): | ||
if self.image_name: | ||
return self.image_name | ||
|
||
if self._script_mode_enabled(): | ||
return fw.create_image_uri(self.sagemaker_session.boto_region_name, _SCRIPT_MODE, | ||
self.train_instance_type, self.framework_version, self.py_version) | ||
|
||
return super(TensorFlow, self).train_image() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure this is in the correct changelog entry. maybe even warrants a bigger version bump?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make this decision tomorrow with the pr to bump version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure