Skip to content

Commit d38f803

Browse files
author
Balaji Veeramani
committed
Add BytesDeserializer
1 parent 1487b22 commit d38f803

File tree

4 files changed

+27
-30
lines changed

4 files changed

+27
-30
lines changed

src/sagemaker/deserializers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,24 @@ def deserialize(self, data, content_type):
3939
@abc.abstractmethod
4040
def ACCEPT(self):
4141
"""The content type that is expected from the inference endpoint."""
42+
43+
44+
class BytesDeserializer(BaseDeserializer):
45+
"""Deserialize a stream of bytes into a bytes object."""
46+
47+
ACCEPT = "application/octet-stream"
48+
49+
def deserialize(self, data, content_type):
50+
"""Read a stream of bytes returned from an inference endpoint.
51+
52+
Args:
53+
data (object): A stream of bytes.
54+
content_type (str): The MIME type of the data.
55+
56+
Returns:
57+
bytes: The bytes object read from the stream.
58+
"""
59+
try:
60+
return data.read()
61+
finally:
62+
data.close()

src/sagemaker/predictor.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -623,32 +623,6 @@ def __call__(self, stream, content_type):
623623
csv_deserializer = _CsvDeserializer()
624624

625625

626-
class BytesDeserializer(object):
627-
"""Return the response as an undecoded array of bytes.
628-
629-
Args:
630-
accept (str): The Accept header to send to the server (optional).
631-
"""
632-
633-
def __init__(self, accept=None):
634-
"""
635-
Args:
636-
accept:
637-
"""
638-
self.accept = accept
639-
640-
def __call__(self, stream, content_type):
641-
"""
642-
Args:
643-
stream:
644-
content_type:
645-
"""
646-
try:
647-
return stream.read()
648-
finally:
649-
stream.close()
650-
651-
652626
class StringDeserializer(object):
653627
"""Return the response as a decoded string.
654628

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ def test_import_from_node_should_be_modified_random_import():
117117
def test_import_from_modify_node():
118118
modifier = predictors.PredictorImportFromRenamer()
119119

120-
node = ast_import("from sagemaker.predictor import BytesDeserializer, RealTimePredictor")
120+
node = ast_import(
121+
"from sagemaker.predictor import ClassThatHasntBeenRenamed, RealTimePredictor"
122+
)
121123
modifier.modify_node(node)
122-
expected_result = "from sagemaker.predictor import BytesDeserializer, Predictor"
124+
expected_result = "from sagemaker.predictor import ClassThatHasntBeenRenamed, Predictor"
123125
assert expected_result == pasta.dump(node)
124126

125127
node = ast_import("from sagemaker.predictor import RealTimePredictor as RTP")

tests/unit/test_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
import pytest
2121
from mock import Mock, call, patch
2222

23+
from sagemaker.deserializers import BytesDeserializer
2324
from sagemaker.predictor import Predictor
2425
from sagemaker.predictor import (
2526
json_serializer,
2627
json_deserializer,
2728
csv_serializer,
2829
csv_deserializer,
29-
BytesDeserializer,
3030
StringDeserializer,
3131
StreamDeserializer,
3232
numpy_deserializer,
@@ -185,7 +185,7 @@ def test_json_deserializer_invalid_data():
185185

186186

187187
def test_bytes_deserializer():
188-
result = BytesDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
188+
result = BytesDeserializer().deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
189189

190190
assert result == b"[1, 2, 3]"
191191

0 commit comments

Comments
 (0)