Skip to content

breaking: Move _NPYSerializer to sagemaker.serializers and rename to NumpySerializer #1691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(ChainerPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
)


Expand Down
52 changes: 2 additions & 50 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import codecs
import csv
import json
from six import StringIO, BytesIO
from six import StringIO
import numpy as np

from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV
from sagemaker.deserializers import BaseDeserializer
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.serializers import BaseSerializer
Expand Down Expand Up @@ -620,51 +620,3 @@ def __call__(self, stream, content_type):


json_deserializer = _JsonDeserializer()


class _NPYSerializer(object):
"""Placeholder docstring"""

def __init__(self):
"""Placeholder docstring"""
self.content_type = CONTENT_TYPE_NPY

def __call__(self, data, dtype=None):
"""Serialize data into the request body in NPY format.

Args:
data (object): Data to be serialized. Can be a numpy array, list,
file, or buffer.
dtype:

Returns:
object: NPY serialized data used for the request.
"""
if isinstance(data, np.ndarray):
if not data.size > 0:
raise ValueError("empty array can't be serialized")
return _npy_serialize(data)

if isinstance(data, list):
if not len(data) > 0:
raise ValueError("empty array can't be serialized")
return _npy_serialize(np.array(data, dtype))

# files and buffers. Assumed to hold npy-formatted data.
if hasattr(data, "read"):
return data.read()

return _npy_serialize(np.array(data))


def _npy_serialize(data):
"""
Args:
data:
"""
buffer = BytesIO()
np.save(buffer, data)
return buffer.getvalue()


npy_serializer = _NPYSerializer()
5 changes: 3 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(PyTorchPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
)


Expand Down
54 changes: 54 additions & 0 deletions src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import abc
import io
import json

import numpy as np
Expand Down Expand Up @@ -43,6 +44,59 @@ def CONTENT_TYPE(self):
"""The MIME type of the data sent to the inference endpoint."""


class NumpySerializer(BaseSerializer):
"""Serialize data to a buffer using the .npy format."""

CONTENT_TYPE = "application/x-npy"

def __init__(self, dtype=None):
"""Initialize the dtype.

Args:
dtype (str): The dtype of the data.
"""
self.dtype = dtype

def serialize(self, data):
"""Serialize data to a buffer using the .npy format.

Args:
data (object): Data to be serialized. Can be a NumPy array, list,
file, or buffer.

Returns:
io.BytesIO: A buffer containing data serialzied in the .npy format.
"""
if isinstance(data, np.ndarray):
if data.size == 0:
raise ValueError("Cannot serialize empty array.")
return self._serialize_array(data)

if isinstance(data, list):
if len(data) == 0:
raise ValueError("Cannot serialize empty array.")
return self._serialize_array(np.array(data, self.dtype))

# files and buffers. Assumed to hold npy-formatted data.
if hasattr(data, "read"):
return data.read()

return self._serialize_array(np.array(data))

def _serialize_array(self, array):
"""Saves a NumPy array in a buffer.

Args:
array (numpy.ndarray): The array to serialize.

Returns:
io.BytesIO: A buffer containing the serialized array.
"""
buffer = io.BytesIO()
np.save(buffer, array)
return buffer.getvalue()


class JSONSerializer(BaseSerializer):
"""Serialize data to a JSON formatted string."""

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.sklearn import defaults

logger = logging.getLogger("sagemaker")
Expand All @@ -45,7 +46,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(SKLearnPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
)


Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.fw_registry import default_framework_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.xgboost.defaults import XGBOOST_NAME

logger = logging.getLogger("sagemaker")
Expand All @@ -43,7 +44,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
chain.
"""
super(XGBoostPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer()
endpoint_name, sagemaker_session, NumpySerializer(), CSVDeserializer()
)


Expand Down
17 changes: 9 additions & 8 deletions tests/integ/test_multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sagemaker.deserializers import StringDeserializer
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.mxnet import MXNet
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.retry import retries
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_multi_data_model_deploy_pretrained_models(
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -216,7 +217,7 @@ def test_multi_data_model_deploy_pretrained_models_local_mode(container_image, s
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=multi_data_model.sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -289,13 +290,13 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# Define a predictor to set `serializer` parameter with `NumpySerializer`
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -390,13 +391,13 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# Define a predictor to set `serializer` parameter with `NumpySerializer`
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -486,7 +487,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down
86 changes: 85 additions & 1 deletion tests/unit/sagemaker/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,100 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import io
import json
import os

import numpy as np
import pytest

from sagemaker.serializers import JSONSerializer
from sagemaker.serializers import NumpySerializer, JSONSerializer
from tests.unit import DATA_DIR


@pytest.fixture
def numpy_serializer():
return NumpySerializer()


def test_numpy_serializer_python_array(numpy_serializer):
array = [1, 2, 3]
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_python_array_with_dtype():
numpy_serializer = NumpySerializer(dtype="float16")
array = [1, 2, 3]

result = numpy_serializer.serialize(array)

deserialized = np.load(io.BytesIO(result))
assert np.array_equal(array, deserialized)
assert deserialized.dtype == "float16"


def test_numpy_serializer_numpy_valid_2_dimensional(numpy_serializer):
array = np.array([[1, 2, 3], [3, 4, 5]])
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_numpy_valid_multidimensional(numpy_serializer):
array = np.ones((10, 10, 10, 10))
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_numpy_valid_list_of_strings(numpy_serializer):
array = np.array(["one", "two", "three"])
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_from_buffer_or_file(numpy_serializer):
array = np.ones((2, 3))
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_serializer.serialize(stream)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_object(numpy_serializer):
object = {1, 2, 3}

result = numpy_serializer.serialize(object)

assert np.array_equal(np.array(object), np.load(io.BytesIO(result), allow_pickle=True))


def test_numpy_serializer_list_of_empty(numpy_serializer):
with pytest.raises(ValueError) as invalid_input:
numpy_serializer.serialize(np.array([[], []]))

assert "empty array" in str(invalid_input)


def test_numpy_serializer_numpy_invalid_empty(numpy_serializer):
with pytest.raises(ValueError) as invalid_input:
numpy_serializer.serialize(np.array([]))

assert "empty array" in str(invalid_input)


def test_numpy_serializer_python_invalid_empty(numpy_serializer):
with pytest.raises(ValueError) as error:
numpy_serializer.serialize([])
assert "empty array" in str(error)


@pytest.fixture
def json_serializer():
return JSONSerializer()
Expand Down
Loading