Skip to content

Commit bf6a3d6

Browse files
authored
Merge branch 'zwei' into add-json-serializer
2 parents 90b7c01 + cc2d047 commit bf6a3d6

File tree

5 files changed

+62
-44
lines changed

5 files changed

+62
-44
lines changed

src/sagemaker/deserializers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Implements methods for deserializing data returned from an inference endpoint."""
1414
from __future__ import absolute_import
1515

16+
import csv
17+
1618
import abc
1719
import codecs
1820
import io
@@ -96,6 +98,37 @@ def deserialize(self, data, content_type):
9698
data.close()
9799

98100

101+
class CSVDeserializer(BaseDeserializer):
102+
"""Deserialize a stream of bytes into a list of lists."""
103+
104+
ACCEPT = "text/csv"
105+
106+
def __init__(self, encoding="utf-8"):
107+
"""Initialize the string encoding.
108+
109+
Args:
110+
encoding (str): The string encoding to use (default: "utf-8").
111+
"""
112+
self.encoding = encoding
113+
114+
def deserialize(self, data, content_type):
115+
"""Deserialize data from an inference endpoint into a list of lists.
116+
117+
Args:
118+
data (botocore.response.StreamingBody): Data to be deserialized.
119+
content_type (str): The MIME type of the data.
120+
121+
Returns:
122+
list: The data deserialized into a list of lists representing the
123+
contents of a CSV file.
124+
"""
125+
try:
126+
decoded_string = data.read().decode(self.encoding)
127+
return list(csv.reader(decoded_string.splitlines()))
128+
finally:
129+
data.close()
130+
131+
99132
class StreamDeserializer(BaseDeserializer):
100133
"""Returns the data and content-type received from an inference endpoint.
101134

src/sagemaker/predictor.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -596,32 +596,6 @@ def _row_to_csv(obj):
596596
return ",".join(obj)
597597

598598

599-
class _CsvDeserializer(object):
600-
"""Placeholder docstring"""
601-
602-
def __init__(self, encoding="utf-8"):
603-
"""
604-
Args:
605-
encoding:
606-
"""
607-
self.accept = CONTENT_TYPE_CSV
608-
self.encoding = encoding
609-
610-
def __call__(self, stream, content_type):
611-
"""
612-
Args:
613-
stream:
614-
content_type:
615-
"""
616-
try:
617-
return list(csv.reader(stream.read().decode(self.encoding).splitlines()))
618-
finally:
619-
stream.close()
620-
621-
622-
csv_deserializer = _CsvDeserializer()
623-
624-
625599
class _JsonDeserializer(object):
626600
"""Placeholder docstring"""
627601

src/sagemaker/xgboost/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker.deserializers import CSVDeserializer
1920
from sagemaker.fw_utils import model_code_key_prefix
2021
from sagemaker.fw_registry import default_framework_uri
2122
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer
23+
from sagemaker.predictor import Predictor, npy_serializer
2324
from sagemaker.xgboost.defaults import XGBOOST_NAME
2425

2526
logger = logging.getLogger("sagemaker")
@@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4243
chain.
4344
"""
4445
super(XGBoostPredictor, self).__init__(
45-
endpoint_name, sagemaker_session, npy_serializer, csv_deserializer
46+
endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer()
4647
)
4748

4849

tests/unit/sagemaker/test_deserializers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.deserializers import (
2121
StringDeserializer,
2222
BytesDeserializer,
23+
CSVDeserializer,
2324
StreamDeserializer,
2425
NumpyDeserializer,
2526
)
@@ -41,6 +42,31 @@ def test_bytes_deserializer():
4142
assert result == b"[1, 2, 3]"
4243

4344

45+
@pytest.fixture
46+
def csv_deserializer():
47+
return CSVDeserializer()
48+
49+
50+
def test_csv_deserializer_single_element(csv_deserializer):
51+
result = csv_deserializer.deserialize(io.BytesIO(b"1"), "text/csv")
52+
assert result == [["1"]]
53+
54+
55+
def test_csv_deserializer_array(csv_deserializer):
56+
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv")
57+
assert result == [["1", "2", "3"]]
58+
59+
60+
def test_csv_deserializer_2dimensional(csv_deserializer):
61+
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv")
62+
assert result == [["1", "2", "3"], ["3", "4", "5"]]
63+
64+
65+
def test_csv_deserializer_posix_compliant(csv_deserializer):
66+
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5\n"), "text/csv")
67+
assert result == [["1", "2", "3"], ["3", "4", "5"]]
68+
69+
4470
def test_stream_deserializer():
4571
deserializer = StreamDeserializer()
4672

tests/unit/test_predictor.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from sagemaker.predictor import (
2525
json_deserializer,
2626
csv_serializer,
27-
csv_deserializer,
2827
npy_serializer,
2928
)
3029
from sagemaker.serializers import JSONSerializer
@@ -99,21 +98,6 @@ def test_csv_serializer_csv_reader():
9998
assert result == validation_data
10099

101100

102-
def test_csv_deserializer_single_element():
103-
result = csv_deserializer(io.BytesIO(b"1"), "text/csv")
104-
assert result == [["1"]]
105-
106-
107-
def test_csv_deserializer_array():
108-
result = csv_deserializer(io.BytesIO(b"1,2,3"), "text/csv")
109-
assert result == [["1", "2", "3"]]
110-
111-
112-
def test_csv_deserializer_2dimensional():
113-
result = csv_deserializer(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv")
114-
assert result == [["1", "2", "3"], ["3", "4", "5"]]
115-
116-
117101
def test_json_deserializer_array():
118102
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")
119103

0 commit comments

Comments
 (0)