Skip to content

Commit 7c1762c

Browse files
authored
Merge branch 'zwei' into add-numpy-serializer
2 parents 091a7ca + b837dc2 commit 7c1762c

File tree

17 files changed

+176
-151
lines changed

17 files changed

+176
-151
lines changed

src/sagemaker/amazon/common.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,37 @@
2222

2323
from sagemaker.amazon.record_pb2 import Record
2424
from sagemaker.deserializers import BaseDeserializer
25+
from sagemaker.serializers import BaseSerializer
2526
from sagemaker.utils import DeferredError
2627

2728

28-
class numpy_to_record_serializer(object):
29-
"""Placeholder docstring"""
29+
class RecordSerializer(BaseSerializer):
30+
"""Serialize a NumPy array for an inference request."""
3031

31-
def __init__(self, content_type="application/x-recordio-protobuf"):
32-
"""
33-
Args:
34-
content_type:
35-
"""
36-
self.content_type = content_type
32+
CONTENT_TYPE = "application/x-recordio-protobuf"
33+
34+
def serialize(self, data):
35+
"""Serialize a NumPy array into a buffer containing RecordIO records.
3736
38-
def __call__(self, array):
39-
"""
4037
Args:
41-
array:
38+
data (numpy.ndarray): The data to serialize.
39+
40+
Returns:
41+
io.BytesIO: A buffer containing the data serialized as records.
4242
"""
43-
if len(array.shape) == 1:
44-
array = array.reshape(1, array.shape[0])
45-
assert len(array.shape) == 2, "Expecting a 1 or 2 dimensional array"
46-
buf = io.BytesIO()
47-
write_numpy_to_dense_tensor(buf, array)
48-
buf.seek(0)
49-
return buf
43+
if len(data.shape) == 1:
44+
data = data.reshape(1, data.shape[0])
45+
46+
if len(data.shape) != 2:
47+
raise ValueError(
48+
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
49+
)
50+
51+
buffer = io.BytesIO()
52+
write_numpy_to_dense_tensor(buffer, data)
53+
buffer.seek(0)
54+
55+
return buffer
5056

5157

5258
class RecordDeserializer(BaseDeserializer):

src/sagemaker/amazon/factorization_machines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge
2020
from sagemaker.predictor import Predictor
@@ -289,7 +289,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
289289
super(FactorizationMachinesPredictor, self).__init__(
290290
endpoint_name,
291291
sagemaker_session,
292-
serializer=numpy_to_record_serializer(),
292+
serializer=RecordSerializer(),
293293
deserializer=RecordDeserializer(),
294294
)
295295

src/sagemaker/amazon/kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge, le
2020
from sagemaker.predictor import Predictor
@@ -222,7 +222,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
222222
super(KMeansPredictor, self).__init__(
223223
endpoint_name,
224224
sagemaker_session,
225-
serializer=numpy_to_record_serializer(),
225+
serializer=RecordSerializer(),
226226
deserializer=RecordDeserializer(),
227227
)
228228

src/sagemaker/amazon/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, isin
2020
from sagemaker.predictor import Predictor
@@ -210,7 +210,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
210210
super(KNNPredictor, self).__init__(
211211
endpoint_name,
212212
sagemaker_session,
213-
serializer=numpy_to_record_serializer(),
213+
serializer=RecordSerializer(),
214214
deserializer=RecordDeserializer(),
215215
)
216216

src/sagemaker/amazon/lda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt
2020
from sagemaker.predictor import Predictor
@@ -194,7 +194,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
194194
super(LDAPredictor, self).__init__(
195195
endpoint_name,
196196
sagemaker_session,
197-
serializer=numpy_to_record_serializer(),
197+
serializer=RecordSerializer(),
198198
deserializer=RecordDeserializer(),
199199
)
200200

src/sagemaker/amazon/linear_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import isin, gt, lt, ge, le
2020
from sagemaker.predictor import Predictor
@@ -453,7 +453,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
453453
super(LinearLearnerPredictor, self).__init__(
454454
endpoint_name,
455455
sagemaker_session,
456-
serializer=numpy_to_record_serializer(),
456+
serializer=RecordSerializer(),
457457
deserializer=RecordDeserializer(),
458458
)
459459

src/sagemaker/amazon/ntm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le, isin
2020
from sagemaker.predictor import Predictor
@@ -224,7 +224,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
224224
super(NTMPredictor, self).__init__(
225225
endpoint_name,
226226
sagemaker_session,
227-
serializer=numpy_to_record_serializer(),
227+
serializer=RecordSerializer(),
228228
deserializer=RecordDeserializer(),
229229
)
230230

src/sagemaker/amazon/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin
2020
from sagemaker.predictor import Predictor
@@ -206,7 +206,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
206206
super(PCAPredictor, self).__init__(
207207
endpoint_name,
208208
sagemaker_session,
209-
serializer=numpy_to_record_serializer(),
209+
serializer=RecordSerializer(),
210210
deserializer=RecordDeserializer(),
211211
)
212212

src/sagemaker/amazon/randomcutforest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le
2020
from sagemaker.predictor import Predictor
@@ -183,7 +183,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
183183
super(RandomCutForestPredictor, self).__init__(
184184
endpoint_name,
185185
sagemaker_session,
186-
serializer=numpy_to_record_serializer(),
186+
serializer=RecordSerializer(),
187187
deserializer=RecordDeserializer(),
188188
)
189189

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
)
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.chainer import defaults
27-
from sagemaker.predictor import Predictor, numpy_deserializer
27+
from sagemaker.deserializers import NumpyDeserializer
28+
from sagemaker.predictor import Predictor
2829
from sagemaker.serializers import NumpySerializer
2930

3031
logger = logging.getLogger("sagemaker")
@@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4950
using the default AWS configuration chain.
5051
"""
5152
super(ChainerPredictor, self).__init__(
52-
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
53+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
5354
)
5455

5556

src/sagemaker/deserializers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import codecs
18+
import io
19+
import json
20+
21+
import numpy as np
1722

1823

1924
class BaseDeserializer(abc.ABC):
@@ -111,3 +116,41 @@ def deserialize(self, data, content_type):
111116
tuple: A two-tuple containing the stream and content-type.
112117
"""
113118
return data, content_type
119+
120+
121+
class NumpyDeserializer(BaseDeserializer):
122+
"""Deserialize a stream of data in the .npy format."""
123+
124+
ACCEPT = "application/x-npy"
125+
126+
def __init__(self, dtype=None):
127+
"""Initialize the dtype.
128+
129+
Args:
130+
dtype (str): The dtype of the data.
131+
"""
132+
self.dtype = dtype
133+
134+
def deserialize(self, data, content_type):
135+
"""Deserialize data from an inference endpoint into a NumPy array.
136+
137+
Args:
138+
data (botocore.response.StreamingBody): Data to be deserialized.
139+
content_type (str): The MIME type of the data.
140+
141+
Returns:
142+
numpy.ndarray: The data deserialized into a NumPy array.
143+
"""
144+
try:
145+
if content_type == "text/csv":
146+
return np.genfromtxt(
147+
codecs.getreader("utf-8")(data), delimiter=",", dtype=self.dtype
148+
)
149+
if content_type == "application/json":
150+
return np.array(json.load(codecs.getreader("utf-8")(data)), dtype=self.dtype)
151+
if content_type == "application/x-npy":
152+
return np.load(io.BytesIO(data.read()))
153+
finally:
154+
data.close()
155+
156+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

src/sagemaker/predictor.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -695,48 +695,4 @@ def __call__(self, stream, content_type):
695695
stream.close()
696696

697697

698-
json_deserializer = _JsonDeserializer()
699-
700-
701-
class _NumpyDeserializer(object):
702-
"""Placeholder docstring"""
703-
704-
def __init__(self, accept=CONTENT_TYPE_NPY, dtype=None):
705-
"""
706-
Args:
707-
accept:
708-
dtype:
709-
"""
710-
self.accept = accept
711-
self.dtype = dtype
712-
713-
def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
714-
"""Decode from serialized data into a Numpy array.
715-
716-
Args:
717-
stream (stream): The response stream to be deserialized.
718-
content_type (str): The content type of the response. Can accept
719-
CSV, JSON, or NPY data.
720-
721-
Returns:
722-
object: Body of the response deserialized into a Numpy array.
723-
"""
724-
try:
725-
if content_type == CONTENT_TYPE_CSV:
726-
return np.genfromtxt(
727-
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
728-
)
729-
if content_type == CONTENT_TYPE_JSON:
730-
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
731-
if content_type == CONTENT_TYPE_NPY:
732-
return np.load(BytesIO(stream.read()))
733-
finally:
734-
stream.close()
735-
raise ValueError(
736-
"content_type must be one of the following: CSV, JSON, NPY. content_type: {}".format(
737-
content_type
738-
)
739-
)
740-
741-
742-
numpy_deserializer = _NumpyDeserializer()
698+
json_deserializer = _JsonDeserializer()

src/sagemaker/pytorch/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import packaging.version
1818

1919
import sagemaker
20+
from sagemaker.deserializers import NumpyDeserializer
2021
from sagemaker.fw_utils import (
2122
create_image_uri,
2223
model_code_key_prefix,
@@ -25,7 +26,7 @@
2526
)
2627
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2728
from sagemaker.pytorch import defaults
28-
from sagemaker.predictor import Predictor, numpy_deserializer
29+
from sagemaker.predictor import Predictor
2930
from sagemaker.serializers import NumpySerializer
3031

3132
logger = logging.getLogger("sagemaker")
@@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5051
using the default AWS configuration chain.
5152
"""
5253
super(PyTorchPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
54+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
5455
)
5556

5657

src/sagemaker/sklearn/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker.deserializers import NumpyDeserializer
1920
from sagemaker.fw_registry import default_framework_uri
2021
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
2122
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.predictor import Predictor, numpy_deserializer
23+
from sagemaker.predictor import Predictor
2324
from sagemaker.serializers import NumpySerializer
2425
from sagemaker.sklearn import defaults
2526

@@ -45,7 +46,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4546
using the default AWS configuration chain.
4647
"""
4748
super(SKLearnPredictor, self).__init__(
48-
endpoint_name, sagemaker_session, NumpySerializer(), numpy_deserializer
49+
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
4950
)
5051

5152

0 commit comments

Comments
 (0)