Skip to content

Commit ae974f8

Browse files
authored
Merge branch 'zwei' into remove-legacy-serde
2 parents a023a23 + d46bd52 commit ae974f8

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

src/sagemaker/deserializers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222

2323
import numpy as np
2424

25+
from sagemaker.utils import DeferredError
26+
27+
try:
28+
import pandas
29+
except ImportError as e:
30+
pandas = DeferredError(e)
31+
2532

2633
class BaseDeserializer(abc.ABC):
2734
"""Abstract base class for creation of new deserializers.
@@ -208,3 +215,31 @@ def deserialize(self, stream, content_type):
208215
return json.load(codecs.getreader("utf-8")(stream))
209216
finally:
210217
stream.close()
218+
219+
220+
class PandasDeserializer(BaseDeserializer):
221+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
222+
223+
ACCEPT = "text/csv"
224+
225+
def deserialize(self, stream, content_type):
226+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
227+
dataframe.
228+
229+
If the data is JSON, the data should be formatted in the 'columns' orient.
230+
See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_json.html
231+
232+
Args:
233+
stream (botocore.response.StreamingBody): Data to be deserialized.
234+
content_type (str): The MIME type of the data.
235+
236+
Returns:
237+
pandas.DataFrame: The data deserialized into a pandas DataFrame.
238+
"""
239+
if content_type == "text/csv":
240+
return pandas.read_csv(stream)
241+
242+
if content_type == "application/json":
243+
return pandas.read_json(stream)
244+
245+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

tests/unit/sagemaker/test_deserializers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
from __future__ import absolute_import
1414

1515
import io
16+
import json
1617

1718
import numpy as np
19+
import pandas as pd
1820
import pytest
1921

2022
from sagemaker.deserializers import (
@@ -24,6 +26,7 @@
2426
StreamDeserializer,
2527
NumpyDeserializer,
2628
JSONDeserializer,
29+
PandasDeserializer,
2730
)
2831

2932

@@ -171,3 +174,25 @@ def test_json_deserializer_invalid_data(json_deserializer):
171174
with pytest.raises(ValueError) as error:
172175
json_deserializer.deserialize(io.BytesIO(b"[[1]"), "application/json")
173176
assert "column" in str(error)
177+
178+
179+
@pytest.fixture
180+
def pandas_deserializer():
181+
return PandasDeserializer()
182+
183+
184+
def test_pandas_deserializer_json(pandas_deserializer):
185+
data = {"col 1": {"row 1": "a", "row 2": "c"}, "col 2": {"row 1": "b", "row 2": "d"}}
186+
stream = io.StringIO(json.dumps(data))
187+
result = pandas_deserializer.deserialize(stream, "application/json")
188+
expected = pd.DataFrame(
189+
[["a", "b"], ["c", "d"]], index=["row 1", "row 2"], columns=["col 1", "col 2"]
190+
)
191+
assert result.equals(expected)
192+
193+
194+
def test_pandas_deserializer_csv(pandas_deserializer):
195+
stream = io.StringIO("col 1,col 2\na,b\nc,d")
196+
result = pandas_deserializer.deserialize(stream, "text/csv")
197+
expected = pd.DataFrame([["a", "b"], ["c", "d"]], columns=["col 1", "col 2"])
198+
assert result.equals(expected)

0 commit comments

Comments
 (0)