Skip to content

Commit 99a9652

Browse files
author
Chuyang Deng
committed
Merge branch 'rename-s3-input' of github.com:ChuyangDeng/sagemaker-python-sdk into rename-s3-input
2 parents cfb7c2e + 083da46 commit 99a9652

File tree

18 files changed

+237
-194
lines changed

18 files changed

+237
-194
lines changed

src/sagemaker/amazon/common.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,37 @@
2222

2323
from sagemaker.amazon.record_pb2 import Record
2424
from sagemaker.deserializers import BaseDeserializer
25+
from sagemaker.serializers import BaseSerializer
2526
from sagemaker.utils import DeferredError
2627

2728

28-
class numpy_to_record_serializer(object):
29-
"""Placeholder docstring"""
29+
class RecordSerializer(BaseSerializer):
30+
"""Serialize a NumPy array for an inference request."""
3031

31-
def __init__(self, content_type="application/x-recordio-protobuf"):
32-
"""
33-
Args:
34-
content_type:
35-
"""
36-
self.content_type = content_type
32+
CONTENT_TYPE = "application/x-recordio-protobuf"
33+
34+
def serialize(self, data):
35+
"""Serialize a NumPy array into a buffer containing RecordIO records.
3736
38-
def __call__(self, array):
39-
"""
4037
Args:
41-
array:
38+
data (numpy.ndarray): The data to serialize.
39+
40+
Returns:
41+
io.BytesIO: A buffer containing the data serialized as records.
4242
"""
43-
if len(array.shape) == 1:
44-
array = array.reshape(1, array.shape[0])
45-
assert len(array.shape) == 2, "Expecting a 1 or 2 dimensional array"
46-
buf = io.BytesIO()
47-
write_numpy_to_dense_tensor(buf, array)
48-
buf.seek(0)
49-
return buf
43+
if len(data.shape) == 1:
44+
data = data.reshape(1, data.shape[0])
45+
46+
if len(data.shape) != 2:
47+
raise ValueError(
48+
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
49+
)
50+
51+
buffer = io.BytesIO()
52+
write_numpy_to_dense_tensor(buffer, data)
53+
buffer.seek(0)
54+
55+
return buffer
5056

5157

5258
class RecordDeserializer(BaseDeserializer):

src/sagemaker/amazon/factorization_machines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge
2020
from sagemaker.predictor import Predictor
@@ -289,7 +289,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
289289
super(FactorizationMachinesPredictor, self).__init__(
290290
endpoint_name,
291291
sagemaker_session,
292-
serializer=numpy_to_record_serializer(),
292+
serializer=RecordSerializer(),
293293
deserializer=RecordDeserializer(),
294294
)
295295

src/sagemaker/amazon/kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin, ge, le
2020
from sagemaker.predictor import Predictor
@@ -222,7 +222,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
222222
super(KMeansPredictor, self).__init__(
223223
endpoint_name,
224224
sagemaker_session,
225-
serializer=numpy_to_record_serializer(),
225+
serializer=RecordSerializer(),
226226
deserializer=RecordDeserializer(),
227227
)
228228

src/sagemaker/amazon/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, isin
2020
from sagemaker.predictor import Predictor
@@ -210,7 +210,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
210210
super(KNNPredictor, self).__init__(
211211
endpoint_name,
212212
sagemaker_session,
213-
serializer=numpy_to_record_serializer(),
213+
serializer=RecordSerializer(),
214214
deserializer=RecordDeserializer(),
215215
)
216216

src/sagemaker/amazon/lda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt
2020
from sagemaker.predictor import Predictor
@@ -194,7 +194,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
194194
super(LDAPredictor, self).__init__(
195195
endpoint_name,
196196
sagemaker_session,
197-
serializer=numpy_to_record_serializer(),
197+
serializer=RecordSerializer(),
198198
deserializer=RecordDeserializer(),
199199
)
200200

src/sagemaker/amazon/linear_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import isin, gt, lt, ge, le
2020
from sagemaker.predictor import Predictor
@@ -453,7 +453,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
453453
super(LinearLearnerPredictor, self).__init__(
454454
endpoint_name,
455455
sagemaker_session,
456-
serializer=numpy_to_record_serializer(),
456+
serializer=RecordSerializer(),
457457
deserializer=RecordDeserializer(),
458458
)
459459

src/sagemaker/amazon/ntm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le, isin
2020
from sagemaker.predictor import Predictor
@@ -224,7 +224,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
224224
super(NTMPredictor, self).__init__(
225225
endpoint_name,
226226
sagemaker_session,
227-
serializer=numpy_to_record_serializer(),
227+
serializer=RecordSerializer(),
228228
deserializer=RecordDeserializer(),
229229
)
230230

src/sagemaker/amazon/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import gt, isin
2020
from sagemaker.predictor import Predictor
@@ -206,7 +206,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
206206
super(PCAPredictor, self).__init__(
207207
endpoint_name,
208208
sagemaker_session,
209-
serializer=numpy_to_record_serializer(),
209+
serializer=RecordSerializer(),
210210
deserializer=RecordDeserializer(),
211211
)
212212

src/sagemaker/amazon/randomcutforest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
17-
from sagemaker.amazon.common import numpy_to_record_serializer, RecordDeserializer
17+
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
1818
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1919
from sagemaker.amazon.validation import ge, le
2020
from sagemaker.predictor import Predictor
@@ -183,7 +183,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
183183
super(RandomCutForestPredictor, self).__init__(
184184
endpoint_name,
185185
sagemaker_session,
186-
serializer=numpy_to_record_serializer(),
186+
serializer=RecordSerializer(),
187187
deserializer=RecordDeserializer(),
188188
)
189189

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
)
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.chainer import defaults
27-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
27+
from sagemaker.deserializers import NumpyDeserializer
28+
from sagemaker.predictor import Predictor, npy_serializer
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -48,7 +49,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4849
using the default AWS configuration chain.
4950
"""
5051
super(ChainerPredictor, self).__init__(
51-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
52+
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
5253
)
5354

5455

src/sagemaker/deserializers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
"""Implements methods for deserializing data returned from an inference endpoint."""
1414
from __future__ import absolute_import
1515

16+
import csv
17+
1618
import abc
19+
import codecs
20+
import io
21+
import json
22+
23+
import numpy as np
1724

1825

1926
class BaseDeserializer(abc.ABC):
@@ -91,6 +98,37 @@ def deserialize(self, data, content_type):
9198
data.close()
9299

93100

101+
class CSVDeserializer(BaseDeserializer):
102+
"""Deserialize a stream of bytes into a list of lists."""
103+
104+
ACCEPT = "text/csv"
105+
106+
def __init__(self, encoding="utf-8"):
107+
"""Initialize the string encoding.
108+
109+
Args:
110+
encoding (str): The string encoding to use (default: "utf-8").
111+
"""
112+
self.encoding = encoding
113+
114+
def deserialize(self, data, content_type):
115+
"""Deserialize data from an inference endpoint into a list of lists.
116+
117+
Args:
118+
data (botocore.response.StreamingBody): Data to be deserialized.
119+
content_type (str): The MIME type of the data.
120+
121+
Returns:
122+
list: The data deserialized into a list of lists representing the
123+
contents of a CSV file.
124+
"""
125+
try:
126+
decoded_string = data.read().decode(self.encoding)
127+
return list(csv.reader(decoded_string.splitlines()))
128+
finally:
129+
data.close()
130+
131+
94132
class StreamDeserializer(BaseDeserializer):
95133
"""Returns the data and content-type received from an inference endpoint.
96134
@@ -111,3 +149,41 @@ def deserialize(self, data, content_type):
111149
tuple: A two-tuple containing the stream and content-type.
112150
"""
113151
return data, content_type
152+
153+
154+
class NumpyDeserializer(BaseDeserializer):
155+
"""Deserialize a stream of data in the .npy format."""
156+
157+
ACCEPT = "application/x-npy"
158+
159+
def __init__(self, dtype=None):
160+
"""Initialize the dtype.
161+
162+
Args:
163+
dtype (str): The dtype of the data.
164+
"""
165+
self.dtype = dtype
166+
167+
def deserialize(self, data, content_type):
168+
"""Deserialize data from an inference endpoint into a NumPy array.
169+
170+
Args:
171+
data (botocore.response.StreamingBody): Data to be deserialized.
172+
content_type (str): The MIME type of the data.
173+
174+
Returns:
175+
numpy.ndarray: The data deserialized into a NumPy array.
176+
"""
177+
try:
178+
if content_type == "text/csv":
179+
return np.genfromtxt(
180+
codecs.getreader("utf-8")(data), delimiter=",", dtype=self.dtype
181+
)
182+
if content_type == "application/json":
183+
return np.array(json.load(codecs.getreader("utf-8")(data)), dtype=self.dtype)
184+
if content_type == "application/x-npy":
185+
return np.load(io.BytesIO(data.read()))
186+
finally:
187+
data.close()
188+
189+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

src/sagemaker/predictor.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -597,32 +597,6 @@ def _row_to_csv(obj):
597597
return ",".join(obj)
598598

599599

600-
class _CsvDeserializer(object):
601-
"""Placeholder docstring"""
602-
603-
def __init__(self, encoding="utf-8"):
604-
"""
605-
Args:
606-
encoding:
607-
"""
608-
self.accept = CONTENT_TYPE_CSV
609-
self.encoding = encoding
610-
611-
def __call__(self, stream, content_type):
612-
"""
613-
Args:
614-
stream:
615-
content_type:
616-
"""
617-
try:
618-
return list(csv.reader(stream.read().decode(self.encoding).splitlines()))
619-
finally:
620-
stream.close()
621-
622-
623-
csv_deserializer = _CsvDeserializer()
624-
625-
626600
class _JsonSerializer(object):
627601
"""Placeholder docstring"""
628602

@@ -698,50 +672,6 @@ def __call__(self, stream, content_type):
698672
json_deserializer = _JsonDeserializer()
699673

700674

701-
class _NumpyDeserializer(object):
702-
"""Placeholder docstring"""
703-
704-
def __init__(self, accept=CONTENT_TYPE_NPY, dtype=None):
705-
"""
706-
Args:
707-
accept:
708-
dtype:
709-
"""
710-
self.accept = accept
711-
self.dtype = dtype
712-
713-
def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
714-
"""Decode from serialized data into a Numpy array.
715-
716-
Args:
717-
stream (stream): The response stream to be deserialized.
718-
content_type (str): The content type of the response. Can accept
719-
CSV, JSON, or NPY data.
720-
721-
Returns:
722-
object: Body of the response deserialized into a Numpy array.
723-
"""
724-
try:
725-
if content_type == CONTENT_TYPE_CSV:
726-
return np.genfromtxt(
727-
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
728-
)
729-
if content_type == CONTENT_TYPE_JSON:
730-
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
731-
if content_type == CONTENT_TYPE_NPY:
732-
return np.load(BytesIO(stream.read()))
733-
finally:
734-
stream.close()
735-
raise ValueError(
736-
"content_type must be one of the following: CSV, JSON, NPY. content_type: {}".format(
737-
content_type
738-
)
739-
)
740-
741-
742-
numpy_deserializer = _NumpyDeserializer()
743-
744-
745675
class _NPYSerializer(object):
746676
"""Placeholder docstring"""
747677

0 commit comments

Comments
 (0)