Skip to content

Commit b983974

Browse files
authored
Merge pull request aws#353 from FarhanTejani/mvs-keras-notebook
Completed mvs-keras-notebook example
2 parents 8841a27 + 8cb842e commit b983974

File tree

4 files changed

+559
-0
lines changed

4 files changed

+559
-0
lines changed

sagemaker-python-sdk/tensorflow_keras_cifar10/__init__.py

Whitespace-only changes.
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# https://aws.amazon.com/apache-2-0/
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
from __future__ import absolute_import
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import os
19+
20+
import tensorflow as tf
21+
from tensorflow.python.keras.layers import InputLayer, Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense
22+
from tensorflow.python.keras.models import Sequential
23+
from tensorflow.python.keras.optimizers import RMSprop
24+
from tensorflow.python.saved_model.signature_constants import PREDICT_INPUTS
25+
26+
HEIGHT = 32
27+
WIDTH = 32
28+
DEPTH = 3
29+
NUM_CLASSES = 10
30+
NUM_DATA_BATCHES = 5
31+
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
32+
BATCH_SIZE = 128
33+
34+
35+
def keras_model_fn(hyperparameters):
36+
"""keras_model_fn receives hyperparameters from the training job and returns a compiled keras model.
37+
The model will be transformed into a TensorFlow Estimator before training and it will be saved in a
38+
TensorFlow Serving SavedModel at the end of training.
39+
40+
Args:
41+
hyperparameters: The hyperparameters passed to the SageMaker TrainingJob that runs your TensorFlow
42+
training script.
43+
Returns: A compiled Keras model
44+
"""
45+
model = Sequential()
46+
47+
# TensorFlow Serving default prediction input tensor name is PREDICT_INPUTS.
48+
# We must conform to this naming scheme.
49+
model.add(InputLayer(input_shape=(HEIGHT, WIDTH, DEPTH), name=PREDICT_INPUTS))
50+
model.add(Conv2D(32, (3, 3), padding='same'))
51+
model.add(Activation('relu'))
52+
model.add(Conv2D(32, (3, 3)))
53+
model.add(Activation('relu'))
54+
model.add(MaxPooling2D(pool_size=(2, 2)))
55+
model.add(Dropout(0.25))
56+
57+
model.add(Conv2D(64, (3, 3), padding='same'))
58+
model.add(Activation('relu'))
59+
model.add(Conv2D(64, (3, 3)))
60+
model.add(Activation('relu'))
61+
model.add(MaxPooling2D(pool_size=(2, 2)))
62+
model.add(Dropout(0.25))
63+
64+
model.add(Flatten())
65+
model.add(Dense(512))
66+
model.add(Activation('relu'))
67+
model.add(Dropout(0.5))
68+
model.add(Dense(NUM_CLASSES))
69+
model.add(Activation('softmax'))
70+
71+
_model = tf.keras.Model(inputs=model.input, outputs=model.output)
72+
73+
opt = RMSprop(lr=hyperparameters['learning_rate'], decay=hyperparameters['decay'])
74+
75+
_model.compile(loss='categorical_crossentropy',
76+
optimizer=opt,
77+
metrics=['accuracy'])
78+
79+
return _model
80+
81+
82+
def serving_input_fn(hyperparameters):
83+
"""This function defines the placeholders that will be added to the model during serving.
84+
The function returns a tf.estimator.export.ServingInputReceiver object, which packages the
85+
placeholders and the resulting feature Tensors together.
86+
For more information: https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/README.rst#creating-a-serving_input_fn
87+
88+
Args:
89+
hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow
90+
training script.
91+
Returns: ServingInputReceiver or fn that returns a ServingInputReceiver
92+
"""
93+
94+
# Notice that the input placeholder has the same input shape as the Keras model input
95+
tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])
96+
97+
# The inputs key PREDICT_INPUTS matches the Keras InputLayer name
98+
inputs = {PREDICT_INPUTS: tensor}
99+
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
100+
101+
102+
def train_input_fn(training_dir, hyperparameters):
103+
"""Returns input function that would feed the model during training"""
104+
return _input(tf.estimator.ModeKeys.TRAIN,
105+
batch_size=BATCH_SIZE, data_dir=training_dir)
106+
107+
108+
def eval_input_fn(training_dir, hyperparameters):
109+
"""Returns input function that would feed the model during evaluation"""
110+
return _input(tf.estimator.ModeKeys.EVAL,
111+
batch_size=BATCH_SIZE, data_dir=training_dir)
112+
113+
114+
def _input(mode, batch_size, data_dir):
115+
"""Uses the tf.data input pipeline for CIFAR-10 dataset.
116+
Args:
117+
mode: Standard names for model modes (tf.estimators.ModeKeys).
118+
batch_size: The number of samples per batch of input requested.
119+
"""
120+
dataset = _record_dataset(_filenames(mode, data_dir))
121+
122+
# For training repeat forever.
123+
if mode == tf.estimator.ModeKeys.TRAIN:
124+
dataset = dataset.repeat()
125+
126+
dataset = dataset.map(_dataset_parser)
127+
dataset.prefetch(2 * batch_size)
128+
129+
# For training, preprocess the image and shuffle.
130+
if mode == tf.estimator.ModeKeys.TRAIN:
131+
dataset = dataset.map(_train_preprocess_fn)
132+
dataset.prefetch(2 * batch_size)
133+
134+
# Ensure that the capacity is sufficiently large to provide good random
135+
# shuffling.
136+
buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size
137+
dataset = dataset.shuffle(buffer_size=buffer_size)
138+
139+
# Subtract off the mean and divide by the variance of the pixels.
140+
dataset = dataset.map(
141+
lambda image, label: (tf.image.per_image_standardization(image), label))
142+
dataset.prefetch(2 * batch_size)
143+
144+
# Batch results by up to batch_size, and then fetch the tuple from the
145+
# iterator.
146+
iterator = dataset.batch(batch_size).make_one_shot_iterator()
147+
images, labels = iterator.get_next()
148+
149+
# We must use the default input tensor name PREDICT_INPUTS
150+
return {PREDICT_INPUTS: images}, labels
151+
152+
153+
def _train_preprocess_fn(image, label):
154+
"""Preprocess a single training image of layout [height, width, depth]."""
155+
# Resize the image to add four extra pixels on each side.
156+
image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8)
157+
158+
# Randomly crop a [HEIGHT, WIDTH] section of the image.
159+
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
160+
161+
# Randomly flip the image horizontally.
162+
image = tf.image.random_flip_left_right(image)
163+
164+
return image, label
165+
166+
167+
def _dataset_parser(value):
168+
"""Parse a CIFAR-10 record from value."""
169+
# Every record consists of a label followed by the image, with a fixed number
170+
# of bytes for each.
171+
label_bytes = 1
172+
image_bytes = HEIGHT * WIDTH * DEPTH
173+
record_bytes = label_bytes + image_bytes
174+
175+
# Convert from a string to a vector of uint8 that is record_bytes long.
176+
raw_record = tf.decode_raw(value, tf.uint8)
177+
178+
# The first byte represents the label, which we convert from uint8 to int32.
179+
label = tf.cast(raw_record[0], tf.int32)
180+
181+
# The remaining bytes after the label represent the image, which we reshape
182+
# from [depth * height * width] to [depth, height, width].
183+
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
184+
[DEPTH, HEIGHT, WIDTH])
185+
186+
# Convert from [depth, height, width] to [height, width, depth], and cast as
187+
# float32.
188+
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
189+
190+
return image, tf.one_hot(label, NUM_CLASSES)
191+
192+
193+
def _record_dataset(filenames):
194+
"""Returns an input pipeline Dataset from `filenames`."""
195+
record_bytes = HEIGHT * WIDTH * DEPTH + 1
196+
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
197+
198+
199+
def _filenames(mode, data_dir):
200+
"""Returns a list of filenames based on 'mode'."""
201+
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
202+
203+
assert os.path.exists(data_dir), ('Run cifar10_download_and_extract.py first '
204+
'to download and extract the CIFAR-10 data.')
205+
206+
if mode == tf.estimator.ModeKeys.TRAIN:
207+
return [
208+
os.path.join(data_dir, 'data_batch_%d.bin' % i)
209+
for i in range(1, NUM_DATA_BATCHES + 1)
210+
]
211+
elif mode == tf.estimator.ModeKeys.EVAL:
212+
return [os.path.join(data_dir, 'test_batch.bin')]
213+
else:
214+
raise ValueError('Invalid mode: %s' % mode)

0 commit comments

Comments
 (0)