Skip to content

Commit d65a296

Browse files
author
Balaji Veeramani
committed
Add JSONDeserializer
1 parent f0d34cc commit d65a296

File tree

12 files changed

+68
-63
lines changed

12 files changed

+68
-63
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ For example, if you want to use JSON serialization and deserialization:
245245

246246
.. code:: python
247247
248-
from sagemaker.predictor import json_deserializer, json_serializer
248+
from sagemaker.predictor import json_serializer
249+
from sagemaker.deserializers import JSONDeserializer
249250
250251
predictor.content_type = "application/json"
251252
predictor.serializer = json_serializer
252-
predictor.accept = "application/json"
253-
predictor.deserializer = json_deserializer
253+
predictor.deserializer = JSONDeserializer()
254254
255255
predictor.predict(data)

src/sagemaker/amazon/ipinsights.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1717
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1818
from sagemaker.amazon.validation import ge, le
19-
from sagemaker.predictor import Predictor, csv_serializer, json_deserializer
19+
from sagemaker.deserializers import JSONDeserializer
20+
from sagemaker.predictor import Predictor, csv_serializer
2021
from sagemaker.model import Model
2122
from sagemaker.session import Session
2223
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -198,7 +199,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
198199
endpoint_name,
199200
sagemaker_session,
200201
serializer=csv_serializer,
201-
deserializer=json_deserializer,
202+
deserializer=JSONDeserializer(),
202203
)
203204

204205

src/sagemaker/deserializers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,24 @@ def deserialize(self, data, content_type):
154154
data.close()
155155

156156
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
157+
158+
159+
class JSONDeserializer(BaseDeserializer):
160+
"""Deserialize data from an inference endpoint into a Python dictionary."""
161+
162+
ACCEPT = "application/json"
163+
164+
def deserialize(self, data, content_type):
165+
"""Deserialize data from an inference endpoint into a Python dictionary.
166+
167+
Args:
168+
data (botocore.response.StreamingBody): Data to be deserialized.
169+
content_type (str): The MIME type of the data.
170+
171+
Returns:
172+
dict: The data deserialized into a Python dictionary.
173+
"""
174+
try:
175+
return json.load(codecs.getreader("utf-8")(data))
176+
finally:
177+
data.close()

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import packaging.version
1919

2020
import sagemaker
21+
from sagemaker.deserializers import JSONDeserializer
2122
from sagemaker.fw_utils import (
2223
create_image_uri,
2324
model_code_key_prefix,
@@ -26,7 +27,7 @@
2627
)
2728
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2829
from sagemaker.mxnet import defaults
29-
from sagemaker.predictor import Predictor, json_serializer, json_deserializer
30+
from sagemaker.predictor import Predictor, json_serializer
3031

3132
logger = logging.getLogger("sagemaker")
3233

@@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5051
using the default AWS configuration chain.
5152
"""
5253
super(MXNetPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, json_serializer, json_deserializer
54+
endpoint_name, sagemaker_session, json_serializer, JSONDeserializer()
5455
)
5556

5657

src/sagemaker/predictor.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -672,32 +672,6 @@ def _json_serialize_from_buffer(buff):
672672
return buff.read()
673673

674674

675-
class _JsonDeserializer(object):
676-
"""Placeholder docstring"""
677-
678-
def __init__(self):
679-
"""Placeholder docstring"""
680-
self.accept = CONTENT_TYPE_JSON
681-
682-
def __call__(self, stream, content_type):
683-
"""Decode a JSON object into the corresponding Python object.
684-
685-
Args:
686-
stream (stream): The response stream to be deserialized.
687-
content_type (str): The content type of the response.
688-
689-
Returns:
690-
object: Body of the response deserialized into a JSON object.
691-
"""
692-
try:
693-
return json.load(codecs.getreader("utf-8")(stream))
694-
finally:
695-
stream.close()
696-
697-
698-
json_deserializer = _JsonDeserializer()
699-
700-
701675
class _NPYSerializer(object):
702676
"""Placeholder docstring"""
703677

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
import sagemaker
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
20+
from sagemaker.deserializers import JSONDeserializer
2021
from sagemaker.fw_utils import create_image_uri
21-
from sagemaker.predictor import json_serializer, json_deserializer, Predictor
22+
from sagemaker.predictor import json_serializer, Predictor
2223

2324

2425
class TensorFlowPredictor(Predictor):
@@ -31,7 +32,7 @@ def __init__(
3132
endpoint_name,
3233
sagemaker_session=None,
3334
serializer=json_serializer,
34-
deserializer=json_deserializer,
35+
deserializer=JSONDeserializer(),
3536
content_type=None,
3637
model_name=None,
3738
model_version=None,

tests/integ/test_byo_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
8686
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
8787
predictor.serializer = fm_serializer
8888
predictor.content_type = "application/json"
89-
predictor.deserializer = sagemaker.predictor.json_deserializer
89+
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
9090

9191
result = predictor.predict(training_set[0][:10])
9292

@@ -132,7 +132,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
132132
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
133133
predictor.serializer = fm_serializer
134134
predictor.content_type = "application/json"
135-
predictor.deserializer = sagemaker.predictor.json_deserializer
135+
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
136136

137137
result = predictor.predict(training_set[0][:10])
138138

tests/integ/test_tfs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_predict_jsons_json_content_type(tfs_predictor):
188188
tfs_predictor.endpoint_name,
189189
tfs_predictor.sagemaker_session,
190190
serializer=None,
191-
deserializer=sagemaker.predictor.json_deserializer,
191+
deserializer=sagemaker.deserializers.JSONDeserializer(),
192192
content_type="application/json",
193193
accept="application/json",
194194
)
@@ -205,7 +205,7 @@ def test_predict_jsons(tfs_predictor):
205205
tfs_predictor.endpoint_name,
206206
tfs_predictor.sagemaker_session,
207207
serializer=None,
208-
deserializer=sagemaker.predictor.json_deserializer,
208+
deserializer=sagemaker.predictor.JSONDeserializer(),
209209
content_type="application/jsons",
210210
accept="application/jsons",
211211
)
@@ -222,7 +222,7 @@ def test_predict_jsonlines(tfs_predictor):
222222
tfs_predictor.endpoint_name,
223223
tfs_predictor.sagemaker_session,
224224
serializer=None,
225-
deserializer=sagemaker.predictor.json_deserializer,
225+
deserializer=sagemaker.predictor.JSONDeserializer(),
226226
content_type="application/jsonlines",
227227
accept="application/jsonlines",
228228
)

tests/integ/test_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from sagemaker.amazon.amazon_estimator import get_image_uri
2626
from sagemaker.amazon.common import read_records
2727
from sagemaker.chainer import Chainer
28+
from sagemaker.deserializers import JSONDeserializer
2829
from sagemaker.estimator import Estimator
2930
from sagemaker.mxnet.estimator import MXNet
30-
from sagemaker.predictor import json_deserializer
3131
from sagemaker.pytorch import PyTorch
3232
from sagemaker.tensorflow import TensorFlow
3333
from sagemaker.tuner import (
@@ -891,7 +891,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
891891
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)
892892
predictor.serializer = _fm_serializer
893893
predictor.content_type = "application/json"
894-
predictor.deserializer = json_deserializer
894+
predictor.deserializer = JSONDeserializer()
895895

896896
result = predictor.predict(datasets.one_p_mnist()[0][:10])
897897

tests/integ/test_tuner_multi_algo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from sagemaker.amazon.amazon_estimator import get_image_uri
2222
from sagemaker.analytics import HyperparameterTuningJobAnalytics
2323
from sagemaker.content_types import CONTENT_TYPE_JSON
24+
from sagemaker.deserializers import JSONDeserializer
2425
from sagemaker.estimator import Estimator
25-
from sagemaker.predictor import json_deserializer
2626
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner
2727
from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES
2828
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -219,7 +219,7 @@ def _create_training_inputs(sagemaker_session):
219219
def _make_prediction(predictor, data):
220220
predictor.serializer = _prediction_data_serializer
221221
predictor.content_type = CONTENT_TYPE_JSON
222-
predictor.deserializer = json_deserializer
222+
predictor.deserializer = JSONDeserializer()
223223
return predictor.predict(data)
224224

225225

tests/unit/sagemaker/test_deserializers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BytesDeserializer,
2323
StreamDeserializer,
2424
NumpyDeserializer,
25+
JSONDeserializer,
2526
)
2627

2728

@@ -119,3 +120,28 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
119120
result = numpy_deserializer.deserialize(stream, "application/x-npy")
120121

121122
assert np.array_equal(array, result)
123+
124+
125+
@pytest.fixture
126+
def json_deserializer():
127+
return JSONDeserializer()
128+
129+
130+
def test_json_deserializer_array(json_deserializer):
131+
result = json_deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
132+
133+
assert result == [1, 2, 3]
134+
135+
136+
def test_json_deserializer_2dimensional(json_deserializer):
137+
result = json_deserializer.deserialize(
138+
io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json"
139+
)
140+
141+
assert result == [[1, 2, 3], [3, 4, 5]]
142+
143+
144+
def test_json_deserializer_invalid_data(json_deserializer):
145+
with pytest.raises(ValueError) as error:
146+
json_deserializer.deserialize(io.BytesIO(b"[[1]"), "application/json")
147+
assert "column" in str(error)

tests/unit/test_predictor.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.predictor import (
2525
json_serializer,
26-
json_deserializer,
2726
csv_serializer,
2827
csv_deserializer,
2928
npy_serializer,
@@ -161,24 +160,6 @@ def test_csv_deserializer_2dimensional():
161160
assert result == [["1", "2", "3"], ["3", "4", "5"]]
162161

163162

164-
def test_json_deserializer_array():
165-
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")
166-
167-
assert result == [1, 2, 3]
168-
169-
170-
def test_json_deserializer_2dimensional():
171-
result = json_deserializer(io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json")
172-
173-
assert result == [[1, 2, 3], [3, 4, 5]]
174-
175-
176-
def test_json_deserializer_invalid_data():
177-
with pytest.raises(ValueError) as error:
178-
json_deserializer(io.BytesIO(b"[[1]"), "application/json")
179-
assert "column" in str(error)
180-
181-
182163
def test_npy_serializer_python_array():
183164
array = [1, 2, 3]
184165
result = npy_serializer(array)

0 commit comments

Comments
 (0)