Skip to content

Commit 8cb842e

Browse files
committed
Addressed PR comments
1 parent 1dd0cd4 commit 8cb842e

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

sagemaker-python-sdk/tensorflow_keras_cifar10/cifar10_cnn.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# https://aws.amazon.com/apache-2-0/
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
114
from __future__ import absolute_import
215
from __future__ import division
316
from __future__ import print_function
@@ -66,7 +79,18 @@ def keras_model_fn(hyperparameters):
6679
return _model
6780

6881

69-
def serving_input_fn(params):
82+
def serving_input_fn(hyperparameters):
83+
"""This function defines the placeholders that will be added to the model during serving.
84+
The function returns a tf.estimator.export.ServingInputReceiver object, which packages the
85+
placeholders and the resulting feature Tensors together.
86+
For more information: https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/README.rst#creating-a-serving_input_fn
87+
88+
Args:
89+
hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow
90+
training script.
91+
Returns: ServingInputReceiver or fn that returns a ServingInputReceiver
92+
"""
93+
7094
# Notice that the input placeholder has the same input shape as the Keras model input
7195
tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])
7296

@@ -75,12 +99,14 @@ def serving_input_fn(params):
7599
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
76100

77101

78-
def train_input_fn(training_dir, params):
102+
def train_input_fn(training_dir, hyperparameters):
103+
"""Returns input function that would feed the model during training"""
79104
return _input(tf.estimator.ModeKeys.TRAIN,
80105
batch_size=BATCH_SIZE, data_dir=training_dir)
81106

82107

83-
def eval_input_fn(training_dir, params):
108+
def eval_input_fn(training_dir, hyperparameters):
109+
"""Returns input function that would feed the model during evaluation"""
84110
return _input(tf.estimator.ModeKeys.EVAL,
85111
batch_size=BATCH_SIZE, data_dir=training_dir)
86112

@@ -120,6 +146,7 @@ def _input(mode, batch_size, data_dir):
120146
iterator = dataset.batch(batch_size).make_one_shot_iterator()
121147
images, labels = iterator.get_next()
122148

149+
# We must use the default input tensor name PREDICT_INPUTS
123150
return {PREDICT_INPUTS: images}, labels
124151

125152

sagemaker-python-sdk/tensorflow_keras_cifar10/tensorflow_keras_CIFAR10.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
"source": [
165165
"This function builds and returns a compiled Keras model.\n",
166166
"\n",
167-
"**Note:** The first layer is named `PREDICT_INPUTS`. This serves as a workaround for a known issue where SageMaker does not recognize the default (or any custom) name for the first layer of Keras models. Furthermore, note that we are wrapping our model in a `tf.keras.Model` before returning it. This serves as a workaround for a known issue where a Sequential model cannot be directly converted into an estimator. See [here](https://github.com/tensorflow/tensorflow/issues/20552) for more information about the issue.\n",
167+
"**Note:** The first layer is named `PREDICT_INPUTS`. This serves as a workaround for a known issue where TensorFlow does not recognize the default (or any custom) name for the first layer of Keras models. Furthermore, note that we are wrapping our model in a `tf.keras.Model` before returning it. This serves as a workaround for a known issue where a Sequential model cannot be directly converted into an Estimator. See [here](https://github.com/tensorflow/tensorflow/issues/20552) for more information about the issue.\n",
168168
"\n",
169169
"### Input functions\n",
170170
"These functions are similar to those required by any other model using the TensorFlow Estimator API."
@@ -285,21 +285,21 @@
285285
],
286286
"metadata": {
287287
"kernelspec": {
288-
"display_name": "conda_tensorflow_p27",
288+
"display_name": "conda_tensorflow_p36",
289289
"language": "python",
290-
"name": "conda_tensorflow_p27"
290+
"name": "conda_tensorflow_p36"
291291
},
292292
"language_info": {
293293
"codemirror_mode": {
294294
"name": "ipython",
295-
"version": 2
295+
"version": 3
296296
},
297297
"file_extension": ".py",
298298
"mimetype": "text/x-python",
299299
"name": "python",
300300
"nbconvert_exporter": "python",
301-
"pygments_lexer": "ipython2",
302-
"version": "2.7.14"
301+
"pygments_lexer": "ipython3",
302+
"version": "3.6.4"
303303
},
304304
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
305305
},

0 commit comments

Comments
 (0)