Skip to content

Commit 1dc5157

Browse files
author
Balaji Veeramani
committed
Refactor code
1 parent 8021c8b commit 1dc5157

File tree

4 files changed

+10
-53
lines changed

4 files changed

+10
-53
lines changed

src/sagemaker/deserializers.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import abc
1717

18-
from sagemaker.utils import parse_mime_type
19-
2018

2119
class BaseDeserializer(abc.ABC):
2220
"""Abstract base class for creation of new deserializers.
@@ -43,35 +41,30 @@ def ACCEPT(self):
4341
"""The content type that is expected from the inference endpoint."""
4442

4543

46-
class StringDeserializer(object):
44+
class StringDeserializer(BaseDeserializer):
4745
"""Deserialize data from an inference endpoint into a decoded string."""
4846

47+
ACCEPT = "application/json"
48+
4949
def __init__(self, encoding="UTF-8"):
50-
"""Initialize the default encoding.
50+
"""Initialize the string encoding.
5151
5252
Args:
53-
encoding (str): The string encoding to use, if a charset is not
54-
provided by the server (default: UTF-8).
53+
encoding (str): The string encoding to use (default: UTF-8).
5554
"""
5655
self.encoding = encoding
5756

5857
def deserialize(self, data, content_type):
5958
"""Deserialize data from an inference endpoint into a decoded string.
6059
6160
Args:
62-
data (object): A string or a byte stream.
61+
data (object): Data to be deserialized.
6362
content_type (str): The MIME type of the data.
6463
6564
Returns:
6665
str: The data deserialized into a decoded string.
6766
"""
68-
category, _, parameters = parse_mime_type(content_type)
69-
70-
if category == "text":
71-
return data
72-
7367
try:
74-
encoding = parameters.get("charset", self.encoding)
75-
return data.read().decode(encoding)
68+
return data.read().decode(self.encoding)
7669
finally:
7770
data.close()

src/sagemaker/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -745,23 +745,3 @@ def _module_import_error(py_module, feature, extras):
745745
"to install all required dependencies."
746746
)
747747
return error_msg.format(py_module, feature, extras)
748-
749-
750-
def parse_mime_type(mime_type):
751-
"""Parse a MIME type and return the type, subtype, and parameters.
752-
753-
Args:
754-
mime_type (str): A MIME type.
755-
756-
Returns:
757-
tuple: A three-tuple containing the type, subtype, and parameters. The
758-
type and subtype are strings, and the parameters are stored in a
759-
dictionary.
760-
"""
761-
category, remaining = mime_type.split("/")
762-
subtype = remaining.split(";")[0]
763-
parameters = {}
764-
for parameter in remaining.split(";")[1:]:
765-
attribute, value = parameter.split("=")
766-
parameters[attribute] = value
767-
return category, subtype, parameters

tests/unit/sagemaker/test_deserializers.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,9 @@
1717
from sagemaker.deserializers import StringDeserializer
1818

1919

20-
def test_string_deserializer_plain_text():
20+
def test_string_deserializer():
2121
deserializer = StringDeserializer()
2222

23-
result = deserializer.deserialize("Hello, world!", "text/plain")
23+
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
2424

25-
assert result == "Hello, world!"
26-
27-
28-
def test_string_deserializer_octet_stream():
29-
deserializer = StringDeserializer()
30-
31-
result = deserializer.deserialize(io.BytesIO(b"Hello, world!"), "application/octet-stream")
32-
33-
assert result == "Hello, world!"
25+
assert result == "[1, 2, 3]"

tests/unit/test_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -779,11 +779,3 @@ def test_partition_by_region():
779779
assert sagemaker.utils._aws_partition("us-gov-east-1") == "aws-us-gov"
780780
assert sagemaker.utils._aws_partition("us-iso-east-1") == "aws-iso"
781781
assert sagemaker.utils._aws_partition("us-isob-east-1") == "aws-iso-b"
782-
783-
784-
def test_parse_mime_type():
785-
mime_type = "application/octet-stream;charset=UTF-8"
786-
category, subtype, parameters = sagemaker.utils.parse_mime_type(mime_type)
787-
assert category == "application"
788-
assert subtype == "octet-stream"
789-
assert parameters == {"charset": "UTF-8"}

0 commit comments

Comments
 (0)