Skip to content

breaking: Move _NumpyDeserializer to sagemaker.deserializers.NumpyDeserializer #1695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.predictor import Predictor, npy_serializer

logger = logging.getLogger("sagemaker")

Expand All @@ -48,7 +49,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(ChainerPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
)


Expand Down
43 changes: 43 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from __future__ import absolute_import

import abc
import codecs
import io
import json

import numpy as np


class BaseDeserializer(abc.ABC):
Expand Down Expand Up @@ -111,3 +116,41 @@ def deserialize(self, data, content_type):
tuple: A two-tuple containing the stream and content-type.
"""
return data, content_type


class NumpyDeserializer(BaseDeserializer):
"""Deserialize a stream of data in the .npy format."""

ACCEPT = "application/x-npy"

def __init__(self, dtype=None):
"""Initialize the dtype.

Args:
dtype (str): The dtype of the data.
"""
self.dtype = dtype

def deserialize(self, data, content_type):
"""Deserialize data from an inference endpoint into a NumPy array.

Args:
data (botocore.response.StreamingBody): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
numpy.ndarray: The data deserialized into a NumPy array.
"""
try:
if content_type == "text/csv":
return np.genfromtxt(
codecs.getreader("utf-8")(data), delimiter=",", dtype=self.dtype
)
if content_type == "application/json":
return np.array(json.load(codecs.getreader("utf-8")(data)), dtype=self.dtype)
if content_type == "application/x-npy":
return np.load(io.BytesIO(data.read()))
finally:
data.close()

raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
44 changes: 0 additions & 44 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,50 +698,6 @@ def __call__(self, stream, content_type):
json_deserializer = _JsonDeserializer()


class _NumpyDeserializer(object):
"""Placeholder docstring"""

def __init__(self, accept=CONTENT_TYPE_NPY, dtype=None):
"""
Args:
accept:
dtype:
"""
self.accept = accept
self.dtype = dtype

def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
"""Decode from serialized data into a Numpy array.

Args:
stream (stream): The response stream to be deserialized.
content_type (str): The content type of the response. Can accept
CSV, JSON, or NPY data.

Returns:
object: Body of the response deserialized into a Numpy array.
"""
try:
if content_type == CONTENT_TYPE_CSV:
return np.genfromtxt(
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
)
if content_type == CONTENT_TYPE_JSON:
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
if content_type == CONTENT_TYPE_NPY:
return np.load(BytesIO(stream.read()))
finally:
stream.close()
raise ValueError(
"content_type must be one of the following: CSV, JSON, NPY. content_type: {}".format(
content_type
)
)


numpy_deserializer = _NumpyDeserializer()


class _NPYSerializer(object):
"""Placeholder docstring"""

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import packaging.version

import sagemaker
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
Expand All @@ -25,7 +26,7 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, npy_serializer

logger = logging.getLogger("sagemaker")

Expand All @@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(PyTorchPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
)


Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import logging

import sagemaker
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.sklearn import defaults

logger = logging.getLogger("sagemaker")
Expand All @@ -44,7 +45,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(SKLearnPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
)


Expand Down
77 changes: 76 additions & 1 deletion tests/unit/sagemaker/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

import io

from sagemaker.deserializers import StringDeserializer, BytesDeserializer, StreamDeserializer
import numpy as np
import pytest

from sagemaker.deserializers import (
StringDeserializer,
BytesDeserializer,
StreamDeserializer,
NumpyDeserializer,
)


def test_string_deserializer():
Expand Down Expand Up @@ -44,3 +52,70 @@ def test_stream_deserializer():

assert result == b"[1, 2, 3]"
assert content_type == "application/json"


@pytest.fixture
def numpy_deserializer():
return NumpyDeserializer()


def test_numpy_deserializer_from_csv(numpy_deserializer):
stream = io.BytesIO(b"1,2,3\n4,5,6")
array = numpy_deserializer.deserialize(stream, "text/csv")
assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6]]))


def test_numpy_deserializer_from_csv_ragged(numpy_deserializer):
stream = io.BytesIO(b"1,2,3\n4,5,6,7")
with pytest.raises(ValueError) as error:
numpy_deserializer.deserialize(stream, "text/csv")
assert "errors were detected" in str(error)


def test_numpy_deserializer_from_csv_alpha():
numpy_deserializer = NumpyDeserializer(dtype="U5")
stream = io.BytesIO(b"hello,2,3\n4,5,6")
array = numpy_deserializer.deserialize(stream, "text/csv")
assert np.array_equal(array, np.array([["hello", 2, 3], [4, 5, 6]]))


def test_numpy_deserializer_from_json(numpy_deserializer):
stream = io.BytesIO(b"[[1,2,3],\n[4,5,6]]")
array = numpy_deserializer.deserialize(stream, "application/json")
assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6]]))


# Sadly, ragged arrays work fine in JSON (giving us a 1D array of Python lists)
def test_numpy_deserializer_from_json_ragged(numpy_deserializer):
stream = io.BytesIO(b"[[1,2,3],\n[4,5,6,7]]")
array = numpy_deserializer.deserialize(stream, "application/json")
assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6, 7]]))


def test_numpy_deserializer_from_json_alpha():
numpy_deserializer = NumpyDeserializer(dtype="U5")
stream = io.BytesIO(b'[["hello",2,3],\n[4,5,6]]')
array = numpy_deserializer.deserialize(stream, "application/json")
assert np.array_equal(array, np.array([["hello", 2, 3], [4, 5, 6]]))


def test_numpy_deserializer_from_npy(numpy_deserializer):
array = np.ones((2, 3))
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_deserializer.deserialize(stream, "application/x-npy")

assert np.array_equal(array, result)


def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
array = np.array(["one", "two"])
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_deserializer.deserialize(stream, "application/x-npy")

assert np.array_equal(array, result)
58 changes: 0 additions & 58 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
json_deserializer,
csv_serializer,
csv_deserializer,
numpy_deserializer,
npy_serializer,
_NumpyDeserializer,
)
from tests.unit import DATA_DIR

Expand Down Expand Up @@ -259,62 +257,6 @@ def test_npy_serializer_python_invalid_empty():
assert "empty array" in str(error)


def test_numpy_deser_from_csv():
arr = numpy_deserializer(io.BytesIO(b"1,2,3\n4,5,6"), "text/csv")
assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6]]))


def test_numpy_deser_from_csv_ragged():
with pytest.raises(ValueError) as error:
numpy_deserializer(io.BytesIO(b"1,2,3\n4,5,6,7"), "text/csv")
assert "errors were detected" in str(error)


def test_numpy_deser_from_csv_alpha():
arr = _NumpyDeserializer(dtype="U5")(io.BytesIO(b"hello,2,3\n4,5,6"), "text/csv")
assert np.array_equal(arr, np.array([["hello", 2, 3], [4, 5, 6]]))


def test_numpy_deser_from_json():
arr = numpy_deserializer(io.BytesIO(b"[[1,2,3],\n[4,5,6]]"), "application/json")
assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6]]))


# Sadly, ragged arrays work fine in JSON (giving us a 1D array of Python lists
def test_numpy_deser_from_json_ragged():
arr = numpy_deserializer(io.BytesIO(b"[[1,2,3],\n[4,5,6,7]]"), "application/json")
assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6, 7]]))


def test_numpy_deser_from_json_alpha():
arr = _NumpyDeserializer(dtype="U5")(
io.BytesIO(b'[["hello",2,3],\n[4,5,6]]'), "application/json"
)
assert np.array_equal(arr, np.array([["hello", 2, 3], [4, 5, 6]]))


def test_numpy_deser_from_npy():
array = np.ones((2, 3))
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_deserializer(stream)

assert np.array_equal(array, result)


def test_numpy_deser_from_npy_object_array():
array = np.array(["one", "two"])
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_deserializer(stream)

assert np.array_equal(array, result)


# testing 'predict' invocations


Expand Down