Skip to content

Commit 1f7f103

Browse files
author
Balaji Veeramani
committed
Move _NPYSerializer
1 parent 9fc8a46 commit 1f7f103

File tree

9 files changed

+181
-143
lines changed

9 files changed

+181
-143
lines changed

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
)
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.chainer import defaults
27-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
27+
from sagemaker.predictor import Predictor, numpy_deserializer
28+
from sagemaker.serializers import NumpySerializer
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -48,7 +49,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4849
using the default AWS configuration chain.
4950
"""
5051
super(ChainerPredictor, self).__init__(
51-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
52+
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
5253
)
5354

5455

src/sagemaker/predictor.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -765,51 +765,3 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
765765

766766

767767
numpy_deserializer = _NumpyDeserializer()
768-
769-
770-
class _NPYSerializer(object):
771-
"""Placeholder docstring"""
772-
773-
def __init__(self):
774-
"""Placeholder docstring"""
775-
self.content_type = CONTENT_TYPE_NPY
776-
777-
def __call__(self, data, dtype=None):
778-
"""Serialize data into the request body in NPY format.
779-
780-
Args:
781-
data (object): Data to be serialized. Can be a numpy array, list,
782-
file, or buffer.
783-
dtype:
784-
785-
Returns:
786-
object: NPY serialized data used for the request.
787-
"""
788-
if isinstance(data, np.ndarray):
789-
if not data.size > 0:
790-
raise ValueError("empty array can't be serialized")
791-
return _npy_serialize(data)
792-
793-
if isinstance(data, list):
794-
if not len(data) > 0:
795-
raise ValueError("empty array can't be serialized")
796-
return _npy_serialize(np.array(data, dtype))
797-
798-
# files and buffers. Assumed to hold npy-formatted data.
799-
if hasattr(data, "read"):
800-
return data.read()
801-
802-
return _npy_serialize(np.array(data))
803-
804-
805-
def _npy_serialize(data):
806-
"""
807-
Args:
808-
data:
809-
"""
810-
buffer = BytesIO()
811-
np.save(buffer, data)
812-
return buffer.getvalue()
813-
814-
815-
npy_serializer = _NPYSerializer()

src/sagemaker/pytorch/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
)
2626
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2727
from sagemaker.pytorch import defaults
28-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
28+
from sagemaker.predictor import Predictor, numpy_deserializer
29+
from sagemaker.serializers import NumpySerializer
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4950
using the default AWS configuration chain.
5051
"""
5152
super(PyTorchPredictor, self).__init__(
52-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
53+
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
5354
)
5455

5556

src/sagemaker/serializers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import io
18+
19+
import numpy as np
1720

1821

1922
class BaseSerializer(abc.ABC):
@@ -38,3 +41,57 @@ def serialize(self, data):
3841
@abc.abstractmethod
3942
def CONTENT_TYPE(self):
4043
"""The MIME type of the data sent to the inference endpoint."""
44+
45+
46+
class NumpySerializer(BaseSerializer):
47+
"""Serialize data to a buffer using the .npy format."""
48+
49+
CONTENT_TYPE = "application/x-npy"
50+
51+
def __init__(self, dtype=None):
52+
"""Initialize the dtype.
53+
54+
Args:
55+
dtype (str): The dtype of the data.
56+
"""
57+
self.dtype = dtype
58+
59+
def serialize(self, data):
60+
"""Serialize data to a buffer using the .npy format.
61+
62+
Args:
63+
data (object): Data to be serialized. Can be a NumPy array, list,
64+
file, or buffer.
65+
66+
Returns:
67+
_io.BytesIO: A buffer containing data serialzied in the .npy format.
68+
"""
69+
if isinstance(data, np.ndarray):
70+
if not data.size > 0:
71+
raise ValueError("Cannot serialize empty array.")
72+
return self.serialize_array(data)
73+
74+
if isinstance(data, list):
75+
if not len(data) > 0:
76+
raise ValueError("empty array can't be serialized")
77+
return self.serialize_array(np.array(data, self.dtype))
78+
79+
# files and buffers. Assumed to hold npy-formatted data.
80+
if hasattr(data, "read"):
81+
return data.read()
82+
83+
return self.serialize_array(np.array(data))
84+
85+
@staticmethod
86+
def serialize_array(array):
87+
"""Saves a NumPy array in a buffer.
88+
89+
Args:
90+
array (numpy.ndarray): The array to serialize.
91+
92+
Returns:
93+
_io.BytesIO: A buffer containing the serialized array.
94+
"""
95+
buffer = io.BytesIO()
96+
np.save(buffer, array)
97+
return buffer.getvalue()

src/sagemaker/sklearn/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from sagemaker.fw_registry import default_framework_uri
2020
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
2121
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
22+
from sagemaker.predictor import Predictor, numpy_deserializer
23+
from sagemaker.serializers import NumpySerializer
2324
from sagemaker.sklearn import defaults
2425

2526
logger = logging.getLogger("sagemaker")
@@ -44,7 +45,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4445
using the default AWS configuration chain.
4546
"""
4647
super(SKLearnPredictor, self).__init__(
47-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
48+
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
4849
)
4950

5051

src/sagemaker/xgboost/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from sagemaker.fw_utils import model_code_key_prefix
2020
from sagemaker.fw_registry import default_framework_uri
2121
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer
22+
from sagemaker.predictor import Predictor, csv_deserializer
23+
from sagemaker.serializers import NumpySerializer
2324
from sagemaker.xgboost.defaults import XGBOOST_NAME
2425

2526
logger = logging.getLogger("sagemaker")
@@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4243
chain.
4344
"""
4445
super(XGBoostPredictor, self).__init__(
45-
endpoint_name, sagemaker_session, npy_serializer, csv_deserializer
46+
endpoint_name, sagemaker_session, NumpySerializer(), csv_deserializer
4647
)
4748

4849

tests/integ/test_multidatamodel.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from sagemaker.deserializers import StringDeserializer
2828
from sagemaker.multidatamodel import MultiDataModel
2929
from sagemaker.mxnet import MXNet
30-
from sagemaker.predictor import Predictor, npy_serializer
30+
from sagemaker.predictor import Predictor
31+
from sagemaker.serializers import NumpySerializer
3132
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
3233
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
3334
from tests.integ.retry import retries
@@ -158,7 +159,7 @@ def test_multi_data_model_deploy_pretrained_models(
158159
predictor = Predictor(
159160
endpoint_name=endpoint_name,
160161
sagemaker_session=sagemaker_session,
161-
serializer=npy_serializer,
162+
serializer=NumpySerializer(),
162163
deserializer=string_deserializer,
163164
)
164165

@@ -216,7 +217,7 @@ def test_multi_data_model_deploy_pretrained_models_local_mode(container_image, s
216217
predictor = Predictor(
217218
endpoint_name=endpoint_name,
218219
sagemaker_session=multi_data_model.sagemaker_session,
219-
serializer=npy_serializer,
220+
serializer=NumpySerializer(),
220221
deserializer=string_deserializer,
221222
)
222223

@@ -289,13 +290,13 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
289290
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
290291
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
291292

292-
# Define a predictor to set `serializer` parameter with npy_serializer
293+
# Define a predictor to set `serializer` parameter with NumpySerializer
293294
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
294295
# Since we are using a placeholder container image the prediction results are not accurate.
295296
predictor = Predictor(
296297
endpoint_name=endpoint_name,
297298
sagemaker_session=sagemaker_session,
298-
serializer=npy_serializer,
299+
serializer=NumpySerializer(),
299300
deserializer=string_deserializer,
300301
)
301302

@@ -390,13 +391,13 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
390391
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
391392
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
392393

393-
# Define a predictor to set `serializer` parameter with npy_serializer
394+
# Define a predictor to set `serializer` parameter with NumpySerializer
394395
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
395396
# Since we are using a placeholder container image the prediction results are not accurate.
396397
predictor = Predictor(
397398
endpoint_name=endpoint_name,
398399
sagemaker_session=sagemaker_session,
399-
serializer=npy_serializer,
400+
serializer=NumpySerializer(),
400401
deserializer=string_deserializer,
401402
)
402403

@@ -486,7 +487,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
486487
predictor = Predictor(
487488
endpoint_name=endpoint_name,
488489
sagemaker_session=sagemaker_session,
489-
serializer=npy_serializer,
490+
serializer=NumpySerializer(),
490491
deserializer=string_deserializer,
491492
)
492493

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import io
16+
17+
import numpy as np
18+
import pytest
19+
20+
from sagemaker.serializers import NumpySerializer
21+
22+
23+
@pytest.fixture
24+
def numpy_serializer():
25+
return NumpySerializer()
26+
27+
28+
def test_numpy_serializer_python_array(numpy_serializer):
29+
array = [1, 2, 3]
30+
result = numpy_serializer.serialize(array)
31+
32+
assert np.array_equal(array, np.load(io.BytesIO(result)))
33+
34+
35+
def test_numpy_serializer_python_array_with_dtype():
36+
numpy_serializer = NumpySerializer(dtype="float16")
37+
array = [1, 2, 3]
38+
39+
result = numpy_serializer.serialize(array)
40+
41+
deserialized = np.load(io.BytesIO(result))
42+
assert np.array_equal(array, deserialized)
43+
assert deserialized.dtype == "float16"
44+
45+
46+
def test_numpy_serializer_numpy_valid_2_dimensional(numpy_serializer):
47+
array = np.array([[1, 2, 3], [3, 4, 5]])
48+
result = numpy_serializer.serialize(array)
49+
50+
assert np.array_equal(array, np.load(io.BytesIO(result)))
51+
52+
53+
def test_numpy_serializer_numpy_valid_multidimensional(numpy_serializer):
54+
array = np.ones((10, 10, 10, 10))
55+
result = numpy_serializer.serialize(array)
56+
57+
assert np.array_equal(array, np.load(io.BytesIO(result)))
58+
59+
60+
def test_numpy_serializer_numpy_valid_list_of_strings(numpy_serializer):
61+
array = np.array(["one", "two", "three"])
62+
result = numpy_serializer.serialize(array)
63+
64+
assert np.array_equal(array, np.load(io.BytesIO(result)))
65+
66+
67+
def test_numpy_serializer_from_buffer_or_file(numpy_serializer):
68+
array = np.ones((2, 3))
69+
stream = io.BytesIO()
70+
np.save(stream, array)
71+
stream.seek(0)
72+
73+
result = numpy_serializer.serialize(stream)
74+
75+
assert np.array_equal(array, np.load(io.BytesIO(result)))
76+
77+
78+
def test_numpy_serializer_object(numpy_serializer):
79+
object = {1, 2, 3}
80+
81+
result = numpy_serializer.serialize(object)
82+
83+
assert np.array_equal(np.array(object), np.load(io.BytesIO(result), allow_pickle=True))
84+
85+
86+
def test_numpy_serializer_list_of_empty(numpy_serializer):
87+
with pytest.raises(ValueError) as invalid_input:
88+
numpy_serializer.serialize(np.array([[], []]))
89+
90+
assert "empty array" in str(invalid_input)
91+
92+
93+
def test_numpy_serializer_numpy_invalid_empty(numpy_serializer):
94+
with pytest.raises(ValueError) as invalid_input:
95+
numpy_serializer.serialize(np.array([]))
96+
97+
assert "empty array" in str(invalid_input)
98+
99+
100+
def test_numpy_serializer_python_invalid_empty(numpy_serializer):
101+
with pytest.raises(ValueError) as error:
102+
numpy_serializer.serialize([])
103+
assert "empty array" in str(error)

0 commit comments

Comments
 (0)