Skip to content

Commit 6ac82e9

Browse files
authored
breaking: Move _NPYSerializer to sagemaker.serializers and rename to NumpySerializer (#1691)
1 parent 24b2ab9 commit 6ac82e9

File tree

9 files changed

+162
-146
lines changed

9 files changed

+162
-146
lines changed

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.chainer import defaults
2727
from sagemaker.deserializers import NumpyDeserializer
28-
from sagemaker.predictor import Predictor, npy_serializer
28+
from sagemaker.predictor import Predictor
29+
from sagemaker.serializers import NumpySerializer
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4950
using the default AWS configuration chain.
5051
"""
5152
super(ChainerPredictor, self).__init__(
52-
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
53+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
5354
)
5455

5556

src/sagemaker/predictor.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import codecs
1717
import csv
1818
import json
19-
from six import StringIO, BytesIO
19+
from six import StringIO
2020
import numpy as np
2121

22-
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
22+
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV
2323
from sagemaker.deserializers import BaseDeserializer
2424
from sagemaker.model_monitor import DataCaptureConfig
2525
from sagemaker.serializers import BaseSerializer
@@ -620,51 +620,3 @@ def __call__(self, stream, content_type):
620620

621621

622622
json_deserializer = _JsonDeserializer()
623-
624-
625-
class _NPYSerializer(object):
626-
"""Placeholder docstring"""
627-
628-
def __init__(self):
629-
"""Placeholder docstring"""
630-
self.content_type = CONTENT_TYPE_NPY
631-
632-
def __call__(self, data, dtype=None):
633-
"""Serialize data into the request body in NPY format.
634-
635-
Args:
636-
data (object): Data to be serialized. Can be a numpy array, list,
637-
file, or buffer.
638-
dtype:
639-
640-
Returns:
641-
object: NPY serialized data used for the request.
642-
"""
643-
if isinstance(data, np.ndarray):
644-
if not data.size > 0:
645-
raise ValueError("empty array can't be serialized")
646-
return _npy_serialize(data)
647-
648-
if isinstance(data, list):
649-
if not len(data) > 0:
650-
raise ValueError("empty array can't be serialized")
651-
return _npy_serialize(np.array(data, dtype))
652-
653-
# files and buffers. Assumed to hold npy-formatted data.
654-
if hasattr(data, "read"):
655-
return data.read()
656-
657-
return _npy_serialize(np.array(data))
658-
659-
660-
def _npy_serialize(data):
661-
"""
662-
Args:
663-
data:
664-
"""
665-
buffer = BytesIO()
666-
np.save(buffer, data)
667-
return buffer.getvalue()
668-
669-
670-
npy_serializer = _NPYSerializer()

src/sagemaker/pytorch/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.pytorch import defaults
29-
from sagemaker.predictor import Predictor, npy_serializer
29+
from sagemaker.predictor import Predictor
30+
from sagemaker.serializers import NumpySerializer
3031

3132
logger = logging.getLogger("sagemaker")
3233

@@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5051
using the default AWS configuration chain.
5152
"""
5253
super(PyTorchPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
54+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
5455
)
5556

5657

src/sagemaker/serializers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import io
1718
import json
1819

1920
import numpy as np
@@ -43,6 +44,59 @@ def CONTENT_TYPE(self):
4344
"""The MIME type of the data sent to the inference endpoint."""
4445

4546

47+
class NumpySerializer(BaseSerializer):
48+
"""Serialize data to a buffer using the .npy format."""
49+
50+
CONTENT_TYPE = "application/x-npy"
51+
52+
def __init__(self, dtype=None):
53+
"""Initialize the dtype.
54+
55+
Args:
56+
dtype (str): The dtype of the data.
57+
"""
58+
self.dtype = dtype
59+
60+
def serialize(self, data):
61+
"""Serialize data to a buffer using the .npy format.
62+
63+
Args:
64+
data (object): Data to be serialized. Can be a NumPy array, list,
65+
file, or buffer.
66+
67+
Returns:
68+
io.BytesIO: A buffer containing data serialzied in the .npy format.
69+
"""
70+
if isinstance(data, np.ndarray):
71+
if data.size == 0:
72+
raise ValueError("Cannot serialize empty array.")
73+
return self._serialize_array(data)
74+
75+
if isinstance(data, list):
76+
if len(data) == 0:
77+
raise ValueError("Cannot serialize empty array.")
78+
return self._serialize_array(np.array(data, self.dtype))
79+
80+
# files and buffers. Assumed to hold npy-formatted data.
81+
if hasattr(data, "read"):
82+
return data.read()
83+
84+
return self._serialize_array(np.array(data))
85+
86+
def _serialize_array(self, array):
87+
"""Saves a NumPy array in a buffer.
88+
89+
Args:
90+
array (numpy.ndarray): The array to serialize.
91+
92+
Returns:
93+
io.BytesIO: A buffer containing the serialized array.
94+
"""
95+
buffer = io.BytesIO()
96+
np.save(buffer, array)
97+
return buffer.getvalue()
98+
99+
46100
class JSONSerializer(BaseSerializer):
47101
"""Serialize data to a JSON formatted string."""
48102

src/sagemaker/sklearn/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sagemaker.fw_registry import default_framework_uri
2121
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
23-
from sagemaker.predictor import Predictor, npy_serializer
23+
from sagemaker.predictor import Predictor
24+
from sagemaker.serializers import NumpySerializer
2425
from sagemaker.sklearn import defaults
2526

2627
logger = logging.getLogger("sagemaker")
@@ -45,7 +46,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4546
using the default AWS configuration chain.
4647
"""
4748
super(SKLearnPredictor, self).__init__(
48-
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
49+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
4950
)
5051

5152

src/sagemaker/xgboost/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sagemaker.fw_utils import model_code_key_prefix
2121
from sagemaker.fw_registry import default_framework_uri
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
23-
from sagemaker.predictor import Predictor, npy_serializer
23+
from sagemaker.predictor import Predictor
24+
from sagemaker.serializers import NumpySerializer
2425
from sagemaker.xgboost.defaults import XGBOOST_NAME
2526

2627
logger = logging.getLogger("sagemaker")
@@ -43,7 +44,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4344
chain.
4445
"""
4546
super(XGBoostPredictor, self).__init__(
46-
endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer()
47+
endpoint_name, sagemaker_session, NumpySerializer(), CSVDeserializer()
4748
)
4849

4950

tests/integ/test_multidatamodel.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from sagemaker.deserializers import StringDeserializer
2828
from sagemaker.multidatamodel import MultiDataModel
2929
from sagemaker.mxnet import MXNet
30-
from sagemaker.predictor import Predictor, npy_serializer
30+
from sagemaker.predictor import Predictor
31+
from sagemaker.serializers import NumpySerializer
3132
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
3233
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
3334
from tests.integ.retry import retries
@@ -158,7 +159,7 @@ def test_multi_data_model_deploy_pretrained_models(
158159
predictor = Predictor(
159160
endpoint_name=endpoint_name,
160161
sagemaker_session=sagemaker_session,
161-
serializer=npy_serializer,
162+
serializer=NumpySerializer(),
162163
deserializer=string_deserializer,
163164
)
164165

@@ -216,7 +217,7 @@ def test_multi_data_model_deploy_pretrained_models_local_mode(container_image, s
216217
predictor = Predictor(
217218
endpoint_name=endpoint_name,
218219
sagemaker_session=multi_data_model.sagemaker_session,
219-
serializer=npy_serializer,
220+
serializer=NumpySerializer(),
220221
deserializer=string_deserializer,
221222
)
222223

@@ -289,13 +290,13 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
289290
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
290291
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
291292

292-
# Define a predictor to set `serializer` parameter with npy_serializer
293+
# Define a predictor to set `serializer` parameter with `NumpySerializer`
293294
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
294295
# Since we are using a placeholder container image the prediction results are not accurate.
295296
predictor = Predictor(
296297
endpoint_name=endpoint_name,
297298
sagemaker_session=sagemaker_session,
298-
serializer=npy_serializer,
299+
serializer=NumpySerializer(),
299300
deserializer=string_deserializer,
300301
)
301302

@@ -390,13 +391,13 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
390391
assert PRETRAINED_MODEL_PATH_1 in endpoint_models
391392
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
392393

393-
# Define a predictor to set `serializer` parameter with npy_serializer
394+
# Define a predictor to set `serializer` parameter with `NumpySerializer`
394395
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
395396
# Since we are using a placeholder container image the prediction results are not accurate.
396397
predictor = Predictor(
397398
endpoint_name=endpoint_name,
398399
sagemaker_session=sagemaker_session,
399-
serializer=npy_serializer,
400+
serializer=NumpySerializer(),
400401
deserializer=string_deserializer,
401402
)
402403

@@ -486,7 +487,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
486487
predictor = Predictor(
487488
endpoint_name=endpoint_name,
488489
sagemaker_session=sagemaker_session,
489-
serializer=npy_serializer,
490+
serializer=NumpySerializer(),
490491
deserializer=string_deserializer,
491492
)
492493

tests/unit/sagemaker/test_serializers.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,100 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import io
1516
import json
1617
import os
1718

1819
import numpy as np
1920
import pytest
2021

21-
from sagemaker.serializers import JSONSerializer
22+
from sagemaker.serializers import NumpySerializer, JSONSerializer
2223
from tests.unit import DATA_DIR
2324

2425

26+
@pytest.fixture
27+
def numpy_serializer():
28+
return NumpySerializer()
29+
30+
31+
def test_numpy_serializer_python_array(numpy_serializer):
32+
array = [1, 2, 3]
33+
result = numpy_serializer.serialize(array)
34+
35+
assert np.array_equal(array, np.load(io.BytesIO(result)))
36+
37+
38+
def test_numpy_serializer_python_array_with_dtype():
39+
numpy_serializer = NumpySerializer(dtype="float16")
40+
array = [1, 2, 3]
41+
42+
result = numpy_serializer.serialize(array)
43+
44+
deserialized = np.load(io.BytesIO(result))
45+
assert np.array_equal(array, deserialized)
46+
assert deserialized.dtype == "float16"
47+
48+
49+
def test_numpy_serializer_numpy_valid_2_dimensional(numpy_serializer):
50+
array = np.array([[1, 2, 3], [3, 4, 5]])
51+
result = numpy_serializer.serialize(array)
52+
53+
assert np.array_equal(array, np.load(io.BytesIO(result)))
54+
55+
56+
def test_numpy_serializer_numpy_valid_multidimensional(numpy_serializer):
57+
array = np.ones((10, 10, 10, 10))
58+
result = numpy_serializer.serialize(array)
59+
60+
assert np.array_equal(array, np.load(io.BytesIO(result)))
61+
62+
63+
def test_numpy_serializer_numpy_valid_list_of_strings(numpy_serializer):
64+
array = np.array(["one", "two", "three"])
65+
result = numpy_serializer.serialize(array)
66+
67+
assert np.array_equal(array, np.load(io.BytesIO(result)))
68+
69+
70+
def test_numpy_serializer_from_buffer_or_file(numpy_serializer):
71+
array = np.ones((2, 3))
72+
stream = io.BytesIO()
73+
np.save(stream, array)
74+
stream.seek(0)
75+
76+
result = numpy_serializer.serialize(stream)
77+
78+
assert np.array_equal(array, np.load(io.BytesIO(result)))
79+
80+
81+
def test_numpy_serializer_object(numpy_serializer):
82+
object = {1, 2, 3}
83+
84+
result = numpy_serializer.serialize(object)
85+
86+
assert np.array_equal(np.array(object), np.load(io.BytesIO(result), allow_pickle=True))
87+
88+
89+
def test_numpy_serializer_list_of_empty(numpy_serializer):
90+
with pytest.raises(ValueError) as invalid_input:
91+
numpy_serializer.serialize(np.array([[], []]))
92+
93+
assert "empty array" in str(invalid_input)
94+
95+
96+
def test_numpy_serializer_numpy_invalid_empty(numpy_serializer):
97+
with pytest.raises(ValueError) as invalid_input:
98+
numpy_serializer.serialize(np.array([]))
99+
100+
assert "empty array" in str(invalid_input)
101+
102+
103+
def test_numpy_serializer_python_invalid_empty(numpy_serializer):
104+
with pytest.raises(ValueError) as error:
105+
numpy_serializer.serialize([])
106+
assert "empty array" in str(error)
107+
108+
25109
@pytest.fixture
26110
def json_serializer():
27111
return JSONSerializer()

0 commit comments

Comments
 (0)