Skip to content

Commit f298f54

Browse files
authored
Improves documentation and tests about input and output functions (#17)
* capture ValidationException errors when deleting endpoints to not override previous exceptions * adding custom prediction to cifar 10 test * improved documentation about input and output functions
1 parent 3462496 commit f298f54

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

README.rst

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,11 +1365,11 @@ An example of ``input_fn`` for the content-type "application/python-pickle" can
13651365
13661366
import numpy as np
13671367
1368-
def input_fn(data, content_type):
1369-
"""An input_fn that loads a pickled numpy array"""
1368+
def input_fn(serialized_input, content_type):
1369+
"""An input_fn that loads a pickled object"""
13701370
if request_content_type == "application/python-pickle":
1371-
array = np.load(StringIO(request_body))
1372-
return array.reshape(model.data_shpaes[0])
1371+
deserialized_input = pickle.loads(serialized_input)
1372+
return deserialized_input
13731373
else:
13741374
# Handle other content-types here or raise an Exception
13751375
# if the content type is not supported.
@@ -1384,15 +1384,18 @@ An example of ``output_fn`` for the accept type "application/python-pickle" can
13841384
13851385
import numpy as np
13861386
1387-
def output_fn(data, accepts):
1388-
"""An output_fn that dumps a pickled numpy as response"""
1387+
def output_fn(prediction_result, accepts):
1388+
"""An output_fn that dumps a pickled object as response"""
13891389
if request_content_type == "application/python-pickle":
1390-
return np.dumps(data)
1390+
return np.dumps(prediction_result)
13911391
else:
13921392
# Handle other content-types here or raise an Exception
13931393
# if the content type is not supported.
13941394
pass
13951395
1396+
A example with ``input_fn`` and ``output_fn`` above can be found in
1397+
`here <https://github.com/aws/sagemaker-python-sdk/blob/master/tests/data/cifar_10/source/resnet_cifar_10.py#L143>`_.
1398+
13961399
SageMaker TensorFlow Docker containers
13971400
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13981401

tests/data/cifar_10/source/resnet_cifar_10.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import pickle
6+
57
import resnet_model
68
import tensorflow as tf
79

@@ -106,21 +108,19 @@ def model_fn(features, labels, mode, params):
106108

107109

108110
def serving_input_fn(hyperpameters):
109-
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=(32, 32, 3))}
110-
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()
111+
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
112+
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
111113

112114

113-
def train_input_fn(training_dir, hyperpameters):
114-
return input_fn(tf.estimator.ModeKeys.TRAIN,
115-
batch_size=BATCH_SIZE, data_dir=training_dir)
115+
def train_input_fn(training_dir, hyperparameters):
116+
return _generate_synthetic_data(tf.estimator.ModeKeys.TRAIN, batch_size=BATCH_SIZE)
116117

117118

118-
def eval_input_fn(training_dir, hyperpameters):
119-
return input_fn(tf.estimator.ModeKeys.EVAL,
120-
batch_size=BATCH_SIZE, data_dir=training_dir)
119+
def eval_input_fn(training_dir, hyperparameters):
120+
return _generate_synthetic_data(tf.estimator.ModeKeys.EVAL, batch_size=BATCH_SIZE)
121121

122122

123-
def input_fn(mode, batch_size, data_dir):
123+
def _generate_synthetic_data(mode, batch_size):
124124
input_shape = [batch_size, HEIGHT, WIDTH, DEPTH]
125125
images = tf.truncated_normal(
126126
input_shape,
@@ -138,3 +138,7 @@ def input_fn(mode, batch_size, data_dir):
138138
labels = tf.contrib.framework.local_variable(labels, name='labels')
139139

140140
return {INPUT_TENSOR_NAME: images}, labels
141+
142+
143+
def input_fn(serialized_data, content_type):
144+
return pickle.loads(serialized_data)

tests/integ/test_tf_cifar.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import pickle
14+
1315
import boto3
16+
import numpy as np
1417
import os
1518
import pytest
1619

@@ -19,12 +22,22 @@
1922
from tests.integ import DATA_DIR, REGION
2023
from tests.integ.timeout import timeout_and_delete_endpoint, timeout
2124

25+
PICKLE_CONTENT_TYPE = 'application/python-pickle'
26+
2227

2328
@pytest.fixture(scope='module')
2429
def sagemaker_session():
2530
return Session(boto_session=boto3.Session(region_name=REGION))
2631

2732

33+
class PickleSerializer(object):
34+
def __init__(self):
35+
self.content_type = PICKLE_CONTENT_TYPE
36+
37+
def __call__(self, data):
38+
return pickle.dumps(data, protocol=2)
39+
40+
2841
def test_cifar(sagemaker_session):
2942
with timeout(minutes=15):
3043
script_path = os.path.join(DATA_DIR, 'cifar_10', 'source')
@@ -42,4 +55,10 @@ def test_cifar(sagemaker_session):
4255
print('job succeeded: {}'.format(estimator.latest_training_job.name))
4356

4457
with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
45-
estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')
58+
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge')
59+
predictor.serializer = PickleSerializer()
60+
predictor.content_type = PICKLE_CONTENT_TYPE
61+
62+
data = np.random.randn(32, 32, 3)
63+
predict_response = predictor.predict(data)
64+
assert len(predict_response['outputs']['probabilities']['floatVal']) == 10

0 commit comments

Comments
 (0)