Skip to content

Commit 1dd0cd4

Browse files
committed
Completed Keras example
1 parent d50de21 commit 1dd0cd4

File tree

2 files changed

+167
-88
lines changed

2 files changed

+167
-88
lines changed

sagemaker-python-sdk/tensorflow_keras_cifar10/cifar10_cnn.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
import os
66

77
import tensorflow as tf
8-
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
9-
from tensorflow.python.keras._impl.keras.engine.topology import InputLayer
10-
from tensorflow.python.keras._impl.keras.layers import Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense
11-
from tensorflow.python.keras._impl.keras.models import Sequential
12-
from tensorflow.python.keras._impl.keras.optimizers import rmsprop
8+
from tensorflow.python.keras.layers import InputLayer, Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense
9+
from tensorflow.python.keras.models import Sequential
10+
from tensorflow.python.keras.optimizers import RMSprop
1311
from tensorflow.python.saved_model.signature_constants import PREDICT_INPUTS
1412

1513
HEIGHT = 32
@@ -18,23 +16,23 @@
1816
NUM_CLASSES = 10
1917
NUM_DATA_BATCHES = 5
2018
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
21-
BATCH_SIZE = 1
19+
BATCH_SIZE = 128
2220

2321

2422
def keras_model_fn(hyperparameters):
2523
"""keras_model_fn receives hyperparameters from the training job and returns a compiled keras model.
26-
The model will transformed in a TensorFlow Estimator before training and it will saved in a TensorFlow Serving
27-
SavedModel in the end of training.
24+
The model will be transformed into a TensorFlow Estimator before training and it will be saved in a
25+
TensorFlow Serving SavedModel at the end of training.
2826
2927
Args:
30-
hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow training
31-
script.
28+
hyperparameters: The hyperparameters passed to the SageMaker TrainingJob that runs your TensorFlow
29+
training script.
3230
Returns: A compiled Keras model
3331
"""
3432
model = Sequential()
3533

36-
# TensorFlow Serving default prediction input tensor name is PREDICT_INPUTS. I will keep the same name for the
37-
# InputLayer
34+
# TensorFlow Serving default prediction input tensor name is PREDICT_INPUTS.
35+
# We must conform to this naming scheme.
3836
model.add(InputLayer(input_shape=(HEIGHT, WIDTH, DEPTH), name=PREDICT_INPUTS))
3937
model.add(Conv2D(32, (3, 3), padding='same'))
4038
model.add(Activation('relu'))
@@ -56,69 +54,56 @@ def keras_model_fn(hyperparameters):
5654
model.add(Dropout(0.5))
5755
model.add(Dense(NUM_CLASSES))
5856
model.add(Activation('softmax'))
57+
58+
_model = tf.keras.Model(inputs=model.input, outputs=model.output)
5959

60-
opt = rmsprop(lr=0.0001, decay=1e-6)
60+
opt = RMSprop(lr=hyperparameters['learning_rate'], decay=hyperparameters['decay'])
6161

62-
model.compile(loss='categorical_crossentropy',
62+
_model.compile(loss='categorical_crossentropy',
6363
optimizer=opt,
6464
metrics=['accuracy'])
6565

66-
print(model.summary())
66+
return _model
6767

68-
return model
69-
70-
71-
def serving_input_fn(hyperparameters):
72-
"""This function defines the placeholders that will be added to the model during serving.
73-
The function returns a tf.estimator.export.ServingInputReceiver object, which packages the placeholders and the
74-
resulting feature Tensors together.
75-
76-
For more information: https://github.com/aws/sagemaker-python-sdk#creating-a-serving_input_fn
77-
78-
Args:
79-
hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow training
80-
script.
81-
Returns: ServingInputReceiver or fn that returns a ServingInputReceiver
82-
"""
8368

69+
def serving_input_fn(params):
8470
# Notice that the input placeholder has the same input shape as the Keras model input
8571
tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])
86-
87-
# the features key PREDICT_INPUTS matches the Keras Input Layer name
88-
features = {PREDICT_INPUTS: tensor}
89-
return build_raw_serving_input_receiver_fn(features)
72+
73+
# The inputs key PREDICT_INPUTS matches the Keras InputLayer name
74+
inputs = {PREDICT_INPUTS: tensor}
75+
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
9076

9177

92-
def train_input_fn(training_dir, hyperparameters):
78+
def train_input_fn(training_dir, params):
9379
return _input(tf.estimator.ModeKeys.TRAIN,
94-
batch_size=BATCH_SIZE, data_dir=training_dir)
80+
batch_size=BATCH_SIZE, data_dir=training_dir)
9581

9682

97-
def eval_input_fn(training_dir, hyperparameters):
83+
def eval_input_fn(training_dir, params):
9884
return _input(tf.estimator.ModeKeys.EVAL,
99-
batch_size=BATCH_SIZE, data_dir=training_dir)
85+
batch_size=BATCH_SIZE, data_dir=training_dir)
10086

10187

10288
def _input(mode, batch_size, data_dir):
103-
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
104-
105-
Args:
106-
mode: Standard names for model modes (tf.estimators.ModeKeys).
107-
batch_size: The number of samples per batch of input requested.
108-
"""
89+
"""Uses the tf.data input pipeline for CIFAR-10 dataset.
90+
Args:
91+
mode: Standard names for model modes (tf.estimators.ModeKeys).
92+
batch_size: The number of samples per batch of input requested.
93+
"""
10994
dataset = _record_dataset(_filenames(mode, data_dir))
11095

11196
# For training repeat forever.
11297
if mode == tf.estimator.ModeKeys.TRAIN:
11398
dataset = dataset.repeat()
11499

115-
dataset = dataset.map(_dataset_parser, num_threads=1,
116-
output_buffer_size=2 * batch_size)
100+
dataset = dataset.map(_dataset_parser)
101+
dataset.prefetch(2 * batch_size)
117102

118103
# For training, preprocess the image and shuffle.
119104
if mode == tf.estimator.ModeKeys.TRAIN:
120-
dataset = dataset.map(_train_preprocess_fn, num_threads=1,
121-
output_buffer_size=2 * batch_size)
105+
dataset = dataset.map(_train_preprocess_fn)
106+
dataset.prefetch(2 * batch_size)
122107

123108
# Ensure that the capacity is sufficiently large to provide good random
124109
# shuffling.
@@ -127,9 +112,8 @@ def _input(mode, batch_size, data_dir):
127112

128113
# Subtract off the mean and divide by the variance of the pixels.
129114
dataset = dataset.map(
130-
lambda image, label: (tf.image.per_image_standardization(image), label),
131-
num_threads=1,
132-
output_buffer_size=2 * batch_size)
115+
lambda image, label: (tf.image.per_image_standardization(image), label))
116+
dataset.prefetch(2 * batch_size)
133117

134118
# Batch results by up to batch_size, and then fetch the tuple from the
135119
# iterator.
@@ -182,7 +166,7 @@ def _dataset_parser(value):
182166
def _record_dataset(filenames):
183167
"""Returns an input pipeline Dataset from `filenames`."""
184168
record_bytes = HEIGHT * WIDTH * DEPTH + 1
185-
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
169+
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
186170

187171

188172
def _filenames(mode, data_dir):

0 commit comments

Comments
 (0)