Skip to content

Commit d50de21

Browse files
committed
Add keras notebook example
1 parent b93d177 commit d50de21

File tree

4 files changed

+453
-0
lines changed

4 files changed

+453
-0
lines changed

sagemaker-python-sdk/tensorflow_keras_cifar10/__init__.py

Whitespace-only changes.
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import os
6+
7+
import tensorflow as tf
8+
from tensorflow.python.estimator.export.export import build_raw_serving_input_receiver_fn
9+
from tensorflow.python.keras._impl.keras.engine.topology import InputLayer
10+
from tensorflow.python.keras._impl.keras.layers import Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense
11+
from tensorflow.python.keras._impl.keras.models import Sequential
12+
from tensorflow.python.keras._impl.keras.optimizers import rmsprop
13+
from tensorflow.python.saved_model.signature_constants import PREDICT_INPUTS
14+
15+
HEIGHT = 32
16+
WIDTH = 32
17+
DEPTH = 3
18+
NUM_CLASSES = 10
19+
NUM_DATA_BATCHES = 5
20+
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
21+
BATCH_SIZE = 1
22+
23+
24+
def keras_model_fn(hyperparameters):
25+
"""keras_model_fn receives hyperparameters from the training job and returns a compiled keras model.
26+
The model will transformed in a TensorFlow Estimator before training and it will saved in a TensorFlow Serving
27+
SavedModel in the end of training.
28+
29+
Args:
30+
hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow training
31+
script.
32+
Returns: A compiled Keras model
33+
"""
34+
model = Sequential()
35+
36+
# TensorFlow Serving default prediction input tensor name is PREDICT_INPUTS. I will keep the same name for the
37+
# InputLayer
38+
model.add(InputLayer(input_shape=(HEIGHT, WIDTH, DEPTH), name=PREDICT_INPUTS))
39+
model.add(Conv2D(32, (3, 3), padding='same'))
40+
model.add(Activation('relu'))
41+
model.add(Conv2D(32, (3, 3)))
42+
model.add(Activation('relu'))
43+
model.add(MaxPooling2D(pool_size=(2, 2)))
44+
model.add(Dropout(0.25))
45+
46+
model.add(Conv2D(64, (3, 3), padding='same'))
47+
model.add(Activation('relu'))
48+
model.add(Conv2D(64, (3, 3)))
49+
model.add(Activation('relu'))
50+
model.add(MaxPooling2D(pool_size=(2, 2)))
51+
model.add(Dropout(0.25))
52+
53+
model.add(Flatten())
54+
model.add(Dense(512))
55+
model.add(Activation('relu'))
56+
model.add(Dropout(0.5))
57+
model.add(Dense(NUM_CLASSES))
58+
model.add(Activation('softmax'))
59+
60+
opt = rmsprop(lr=0.0001, decay=1e-6)
61+
62+
model.compile(loss='categorical_crossentropy',
63+
optimizer=opt,
64+
metrics=['accuracy'])
65+
66+
print(model.summary())
67+
68+
return model
69+
70+
71+
def serving_input_fn(hyperparameters):
72+
"""This function defines the placeholders that will be added to the model during serving.
73+
The function returns a tf.estimator.export.ServingInputReceiver object, which packages the placeholders and the
74+
resulting feature Tensors together.
75+
76+
For more information: https://github.com/aws/sagemaker-python-sdk#creating-a-serving_input_fn
77+
78+
Args:
79+
hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow training
80+
script.
81+
Returns: ServingInputReceiver or fn that returns a ServingInputReceiver
82+
"""
83+
84+
# Notice that the input placeholder has the same input shape as the Keras model input
85+
tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])
86+
87+
# the features key PREDICT_INPUTS matches the Keras Input Layer name
88+
features = {PREDICT_INPUTS: tensor}
89+
return build_raw_serving_input_receiver_fn(features)
90+
91+
92+
def train_input_fn(training_dir, hyperparameters):
93+
return _input(tf.estimator.ModeKeys.TRAIN,
94+
batch_size=BATCH_SIZE, data_dir=training_dir)
95+
96+
97+
def eval_input_fn(training_dir, hyperparameters):
98+
return _input(tf.estimator.ModeKeys.EVAL,
99+
batch_size=BATCH_SIZE, data_dir=training_dir)
100+
101+
102+
def _input(mode, batch_size, data_dir):
103+
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
104+
105+
Args:
106+
mode: Standard names for model modes (tf.estimators.ModeKeys).
107+
batch_size: The number of samples per batch of input requested.
108+
"""
109+
dataset = _record_dataset(_filenames(mode, data_dir))
110+
111+
# For training repeat forever.
112+
if mode == tf.estimator.ModeKeys.TRAIN:
113+
dataset = dataset.repeat()
114+
115+
dataset = dataset.map(_dataset_parser, num_threads=1,
116+
output_buffer_size=2 * batch_size)
117+
118+
# For training, preprocess the image and shuffle.
119+
if mode == tf.estimator.ModeKeys.TRAIN:
120+
dataset = dataset.map(_train_preprocess_fn, num_threads=1,
121+
output_buffer_size=2 * batch_size)
122+
123+
# Ensure that the capacity is sufficiently large to provide good random
124+
# shuffling.
125+
buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size
126+
dataset = dataset.shuffle(buffer_size=buffer_size)
127+
128+
# Subtract off the mean and divide by the variance of the pixels.
129+
dataset = dataset.map(
130+
lambda image, label: (tf.image.per_image_standardization(image), label),
131+
num_threads=1,
132+
output_buffer_size=2 * batch_size)
133+
134+
# Batch results by up to batch_size, and then fetch the tuple from the
135+
# iterator.
136+
iterator = dataset.batch(batch_size).make_one_shot_iterator()
137+
images, labels = iterator.get_next()
138+
139+
return {PREDICT_INPUTS: images}, labels
140+
141+
142+
def _train_preprocess_fn(image, label):
143+
"""Preprocess a single training image of layout [height, width, depth]."""
144+
# Resize the image to add four extra pixels on each side.
145+
image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8)
146+
147+
# Randomly crop a [HEIGHT, WIDTH] section of the image.
148+
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
149+
150+
# Randomly flip the image horizontally.
151+
image = tf.image.random_flip_left_right(image)
152+
153+
return image, label
154+
155+
156+
def _dataset_parser(value):
157+
"""Parse a CIFAR-10 record from value."""
158+
# Every record consists of a label followed by the image, with a fixed number
159+
# of bytes for each.
160+
label_bytes = 1
161+
image_bytes = HEIGHT * WIDTH * DEPTH
162+
record_bytes = label_bytes + image_bytes
163+
164+
# Convert from a string to a vector of uint8 that is record_bytes long.
165+
raw_record = tf.decode_raw(value, tf.uint8)
166+
167+
# The first byte represents the label, which we convert from uint8 to int32.
168+
label = tf.cast(raw_record[0], tf.int32)
169+
170+
# The remaining bytes after the label represent the image, which we reshape
171+
# from [depth * height * width] to [depth, height, width].
172+
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
173+
[DEPTH, HEIGHT, WIDTH])
174+
175+
# Convert from [depth, height, width] to [height, width, depth], and cast as
176+
# float32.
177+
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
178+
179+
return image, tf.one_hot(label, NUM_CLASSES)
180+
181+
182+
def _record_dataset(filenames):
183+
"""Returns an input pipeline Dataset from `filenames`."""
184+
record_bytes = HEIGHT * WIDTH * DEPTH + 1
185+
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
186+
187+
188+
def _filenames(mode, data_dir):
189+
"""Returns a list of filenames based on 'mode'."""
190+
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
191+
192+
assert os.path.exists(data_dir), ('Run cifar10_download_and_extract.py first '
193+
'to download and extract the CIFAR-10 data.')
194+
195+
if mode == tf.estimator.ModeKeys.TRAIN:
196+
return [
197+
os.path.join(data_dir, 'data_batch_%d.bin' % i)
198+
for i in range(1, NUM_DATA_BATCHES + 1)
199+
]
200+
elif mode == tf.estimator.ModeKeys.EVAL:
201+
return [os.path.join(data_dir, 'test_batch.bin')]
202+
else:
203+
raise ValueError('Invalid mode: %s' % mode)

0 commit comments

Comments
 (0)