Skip to content

Commit da1fcb8

Browse files
authored
Merge branch 'zwei' into add-pandas-deserializer
2 parents 58903f9 + d4f7ce8 commit da1fcb8

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

src/sagemaker/deserializers.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ class BaseDeserializer(abc.ABC):
3838
"""
3939

4040
@abc.abstractmethod
41-
def deserialize(self, data, content_type):
41+
def deserialize(self, stream, content_type):
4242
"""Deserialize data received from an inference endpoint.
4343
4444
Args:
45-
data (object): Data to be deserialized.
45+
stream (botocore.response.StreamingBody): Data to be deserialized.
4646
content_type (str): The MIME type of the data.
4747
4848
Returns:
@@ -68,41 +68,41 @@ def __init__(self, encoding="UTF-8"):
6868
"""
6969
self.encoding = encoding
7070

71-
def deserialize(self, data, content_type):
71+
def deserialize(self, stream, content_type):
7272
"""Deserialize data from an inference endpoint into a decoded string.
7373
7474
Args:
75-
data (object): Data to be deserialized.
75+
stream (botocore.response.StreamingBody): Data to be deserialized.
7676
content_type (str): The MIME type of the data.
7777
7878
Returns:
7979
str: The data deserialized into a decoded string.
8080
"""
8181
try:
82-
return data.read().decode(self.encoding)
82+
return stream.read().decode(self.encoding)
8383
finally:
84-
data.close()
84+
stream.close()
8585

8686

8787
class BytesDeserializer(BaseDeserializer):
8888
"""Deserialize a stream of bytes into a bytes object."""
8989

9090
ACCEPT = "*/*"
9191

92-
def deserialize(self, data, content_type):
92+
def deserialize(self, stream, content_type):
9393
"""Read a stream of bytes returned from an inference endpoint.
9494
9595
Args:
96-
data (object): A stream of bytes.
96+
stream (botocore.response.StreamingBody): A stream of bytes.
9797
content_type (str): The MIME type of the data.
9898
9999
Returns:
100100
bytes: The bytes object read from the stream.
101101
"""
102102
try:
103-
return data.read()
103+
return stream.read()
104104
finally:
105-
data.close()
105+
stream.close()
106106

107107

108108
class CSVDeserializer(BaseDeserializer):
@@ -118,22 +118,22 @@ def __init__(self, encoding="utf-8"):
118118
"""
119119
self.encoding = encoding
120120

121-
def deserialize(self, data, content_type):
121+
def deserialize(self, stream, content_type):
122122
"""Deserialize data from an inference endpoint into a list of lists.
123123
124124
Args:
125-
data (botocore.response.StreamingBody): Data to be deserialized.
125+
stream (botocore.response.StreamingBody): Data to be deserialized.
126126
content_type (str): The MIME type of the data.
127127
128128
Returns:
129129
list: The data deserialized into a list of lists representing the
130130
contents of a CSV file.
131131
"""
132132
try:
133-
decoded_string = data.read().decode(self.encoding)
133+
decoded_string = stream.read().decode(self.encoding)
134134
return list(csv.reader(decoded_string.splitlines()))
135135
finally:
136-
data.close()
136+
stream.close()
137137

138138

139139
class StreamDeserializer(BaseDeserializer):
@@ -145,17 +145,17 @@ class StreamDeserializer(BaseDeserializer):
145145

146146
ACCEPT = "*/*"
147147

148-
def deserialize(self, data, content_type):
148+
def deserialize(self, stream, content_type):
149149
"""Returns a stream of the response body and the MIME type of the data.
150150
151151
Args:
152-
data (object): A stream of bytes.
152+
stream (botocore.response.StreamingBody): A stream of bytes.
153153
content_type (str): The MIME type of the data.
154154
155155
Returns:
156156
tuple: A two-tuple containing the stream and content-type.
157157
"""
158-
return data, content_type
158+
return stream, content_type
159159

160160

161161
class NumpyDeserializer(BaseDeserializer):
@@ -171,11 +171,11 @@ def __init__(self, dtype=None):
171171
"""
172172
self.dtype = dtype
173173

174-
def deserialize(self, data, content_type):
174+
def deserialize(self, stream, content_type):
175175
"""Deserialize data from an inference endpoint into a NumPy array.
176176
177177
Args:
178-
data (botocore.response.StreamingBody): Data to be deserialized.
178+
stream (botocore.response.StreamingBody): Data to be deserialized.
179179
content_type (str): The MIME type of the data.
180180
181181
Returns:
@@ -184,14 +184,14 @@ def deserialize(self, data, content_type):
184184
try:
185185
if content_type == "text/csv":
186186
return np.genfromtxt(
187-
codecs.getreader("utf-8")(data), delimiter=",", dtype=self.dtype
187+
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
188188
)
189189
if content_type == "application/json":
190-
return np.array(json.load(codecs.getreader("utf-8")(data)), dtype=self.dtype)
190+
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
191191
if content_type == "application/x-npy":
192-
return np.load(io.BytesIO(data.read()))
192+
return np.load(io.BytesIO(stream.read()))
193193
finally:
194-
data.close()
194+
stream.close()
195195

196196
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
197197

@@ -201,45 +201,45 @@ class JSONDeserializer(BaseDeserializer):
201201

202202
ACCEPT = "application/json"
203203

204-
def deserialize(self, data, content_type):
204+
def deserialize(self, stream, content_type):
205205
"""Deserialize JSON data from an inference endpoint into a Python object.
206206
207207
Args:
208-
data (botocore.response.StreamingBody): Data to be deserialized.
208+
stream (botocore.response.StreamingBody): Data to be deserialized.
209209
content_type (str): The MIME type of the data.
210210
211211
Returns:
212212
object: The JSON-formatted data deserialized into a Python object.
213213
"""
214214
try:
215-
return json.load(codecs.getreader("utf-8")(data))
215+
return json.load(codecs.getreader("utf-8")(stream))
216216
finally:
217-
data.close()
217+
stream.close()
218218

219219

220220
class PandasDeserializer(BaseDeserializer):
221221
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
222222

223223
ACCEPT = "text/csv"
224224

225-
def deserialize(self, data, content_type):
225+
def deserialize(self, stream, content_type):
226226
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
227227
dataframe.
228228
229229
If the data is JSON, the data should be formatted in the 'columns' orient.
230230
See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_json.html
231231
232232
Args:
233-
data (botocore.response.StreamingBody): Data to be deserialized.
233+
stream (botocore.response.StreamingBody): Data to be deserialized.
234234
content_type (str): The MIME type of the data.
235235
236236
Returns:
237237
pandas.DataFrame: The data deserialized into a pandas DataFrame.
238238
"""
239239
if content_type == "text/csv":
240-
return pandas.read_csv(data)
240+
return pandas.read_csv(stream)
241241

242242
if content_type == "application/json":
243-
return pandas.read_json(data)
243+
return pandas.read_json(stream)
244244

245245
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

0 commit comments

Comments
 (0)