Skip to content

breaking: Move _NPYSerializer to sagemaker.serializers and rename to NumpySerializer #1691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, numpy_deserializer
from sagemaker.serializers import NumpySerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -48,7 +49,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(ChainerPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
)


Expand Down
48 changes: 0 additions & 48 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,51 +765,3 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):


numpy_deserializer = _NumpyDeserializer()


class _NPYSerializer(object):
"""Placeholder docstring"""

def __init__(self):
"""Placeholder docstring"""
self.content_type = CONTENT_TYPE_NPY

def __call__(self, data, dtype=None):
"""Serialize data into the request body in NPY format.

Args:
data (object): Data to be serialized. Can be a numpy array, list,
file, or buffer.
dtype:

Returns:
object: NPY serialized data used for the request.
"""
if isinstance(data, np.ndarray):
if not data.size > 0:
raise ValueError("empty array can't be serialized")
return _npy_serialize(data)

if isinstance(data, list):
if not len(data) > 0:
raise ValueError("empty array can't be serialized")
return _npy_serialize(np.array(data, dtype))

# files and buffers. Assumed to hold npy-formatted data.
if hasattr(data, "read"):
return data.read()

return _npy_serialize(np.array(data))


def _npy_serialize(data):
"""
Args:
data:
"""
buffer = BytesIO()
np.save(buffer, data)
return buffer.getvalue()


npy_serializer = _NPYSerializer()
5 changes: 3 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, numpy_deserializer
from sagemaker.serializers import NumpySerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(PyTorchPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
)


Expand Down
57 changes: 57 additions & 0 deletions src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from __future__ import absolute_import

import abc
import io

import numpy as np


class BaseSerializer(abc.ABC):
Expand All @@ -38,3 +41,57 @@ def serialize(self, data):
@abc.abstractmethod
def CONTENT_TYPE(self):
"""The MIME type of the data sent to the inference endpoint."""


class NumpySerializer(BaseSerializer):
"""Serialize data to a buffer using the .npy format."""

CONTENT_TYPE = "application/x-npy"

def __init__(self, dtype=None):
"""Initialize the dtype.

Args:
dtype (str): The dtype of the data.
"""
self.dtype = dtype

def serialize(self, data):
"""Serialize data to a buffer using the .npy format.

Args:
data (object): Data to be serialized. Can be a NumPy array, list,
file, or buffer.

Returns:
_io.BytesIO: A buffer containing data serialzied in the .npy format.
"""
if isinstance(data, np.ndarray):
if not data.size > 0:
raise ValueError("Cannot serialize empty array.")
return self.serialize_array(data)

if isinstance(data, list):
if not len(data) > 0:
raise ValueError("empty array can't be serialized")
return self.serialize_array(np.array(data, self.dtype))

# files and buffers. Assumed to hold npy-formatted data.
if hasattr(data, "read"):
return data.read()

return self.serialize_array(np.array(data))

@staticmethod
def serialize_array(array):
"""Saves a NumPy array in a buffer.

Args:
array (numpy.ndarray): The array to serialize.

Returns:
_io.BytesIO: A buffer containing the serialized array.
"""
buffer = io.BytesIO()
np.save(buffer, array)
return buffer.getvalue()
5 changes: 3 additions & 2 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from sagemaker.fw_registry import default_framework_uri
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
from sagemaker.predictor import Predictor, numpy_deserializer
from sagemaker.serializers import NumpySerializer
from sagemaker.sklearn import defaults

logger = logging.getLogger("sagemaker")
Expand All @@ -44,7 +45,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(SKLearnPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
)


Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.fw_registry import default_framework_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer
from sagemaker.predictor import Predictor, csv_deserializer
from sagemaker.serializers import NumpySerializer
from sagemaker.xgboost.defaults import XGBOOST_NAME

logger = logging.getLogger("sagemaker")
Expand All @@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
chain.
"""
super(XGBoostPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, csv_deserializer
endpoint_name, sagemaker_session, NumpySerializer(), csv_deserializer
)


Expand Down
17 changes: 9 additions & 8 deletions tests/integ/test_multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sagemaker.deserializers import StringDeserializer
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.mxnet import MXNet
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.retry import retries
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_multi_data_model_deploy_pretrained_models(
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -216,7 +217,7 @@ def test_multi_data_model_deploy_pretrained_models_local_mode(container_image, s
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=multi_data_model.sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -289,13 +290,13 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# Define a predictor to set `serializer` parameter with NumpySerializer
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -390,13 +391,13 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# Define a predictor to set `serializer` parameter with NumpySerializer
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down Expand Up @@ -486,7 +487,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=npy_serializer,
serializer=NumpySerializer(),
deserializer=string_deserializer,
)

Expand Down
103 changes: 103 additions & 0 deletions tests/unit/sagemaker/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import io

import numpy as np
import pytest

from sagemaker.serializers import NumpySerializer


@pytest.fixture
def numpy_serializer():
return NumpySerializer()


def test_numpy_serializer_python_array(numpy_serializer):
array = [1, 2, 3]
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_python_array_with_dtype():
numpy_serializer = NumpySerializer(dtype="float16")
array = [1, 2, 3]

result = numpy_serializer.serialize(array)

deserialized = np.load(io.BytesIO(result))
assert np.array_equal(array, deserialized)
assert deserialized.dtype == "float16"


def test_numpy_serializer_numpy_valid_2_dimensional(numpy_serializer):
array = np.array([[1, 2, 3], [3, 4, 5]])
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_numpy_valid_multidimensional(numpy_serializer):
array = np.ones((10, 10, 10, 10))
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_numpy_valid_list_of_strings(numpy_serializer):
array = np.array(["one", "two", "three"])
result = numpy_serializer.serialize(array)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_from_buffer_or_file(numpy_serializer):
array = np.ones((2, 3))
stream = io.BytesIO()
np.save(stream, array)
stream.seek(0)

result = numpy_serializer.serialize(stream)

assert np.array_equal(array, np.load(io.BytesIO(result)))


def test_numpy_serializer_object(numpy_serializer):
object = {1, 2, 3}

result = numpy_serializer.serialize(object)

assert np.array_equal(np.array(object), np.load(io.BytesIO(result), allow_pickle=True))


def test_numpy_serializer_list_of_empty(numpy_serializer):
with pytest.raises(ValueError) as invalid_input:
numpy_serializer.serialize(np.array([[], []]))

assert "empty array" in str(invalid_input)


def test_numpy_serializer_numpy_invalid_empty(numpy_serializer):
with pytest.raises(ValueError) as invalid_input:
numpy_serializer.serialize(np.array([]))

assert "empty array" in str(invalid_input)


def test_numpy_serializer_python_invalid_empty(numpy_serializer):
with pytest.raises(ValueError) as error:
numpy_serializer.serialize([])
assert "empty array" in str(error)
Loading