Skip to content

breaking: Move _JsonDeserializer to sagemaker.deserializers.JSONDeserializer #1697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions doc/frameworks/tensorflow/upgrade_from_legacy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,10 @@ For example, if you want to use JSON serialization and deserialization:

.. code:: python

from sagemaker.predictor import json_deserializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer

predictor.serializer = JSONSerializer()
predictor.accept = "application/json"
predictor.deserializer = json_deserializer
predictor.deserializer = JSONDeserializer()

predictor.predict(data)
5 changes: 3 additions & 2 deletions src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
from sagemaker.amazon.validation import ge, le
from sagemaker.predictor import Predictor, csv_serializer, json_deserializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor, csv_serializer
from sagemaker.model import Model
from sagemaker.session import Session
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -198,7 +199,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
endpoint_name,
sagemaker_session,
serializer=csv_serializer,
deserializer=json_deserializer,
deserializer=JSONDeserializer(),
)


Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,24 @@ def deserialize(self, data, content_type):
data.close()

raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))


class JSONDeserializer(BaseDeserializer):
"""Deserialize JSON data from an inference endpoint into a Python object."""

ACCEPT = "application/json"

def deserialize(self, data, content_type):
"""Deserialize JSON data from an inference endpoint into a Python object.

Args:
data (botocore.response.StreamingBody): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
object: The JSON-formatted data deserialized into a Python object.
"""
try:
return json.load(codecs.getreader("utf-8")(data))
finally:
data.close()
5 changes: 3 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import packaging.version

import sagemaker
from sagemaker.deserializers import JSONDeserializer
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
Expand All @@ -26,7 +27,7 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet import defaults
from sagemaker.predictor import Predictor, json_deserializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer

logger = logging.getLogger("sagemaker")
Expand All @@ -51,7 +52,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(MXNetPredictor, self).__init__(
endpoint_name, sagemaker_session, JSONSerializer(), json_deserializer
endpoint_name, sagemaker_session, JSONSerializer(), JSONDeserializer()
)


Expand Down
30 changes: 1 addition & 29 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
"""Placeholder docstring"""
from __future__ import print_function, absolute_import

import codecs
import csv
import json
from six import StringIO
import numpy as np

from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV
from sagemaker.content_types import CONTENT_TYPE_CSV
from sagemaker.deserializers import BaseDeserializer
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.serializers import BaseSerializer
Expand Down Expand Up @@ -594,29 +592,3 @@ def _row_to_csv(obj):
if isinstance(obj, str):
return obj
return ",".join(obj)


class _JsonDeserializer(object):
"""Placeholder docstring"""

def __init__(self):
"""Placeholder docstring"""
self.accept = CONTENT_TYPE_JSON

def __call__(self, stream, content_type):
"""Decode a JSON object into the corresponding Python object.

Args:
stream (stream): The response stream to be deserialized.
content_type (str): The content type of the response.

Returns:
object: Body of the response deserialized into a JSON object.
"""
try:
return json.load(codecs.getreader("utf-8")(stream))
finally:
stream.close()


json_deserializer = _JsonDeserializer()
5 changes: 3 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import sagemaker
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.deserializers import JSONDeserializer
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import json_deserializer, Predictor
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer


Expand All @@ -32,7 +33,7 @@ def __init__(
endpoint_name,
sagemaker_session=None,
serializer=JSONSerializer(),
deserializer=json_deserializer,
deserializer=JSONDeserializer(),
content_type=None,
model_name=None,
model_version=None,
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
predictor.serializer = fm_serializer
predictor.content_type = "application/json"
predictor.deserializer = sagemaker.predictor.json_deserializer
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

result = predictor.predict(training_set[0][:10])

Expand Down Expand Up @@ -132,7 +132,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor.serializer = fm_serializer
predictor.content_type = "application/json"
predictor.deserializer = sagemaker.predictor.json_deserializer
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

result = predictor.predict(training_set[0][:10])

Expand Down
6 changes: 3 additions & 3 deletions tests/integ/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_predict_jsons_json_content_type(tfs_predictor):
tfs_predictor.endpoint_name,
tfs_predictor.sagemaker_session,
serializer=None,
deserializer=sagemaker.predictor.json_deserializer,
deserializer=sagemaker.deserializers.JSONDeserializer(),
content_type="application/json",
accept="application/json",
)
Expand All @@ -205,7 +205,7 @@ def test_predict_jsons(tfs_predictor):
tfs_predictor.endpoint_name,
tfs_predictor.sagemaker_session,
serializer=None,
deserializer=sagemaker.predictor.json_deserializer,
deserializer=sagemaker.deserializers.JSONDeserializer(),
content_type="application/jsons",
accept="application/jsons",
)
Expand All @@ -222,7 +222,7 @@ def test_predict_jsonlines(tfs_predictor):
tfs_predictor.endpoint_name,
tfs_predictor.sagemaker_session,
serializer=None,
deserializer=sagemaker.predictor.json_deserializer,
deserializer=sagemaker.deserializers.JSONDeserializer(),
content_type="application/jsonlines",
accept="application/jsonlines",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.amazon.common import read_records
from sagemaker.chainer import Chainer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.estimator import Estimator
from sagemaker.mxnet.estimator import MXNet
from sagemaker.predictor import json_deserializer
from sagemaker.pytorch import PyTorch
from sagemaker.tensorflow import TensorFlow
from sagemaker.tuner import (
Expand Down Expand Up @@ -891,7 +891,7 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)
predictor.serializer = _fm_serializer
predictor.content_type = "application/json"
predictor.deserializer = json_deserializer
predictor.deserializer = JSONDeserializer()

result = predictor.predict(datasets.one_p_mnist()[0][:10])

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_tuner_multi_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker.analytics import HyperparameterTuningJobAnalytics
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.deserializers import JSONDeserializer
from sagemaker.estimator import Estimator
from sagemaker.predictor import json_deserializer
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner
from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
Expand Down Expand Up @@ -219,7 +219,7 @@ def _create_training_inputs(sagemaker_session):
def _make_prediction(predictor, data):
predictor.serializer = _prediction_data_serializer
predictor.content_type = CONTENT_TYPE_JSON
predictor.deserializer = json_deserializer
predictor.deserializer = JSONDeserializer()
return predictor.predict(data)


Expand Down
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CSVDeserializer,
StreamDeserializer,
NumpyDeserializer,
JSONDeserializer,
)


Expand Down Expand Up @@ -145,3 +146,28 @@ def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
result = numpy_deserializer.deserialize(stream, "application/x-npy")

assert np.array_equal(array, result)


@pytest.fixture
def json_deserializer():
return JSONDeserializer()


def test_json_deserializer_array(json_deserializer):
result = json_deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == [1, 2, 3]


def test_json_deserializer_2dimensional(json_deserializer):
result = json_deserializer.deserialize(
io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json"
)

assert result == [[1, 2, 3], [3, 4, 5]]


def test_json_deserializer_invalid_data(json_deserializer):
with pytest.raises(ValueError) as error:
json_deserializer.deserialize(io.BytesIO(b"[[1]"), "application/json")
assert "column" in str(error)
24 changes: 1 addition & 23 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import io
import json
import os

Expand All @@ -21,10 +20,7 @@
from mock import Mock, call, patch

from sagemaker.predictor import Predictor
from sagemaker.predictor import (
json_deserializer,
csv_serializer,
)
from sagemaker.predictor import csv_serializer
from sagemaker.serializers import JSONSerializer
from tests.unit import DATA_DIR

Expand Down Expand Up @@ -97,24 +93,6 @@ def test_csv_serializer_csv_reader():
assert result == validation_data


def test_json_deserializer_array():
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == [1, 2, 3]


def test_json_deserializer_2dimensional():
result = json_deserializer(io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json")

assert result == [[1, 2, 3], [3, 4, 5]]


def test_json_deserializer_invalid_data():
with pytest.raises(ValueError) as error:
json_deserializer(io.BytesIO(b"[[1]"), "application/json")
assert "column" in str(error)


# testing 'predict' invocations


Expand Down