Skip to content

Commit a3f4f4b

Browse files
authored
change: add csv deserializer (#737)
1 parent 2d67ff5 commit a3f4f4b

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

src/sagemaker/predictor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,21 @@ def _row_to_csv(obj):
229229
return ','.join(obj)
230230

231231

232+
class _CsvDeserializer(object):
233+
def __init__(self, encoding='utf-8'):
234+
self.accept = CONTENT_TYPE_CSV
235+
self.encoding = encoding
236+
237+
def __call__(self, stream, content_type):
238+
try:
239+
return list(csv.reader(stream.read().decode(self.encoding).splitlines()))
240+
finally:
241+
stream.close()
242+
243+
244+
csv_deserializer = _CsvDeserializer()
245+
246+
232247
class BytesDeserializer(object):
233248
"""Return the response as an undecoded array of bytes.
234249

tests/unit/test_predictor.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import numpy as np
2222

2323
from sagemaker.predictor import RealTimePredictor
24-
from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, BytesDeserializer, \
25-
StringDeserializer, StreamDeserializer, numpy_deserializer, npy_serializer, _NumpyDeserializer
24+
from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, \
25+
csv_deserializer, BytesDeserializer, StringDeserializer, StreamDeserializer, \
26+
numpy_deserializer, npy_serializer, _NumpyDeserializer
2627
from tests.unit import DATA_DIR
2728

2829
# testing serialization functions
@@ -141,6 +142,21 @@ def test_csv_serializer_csv_reader():
141142
assert result == validation_data
142143

143144

145+
def test_csv_deserializer_single_element():
146+
result = csv_deserializer(io.BytesIO(b'1'), 'text/csv')
147+
assert result == [['1']]
148+
149+
150+
def test_csv_deserializer_array():
151+
result = csv_deserializer(io.BytesIO(b'1,2,3'), 'text/csv')
152+
assert result == [['1', '2', '3']]
153+
154+
155+
def test_csv_deserializer_2dimensional():
156+
result = csv_deserializer(io.BytesIO(b'1,2,3\n3,4,5'), 'text/csv')
157+
assert result == [['1', '2', '3'], ['3', '4', '5']]
158+
159+
144160
def test_json_deserializer_array():
145161
result = json_deserializer(io.BytesIO(b'[1, 2, 3]'), 'application/json')
146162

0 commit comments

Comments
 (0)