Skip to content

Add support for TensorFlow script mode and Python 3 #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CHANGELOG
* build: added pylint
* build: upgrade docker-compose to 1.23
* enhancement: Frameworks: update warning for not setting framework_version as we aren't planning a breaking change anymore
* feature: Estimator: add script mode and py3 support for TensorFlow

1.14.1
======
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
'If you would like to use version {latest}, ' \
'please add framework_version={latest} to your constructor.'

EMPTY_FRAMEWORK_VERSION_ERROR = 'framework_version is required for this estimator. ' \
'Please add framework_version={} to your constructor to avoid this error.'

VALID_PY_VERSIONS = ['py2', 'py3']


Expand Down
89 changes: 72 additions & 17 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import time

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \
empty_framework_version_warning
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag, \
empty_framework_version_warning, EMPTY_FRAMEWORK_VERSION_ERROR
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.serving import Model
Expand Down Expand Up @@ -163,9 +163,19 @@ class TensorFlow(Framework):

__framework_name__ = 'tensorflow'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None,
py_version='py2', framework_version=None, requirements_file='', image_name=None,
**kwargs):
_DEPRECATED_ARGS = ['training_steps', 'evaluation_steps', 'requirements_file', 'checkpoint_path']
_SCRIPT_MODE = 'tensorflow-scriptmode'
_SCRIPT_MODE_SERVING_ERROR_MSG = 'Script mode containers does not support serving yet. ' \
'Please use our new tensorflow-serving container by creating the model ' \
'with \'endpoint_type\' set to \'tensorflow-serving\'.'
_SCRIPT_MODE_TENSORBOARD_WARNING = 'Tensorboard is not supported with script mode. You can run the following ' \
'command: tensorboard --logdir {} --host localhost --port 6006 This can be ' \
'run from anywhere with access to the s3 uri used as the logdir.'
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
framework_version=None, model_dir=None, requirements_file='', image_name=None,
script_mode=False, distributions=None, **kwargs):
"""Initialize an ``TensorFlow`` estimator.
Args:
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
Expand Down Expand Up @@ -196,6 +206,19 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
self.py_version = py_version
self.training_steps = training_steps
self.evaluation_steps = evaluation_steps
self.model_dir = model_dir
self.script_mode = script_mode
self.distributions = distributions

if py_version == 'py3' or script_mode:
if framework_version is None:
raise ValueError(EMPTY_FRAMEWORK_VERSION_ERROR)

if training_steps or evaluation_steps or requirements_file or checkpoint_path:
raise ValueError(
'{} are deprecated in script mode. Please do not set these arguments.'
.format(', '.join(self._DEPRECATED_ARGS))
)

self._validate_requirements_file(requirements_file)
self.requirements_file = requirements_file
Expand Down Expand Up @@ -246,6 +269,11 @@ def fit_super():
raise ValueError("Tensorboard is not supported with async fit")

if run_tensorboard_locally:

if self.script_mode_enabled():
LOGGER.warning(self._SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return here will not end up no calling super in the end of the constructor. Returning a conditional branch in the middle of the function like that can cause errors like that. You missed the conditional here https://github.com/aws/sagemaker-python-sdk/pull/475/files#diff-74b724644c87245b6cbecfb6e0bb6da2L245 as well.

What about:

if self.script_mode_enabled():
    if run_tensorboard_locally:
        LOGGER.warning(self._SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
    fit_super()
elif run_tensorboard_locally:
   ...
else:
    ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error still happens.


tensorboard = Tensorboard(self)
tensorboard.validate_requirements()

Expand Down Expand Up @@ -275,7 +303,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
model_channel_name)

# Move some of the tensorflow specific init params from hyperparameters into the main init params.
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps', 'model_dir']:
value = init_params['hyperparameters'].pop(argument, None)
if value is not None:
init_params[argument] = value
Expand Down Expand Up @@ -331,6 +359,9 @@ def create_model(self, model_server_workers=None, role=None,
if endpoint_type == 'tensorflow-serving':
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)

if self.script_mode_enabled():
raise ValueError(self._SCRIPT_MODE_SERVING_ERROR_MSG)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not setting the model here as TFS? What is the best for the customer here?


return self._create_default_model(model_server_workers=model_server_workers, role=role,
vpc_config_override=vpc_config_override)

Expand Down Expand Up @@ -363,17 +394,41 @@ def hyperparameters(self):
hyperparameters = super(TensorFlow, self).hyperparameters()

if not self.checkpoint_path:
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
self.checkpoint_path = '/opt/ml/shared/checkpoints'
else:
self.checkpoint_path = os.path.join(self.output_path,
self._current_job_name, 'checkpoints')

additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
'training_steps': self.training_steps,
'evaluation_steps': self.evaluation_steps,
'sagemaker_requirements': self.requirements_file}
self.checkpoint_path = self._default_s3_path('checkpoints')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints')


if self.script_mode_enabled():
if not self.model_dir:
self.model_dir = self._default_s3_path('model')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.model_dir = self.model_dir or self._default_s3_path('model')

additional_hyperparameters = {'model_dir': self.model_dir}
if self.distributions:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make self.distributions default empty instead of none so you do not need to check it.

if 'parameter_server' in self.distributions:
enabled = self.distributions['parameter_server'].get('enabled', False)
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
else:
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
'training_steps': self.training_steps,
'evaluation_steps': self.evaluation_steps,
'sagemaker_requirements': self.requirements_file}

hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
return hyperparameters

def _default_s3_path(self, directory):
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
return '/opt/ml/shared/{}'.format(directory)
else:
return os.path.join(self.output_path, self._current_job_name, directory)

def script_mode_enabled(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring or consider making it private.

return self.py_version == 'py3' or self.script_mode

def train_image(self):
if self.image_name:
return self.image_name

if self.script_mode_enabled():
return create_image_uri(self.sagemaker_session.boto_region_name, self._SCRIPT_MODE,
self.train_instance_type, self.framework_version, self.py_version)
else:
return super(TensorFlow, self).train_image()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid returning conditional branches here too.

Binary file added tests/data/tensorflow_mnist/data/eval_data.npy
Binary file not shown.
Binary file not shown.
Binary file added tests/data/tensorflow_mnist/data/train_data.npy
Binary file not shown.
Binary file not shown.
184 changes: 184 additions & 0 deletions tests/data/tensorflow_mnist/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Convolutional Neural Network Estimator for MNIST, built with tf.layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import os
import json
import argparse
from tensorflow.python.platform import tf_logging
import logging as _logging
import sys as _sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: organize imports



def cnn_model_fn(features, labels, mode):
"""Model function for CNN."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

# Convolutional Layer #1
# Computes 32 features using a 5x5 filter with ReLU activation.
# Padding is added to preserve width and height.
# Input Tensor Shape: [batch_size, 28, 28, 1]
# Output Tensor Shape: [batch_size, 28, 28, 32]
conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)

# Pooling Layer #1
# First max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 28, 28, 32]
# Output Tensor Shape: [batch_size, 14, 14, 32]
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

# Convolutional Layer #2
# Computes 64 features using a 5x5 filter.
# Padding is added to preserve width and height.
# Input Tensor Shape: [batch_size, 14, 14, 32]
# Output Tensor Shape: [batch_size, 14, 14, 64]
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)

# Pooling Layer #2
# Second max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 14, 14, 64]
# Output Tensor Shape: [batch_size, 7, 7, 64]
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice comments.


# Flatten tensor into a batch of vectors
# Input Tensor Shape: [batch_size, 7, 7, 64]
# Output Tensor Shape: [batch_size, 7 * 7 * 64]
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

# Dense Layer
# Densely connected layer with 1024 neurons
# Input Tensor Shape: [batch_size, 7 * 7 * 64]
# Output Tensor Shape: [batch_size, 1024]
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)

# Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)

# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, 10]
logits = tf.layers.dense(inputs=dropout, units=10)

predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

# Calculate Loss (for both TRAIN and EVAL modes)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

def _load_training_data(base_dir):
x_train = np.load(os.path.join(base_dir, 'train_data.npy'))
y_train = np.load(os.path.join(base_dir, 'train_labels.npy'))
return x_train, y_train

def _load_testing_data(base_dir):
x_test = np.load(os.path.join(base_dir, 'eval_data.npy'))
y_test = np.load(os.path.join(base_dir, 'eval_labels.npy'))
return x_test, y_test

def _parse_args():

parser = argparse.ArgumentParser()

# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument('--epochs', type=int, default=1)
# Data, model, and output directories
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model_dir', type=str, default=os.environ['SM_MODEL_DIR'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the default parameter here given that you are passing the model dir

parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about os.environ.get('SM_CURRENT_HOST') instead so the script does not fail outside a container.


return parser.parse_known_args()

def serving_input_fn():
inputs = {'x': tf.placeholder(tf.float32, [None, 784])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)

if __name__ == "__main__":
args, unknown = _parse_args()
tf.logging.set_verbosity(tf.logging.DEBUG)
_handler = _logging.StreamHandler(_sys.stdout)
tf_logger = tf_logging._get_logger()
tf_logger.handlers = [_handler]

if args.model_dir.startswith('s3://'):
os.environ['S3_REGION'] = 'us-west-2'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
os.environ['S3_USE_HTTPS'] = '1'

train_data, train_labels = _load_training_data(args.train)
eval_data, eval_labels = _load_testing_data(args.train)

# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir=args.model_dir)

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=50)

# Train the model
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
num_epochs=None,
shuffle=True)

# Evaluate the model and print results
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)

train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=500)
eval_spec = tf.estimator.EvalSpec(eval_input_fn)
tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec)

if args.current_host == args.hosts[0]:
mnist_classifier.export_savedmodel(args.model_dir, serving_input_fn)

tf_logger.info('====== Training finished =========')
Loading