Skip to content

Commit fe14388

Browse files
authored
Merge branch 'zwei' into tf-image-uri
2 parents 5ac66e7 + 218d786 commit fe14388

File tree

12 files changed

+69
-70
lines changed

12 files changed

+69
-70
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,10 @@ For example, if you want to use JSON serialization and deserialization:
245245

246246
.. code:: python
247247
248-
from sagemaker.predictor import json_deserializer
248+
from sagemaker.deserializers import JSONDeserializer
249249
from sagemaker.serializers import JSONSerializer
250250
251251
predictor.serializer = JSONSerializer()
252-
predictor.accept = "application/json"
253-
predictor.deserializer = json_deserializer
252+
predictor.deserializer = JSONDeserializer()
254253
255254
predictor.predict(data)

src/sagemaker/amazon/ipinsights.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1717
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1818
from sagemaker.amazon.validation import ge, le
19-
from sagemaker.predictor import Predictor, csv_serializer, json_deserializer
19+
from sagemaker.deserializers import JSONDeserializer
20+
from sagemaker.predictor import Predictor, csv_serializer
2021
from sagemaker.model import Model
2122
from sagemaker.session import Session
2223
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -198,7 +199,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
198199
endpoint_name,
199200
sagemaker_session,
200201
serializer=csv_serializer,
201-
deserializer=json_deserializer,
202+
deserializer=JSONDeserializer(),
202203
)
203204

204205

src/sagemaker/deserializers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,24 @@ def deserialize(self, data, content_type):
187187
data.close()
188188

189189
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
190+
191+
192+
class JSONDeserializer(BaseDeserializer):
193+
"""Deserialize JSON data from an inference endpoint into a Python object."""
194+
195+
ACCEPT = "application/json"
196+
197+
def deserialize(self, data, content_type):
198+
"""Deserialize JSON data from an inference endpoint into a Python object.
199+
200+
Args:
201+
data (botocore.response.StreamingBody): Data to be deserialized.
202+
content_type (str): The MIME type of the data.
203+
204+
Returns:
205+
object: The JSON-formatted data deserialized into a Python object.
206+
"""
207+
try:
208+
return json.load(codecs.getreader("utf-8")(data))
209+
finally:
210+
data.close()

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import packaging.version
1919

2020
import sagemaker
21+
from sagemaker.deserializers import JSONDeserializer
2122
from sagemaker.fw_utils import (
2223
create_image_uri,
2324
model_code_key_prefix,
@@ -26,7 +27,7 @@
2627
)
2728
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2829
from sagemaker.mxnet import defaults
29-
from sagemaker.predictor import Predictor, json_deserializer
30+
from sagemaker.predictor import Predictor
3031
from sagemaker.serializers import JSONSerializer
3132

3233
logger = logging.getLogger("sagemaker")
@@ -51,7 +52,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5152
using the default AWS configuration chain.
5253
"""
5354
super(MXNetPredictor, self).__init__(
54-
endpoint_name, sagemaker_session, JSONSerializer(), json_deserializer
55+
endpoint_name, sagemaker_session, JSONSerializer(), JSONDeserializer()
5556
)
5657

5758

src/sagemaker/predictor.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16-
import codecs
1716
import csv
18-
import json
1917
from six import StringIO
2018
import numpy as np
2119

22-
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV
20+
from sagemaker.content_types import CONTENT_TYPE_CSV
2321
from sagemaker.deserializers import BaseDeserializer
2422
from sagemaker.model_monitor import DataCaptureConfig
2523
from sagemaker.serializers import BaseSerializer
@@ -594,29 +592,3 @@ def _row_to_csv(obj):
594592
if isinstance(obj, str):
595593
return obj
596594
return ",".join(obj)
597-
598-
599-
class _JsonDeserializer(object):
600-
"""Placeholder docstring"""
601-
602-
def __init__(self):
603-
"""Placeholder docstring"""
604-
self.accept = CONTENT_TYPE_JSON
605-
606-
def __call__(self, stream, content_type):
607-
"""Decode a JSON object into the corresponding Python object.
608-
609-
Args:
610-
stream (stream): The response stream to be deserialized.
611-
content_type (str): The content type of the response.
612-
613-
Returns:
614-
object: Body of the response deserialized into a JSON object.
615-
"""
616-
try:
617-
return json.load(codecs.getreader("utf-8")(stream))
618-
finally:
619-
stream.close()
620-
621-
622-
json_deserializer = _JsonDeserializer()

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
import sagemaker
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
20+
from sagemaker.deserializers import JSONDeserializer
2021
from sagemaker.fw_utils import create_image_uri
21-
from sagemaker.predictor import json_deserializer, Predictor
22+
from sagemaker.predictor import Predictor
2223
from sagemaker.serializers import JSONSerializer
2324

2425

@@ -32,7 +33,7 @@ def __init__(
3233
endpoint_name,
3334
sagemaker_session=None,
3435
serializer=JSONSerializer(),
35-
deserializer=json_deserializer,
36+
deserializer=JSONDeserializer(),
3637
content_type=None,
3738
model_name=None,
3839
model_version=None,

tests/integ/test_byo_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
8686
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
8787
predictor.serializer = fm_serializer
8888
predictor.content_type = "application/json"
89-
predictor.deserializer = sagemaker.predictor.json_deserializer
89+
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
9090

9191
result = predictor.predict(training_set[0][:10])
9292

@@ -132,7 +132,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
132132
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
133133
predictor.serializer = fm_serializer
134134
predictor.content_type = "application/json"
135-
predictor.deserializer = sagemaker.predictor.json_deserializer
135+
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
136136

137137
result = predictor.predict(training_set[0][:10])
138138

tests/integ/test_tfs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_predict_jsons_json_content_type(tfs_predictor):
188188
tfs_predictor.endpoint_name,
189189
tfs_predictor.sagemaker_session,
190190
serializer=None,
191-
deserializer=sagemaker.predictor.json_deserializer,
191+
deserializer=sagemaker.deserializers.JSONDeserializer(),
192192
content_type="application/json",
193193
accept="application/json",
194194
)
@@ -205,7 +205,7 @@ def test_predict_jsons(tfs_predictor):
205205
tfs_predictor.endpoint_name,
206206
tfs_predictor.sagemaker_session,
207207
serializer=None,
208-
deserializer=sagemaker.predictor.json_deserializer,
208+
deserializer=sagemaker.deserializers.JSONDeserializer(),
209209
content_type="application/jsons",
210210
accept="application/jsons",
211211
)
@@ -222,7 +222,7 @@ def test_predict_jsonlines(tfs_predictor):
222222
tfs_predictor.endpoint_name,
223223
tfs_predictor.sagemaker_session,
224224
serializer=None,
225-
deserializer=sagemaker.predictor.json_deserializer,
225+
deserializer=sagemaker.deserializers.JSONDeserializer(),
226226
content_type="application/jsonlines",
227227
accept="application/jsonlines",
228228
)

tests/integ/test_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from sagemaker.amazon.amazon_estimator import get_image_uri
2626
from sagemaker.amazon.common import read_records
2727
from sagemaker.chainer import Chainer
28+
from sagemaker.deserializers import JSONDeserializer
2829
from sagemaker.estimator import Estimator
2930
from sagemaker.mxnet.estimator import MXNet
30-
from sagemaker.predictor import json_deserializer
3131
from sagemaker.pytorch import PyTorch
3232
from sagemaker.tensorflow import TensorFlow
3333
from sagemaker.tuner import (
@@ -891,7 +891,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
891891
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)
892892
predictor.serializer = _fm_serializer
893893
predictor.content_type = "application/json"
894-
predictor.deserializer = json_deserializer
894+
predictor.deserializer = JSONDeserializer()
895895

896896
result = predictor.predict(datasets.one_p_mnist()[0][:10])
897897

tests/integ/test_tuner_multi_algo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from sagemaker.amazon.amazon_estimator import get_image_uri
2222
from sagemaker.analytics import HyperparameterTuningJobAnalytics
2323
from sagemaker.content_types import CONTENT_TYPE_JSON
24+
from sagemaker.deserializers import JSONDeserializer
2425
from sagemaker.estimator import Estimator
25-
from sagemaker.predictor import json_deserializer
2626
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner
2727
from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES
2828
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -219,7 +219,7 @@ def _create_training_inputs(sagemaker_session):
219219
def _make_prediction(predictor, data):
220220
predictor.serializer = _prediction_data_serializer
221221
predictor.content_type = CONTENT_TYPE_JSON
222-
predictor.deserializer = json_deserializer
222+
predictor.deserializer = JSONDeserializer()
223223
return predictor.predict(data)
224224

225225

tests/unit/sagemaker/test_deserializers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CSVDeserializer,
2424
StreamDeserializer,
2525
NumpyDeserializer,
26+
JSONDeserializer,
2627
)
2728

2829

@@ -145,3 +146,28 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
145146
result = numpy_deserializer.deserialize(stream, "application/x-npy")
146147

147148
assert np.array_equal(array, result)
149+
150+
151+
@pytest.fixture
152+
def json_deserializer():
153+
return JSONDeserializer()
154+
155+
156+
def test_json_deserializer_array(json_deserializer):
157+
result = json_deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
158+
159+
assert result == [1, 2, 3]
160+
161+
162+
def test_json_deserializer_2dimensional(json_deserializer):
163+
result = json_deserializer.deserialize(
164+
io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json"
165+
)
166+
167+
assert result == [[1, 2, 3], [3, 4, 5]]
168+
169+
170+
def test_json_deserializer_invalid_data(json_deserializer):
171+
with pytest.raises(ValueError) as error:
172+
json_deserializer.deserialize(io.BytesIO(b"[[1]"), "application/json")
173+
assert "column" in str(error)

tests/unit/test_predictor.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import io
1615
import json
1716
import os
1817

@@ -21,10 +20,7 @@
2120
from mock import Mock, call, patch
2221

2322
from sagemaker.predictor import Predictor
24-
from sagemaker.predictor import (
25-
json_deserializer,
26-
csv_serializer,
27-
)
23+
from sagemaker.predictor import csv_serializer
2824
from sagemaker.serializers import JSONSerializer
2925
from tests.unit import DATA_DIR
3026

@@ -97,24 +93,6 @@ def test_csv_serializer_csv_reader():
9793
assert result == validation_data
9894

9995

100-
def test_json_deserializer_array():
101-
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")
102-
103-
assert result == [1, 2, 3]
104-
105-
106-
def test_json_deserializer_2dimensional():
107-
result = json_deserializer(io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json")
108-
109-
assert result == [[1, 2, 3], [3, 4, 5]]
110-
111-
112-
def test_json_deserializer_invalid_data():
113-
with pytest.raises(ValueError) as error:
114-
json_deserializer(io.BytesIO(b"[[1]"), "application/json")
115-
assert "column" in str(error)
116-
117-
11896
# testing 'predict' invocations
11997

12098

0 commit comments

Comments
 (0)