Skip to content

Commit 9fc8a46

Browse files
authored
breaking: Move StringDeserializer to sagemaker.deserializers (#1677)
1 parent 4f00176 commit 9fc8a46

File tree

5 files changed

+40
-38
lines changed

5 files changed

+40
-38
lines changed

src/sagemaker/deserializers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,35 @@ def ACCEPT(self):
4141
"""The content type that is expected from the inference endpoint."""
4242

4343

44+
class StringDeserializer(BaseDeserializer):
45+
"""Deserialize data from an inference endpoint into a decoded string."""
46+
47+
ACCEPT = "application/json"
48+
49+
def __init__(self, encoding="UTF-8"):
50+
"""Initialize the string encoding.
51+
52+
Args:
53+
encoding (str): The string encoding to use (default: UTF-8).
54+
"""
55+
self.encoding = encoding
56+
57+
def deserialize(self, data, content_type):
58+
"""Deserialize data from an inference endpoint into a decoded string.
59+
60+
Args:
61+
data (object): Data to be deserialized.
62+
content_type (str): The MIME type of the data.
63+
64+
Returns:
65+
str: The data deserialized into a decoded string.
66+
"""
67+
try:
68+
return data.read().decode(self.encoding)
69+
finally:
70+
data.close()
71+
72+
4473
class BytesDeserializer(BaseDeserializer):
4574
"""Deserialize a stream of bytes into a bytes object."""
4675

src/sagemaker/predictor.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -623,35 +623,6 @@ def __call__(self, stream, content_type):
623623
csv_deserializer = _CsvDeserializer()
624624

625625

626-
class StringDeserializer(object):
627-
"""Return the response as a decoded string.
628-
629-
Args:
630-
encoding (str): The string encoding to use (default=utf-8).
631-
accept (str): The Accept header to send to the server (optional).
632-
"""
633-
634-
def __init__(self, encoding="utf-8", accept=None):
635-
"""
636-
Args:
637-
encoding:
638-
accept:
639-
"""
640-
self.encoding = encoding
641-
self.accept = accept
642-
643-
def __call__(self, stream, content_type):
644-
"""
645-
Args:
646-
stream:
647-
content_type:
648-
"""
649-
try:
650-
return stream.read().decode(self.encoding)
651-
finally:
652-
stream.close()
653-
654-
655626
class StreamDeserializer(object):
656627
"""Returns the tuple of the response stream and the content-type of the response.
657628
It is the receivers responsibility to close the stream when they're done

tests/integ/test_multidatamodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424

2525
from sagemaker import utils
2626
from sagemaker.amazon.randomcutforest import RandomCutForest
27+
from sagemaker.deserializers import StringDeserializer
2728
from sagemaker.multidatamodel import MultiDataModel
2829
from sagemaker.mxnet import MXNet
29-
from sagemaker.predictor import Predictor, StringDeserializer, npy_serializer
30+
from sagemaker.predictor import Predictor, npy_serializer
3031
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
3132
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
3233
from tests.integ.retry import retries

tests/unit/sagemaker/test_deserializers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515
import io
1616

17-
from sagemaker.deserializers import BytesDeserializer
17+
from sagemaker.deserializers import StringDeserializer, BytesDeserializer
18+
19+
20+
def test_string_deserializer():
21+
deserializer = StringDeserializer()
22+
23+
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
24+
25+
assert result == "[1, 2, 3]"
1826

1927

2028
def test_bytes_deserializer():

tests/unit/test_predictor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
json_deserializer,
2727
csv_serializer,
2828
csv_deserializer,
29-
StringDeserializer,
3029
StreamDeserializer,
3130
numpy_deserializer,
3231
npy_serializer,
@@ -183,12 +182,6 @@ def test_json_deserializer_invalid_data():
183182
assert "column" in str(error)
184183

185184

186-
def test_string_deserializer():
187-
result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
188-
189-
assert result == "[1, 2, 3]"
190-
191-
192185
def test_stream_deserializer():
193186
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
194187
result = stream.read()

0 commit comments

Comments
 (0)