Skip to content

Commit d8a6038

Browse files
authored
Merge branch 'zwei' into no-more-trains-in-docs
2 parents caef915 + d26d3f6 commit d8a6038

File tree

4 files changed

+36
-34
lines changed

4 files changed

+36
-34
lines changed

src/sagemaker/deserializers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,25 @@ def deserialize(self, data, content_type):
8989
return data.read()
9090
finally:
9191
data.close()
92+
93+
94+
class StreamDeserializer(BaseDeserializer):
95+
"""Returns the data and content-type received from an inference endpoint.
96+
97+
It is the user's responsibility to close the data stream once they're done
98+
reading it.
99+
"""
100+
101+
ACCEPT = "*/*"
102+
103+
def deserialize(self, data, content_type):
104+
"""Returns a stream of the response body and the MIME type of the data.
105+
106+
Args:
107+
data (object): A stream of bytes.
108+
content_type (str): The MIME type of the data.
109+
110+
Returns:
111+
tuple: A two-tuple containing the stream and content-type.
112+
"""
113+
return data, content_type

src/sagemaker/predictor.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -623,31 +623,6 @@ def __call__(self, stream, content_type):
623623
csv_deserializer = _CsvDeserializer()
624624

625625

626-
class StreamDeserializer(object):
627-
"""Returns the tuple of the response stream and the content-type of the response.
628-
It is the receivers responsibility to close the stream when they're done
629-
reading the stream.
630-
631-
Args:
632-
accept (str): The Accept header to send to the server (optional).
633-
"""
634-
635-
def __init__(self, accept=None):
636-
"""
637-
Args:
638-
accept:
639-
"""
640-
self.accept = accept
641-
642-
def __call__(self, stream, content_type):
643-
"""
644-
Args:
645-
stream:
646-
content_type:
647-
"""
648-
return (stream, content_type)
649-
650-
651626
class _JsonSerializer(object):
652627
"""Placeholder docstring"""
653628

tests/unit/sagemaker/test_deserializers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import io
1616

17-
from sagemaker.deserializers import StringDeserializer, BytesDeserializer
17+
from sagemaker.deserializers import StringDeserializer, BytesDeserializer, StreamDeserializer
1818

1919

2020
def test_string_deserializer():
@@ -31,3 +31,16 @@ def test_bytes_deserializer():
3131
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
3232

3333
assert result == b"[1, 2, 3]"
34+
35+
36+
def test_stream_deserializer():
37+
deserializer = StreamDeserializer()
38+
39+
stream, content_type = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
40+
try:
41+
result = stream.read()
42+
finally:
43+
stream.close()
44+
45+
assert result == b"[1, 2, 3]"
46+
assert content_type == "application/json"

tests/unit/test_predictor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
json_deserializer,
2727
csv_serializer,
2828
csv_deserializer,
29-
StreamDeserializer,
3029
numpy_deserializer,
3130
npy_serializer,
3231
_NumpyDeserializer,
@@ -182,13 +181,6 @@ def test_json_deserializer_invalid_data():
182181
assert "column" in str(error)
183182

184183

185-
def test_stream_deserializer():
186-
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
187-
result = stream.read()
188-
assert result == b"[1, 2, 3]"
189-
assert content_type == "application/json"
190-
191-
192184
def test_npy_serializer_python_array():
193185
array = [1, 2, 3]
194186
result = npy_serializer(array)

0 commit comments

Comments
 (0)