Skip to content

Commit 3d6f5df

Browse files
yangawsjesterhazy
authored andcommitted
Fix prediction data in tensorflow keras notebook (aws#462)
1 parent 768bf64 commit 3d6f5df

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

sagemaker-python-sdk/tensorflow_keras_cifar10/cifar10_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def serving_input_fn(hyperparameters):
9090
# Notice that the input placeholder has the same input shape as the Keras model input
9191
tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])
9292

93-
# The inputs key PREDICT_INPUTS matches the Keras InputLayer name
93+
# The inputs key INPUT_TENSOR_NAME matches the Keras InputLayer name
9494
inputs = {INPUT_TENSOR_NAME: tensor}
9595
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
9696

sagemaker-python-sdk/tensorflow_keras_cifar10/tensorflow_keras_CIFAR10.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@
186186
" # Notice that the input placeholder has the same input shape as the Keras model input\n",
187187
" tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])\n",
188188
" \n",
189-
" # The inputs key PREDICT_INPUTS matches the Keras InputLayer name\n",
190-
" inputs = {PREDICT_INPUTS: tensor}\n",
189+
" # The inputs key INPUT_TENSOR_NAME matches the Keras InputLayer name\n",
190+
" inputs = {INPUT_TENSOR_NAME: tensor}\n",
191191
" return tf.estimator.export.ServingInputReceiver(inputs, inputs)\n",
192192
"\n",
193193
"\n",
@@ -269,7 +269,8 @@
269269
"import numpy as np\n",
270270
"data = np.random.randn(1, 32, 32, 3)\n",
271271
"\n",
272-
"predictor.predict(data)"
272+
"# The inputs key 'inputs_input' matches the Keras InputLayer name\n",
273+
"predictor.predict({'inputs_input': data}) "
273274
]
274275
},
275276
{

0 commit comments

Comments
 (0)