Skip to content

Commit 98ddbbe

Browse files
author
Balaji Veeramani
committed
Add multiple Accept types
1 parent 6031035 commit 98ddbbe

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/sagemaker/deserializers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def ACCEPT(self):
5858
class StringDeserializer(BaseDeserializer):
5959
"""Deserialize data from an inference endpoint into a decoded string."""
6060

61-
ACCEPT = "application/json"
61+
ACCEPT = ["application/json", "text/csv"]
6262

6363
def __init__(self, encoding="UTF-8"):
6464
"""Initialize the string encoding.
@@ -87,7 +87,7 @@ def deserialize(self, stream, content_type):
8787
class BytesDeserializer(BaseDeserializer):
8888
"""Deserialize a stream of bytes into a bytes object."""
8989

90-
ACCEPT = "*/*"
90+
ACCEPT = ["*/*"]
9191

9292
def deserialize(self, stream, content_type):
9393
"""Read a stream of bytes returned from an inference endpoint.
@@ -108,7 +108,7 @@ def deserialize(self, stream, content_type):
108108
class CSVDeserializer(BaseDeserializer):
109109
"""Deserialize a stream of bytes into a list of lists."""
110110

111-
ACCEPT = "text/csv"
111+
ACCEPT = ["text/csv"]
112112

113113
def __init__(self, encoding="utf-8"):
114114
"""Initialize the string encoding.
@@ -143,7 +143,7 @@ class StreamDeserializer(BaseDeserializer):
143143
reading it.
144144
"""
145145

146-
ACCEPT = "*/*"
146+
ACCEPT = ["*/*"]
147147

148148
def deserialize(self, stream, content_type):
149149
"""Returns a stream of the response body and the MIME type of the data.
@@ -161,7 +161,7 @@ def deserialize(self, stream, content_type):
161161
class NumpyDeserializer(BaseDeserializer):
162162
"""Deserialize a stream of data in the .npy format."""
163163

164-
ACCEPT = "application/x-npy"
164+
ACCEPT = ["application/x-npy"]
165165

166166
def __init__(self, dtype=None, allow_pickle=True):
167167
"""Initialize the dtype and allow_pickle arguments.
@@ -201,7 +201,7 @@ def deserialize(self, stream, content_type):
201201
class JSONDeserializer(BaseDeserializer):
202202
"""Deserialize JSON data from an inference endpoint into a Python object."""
203203

204-
ACCEPT = "application/json"
204+
ACCEPT = ["application/json"]
205205

206206
def deserialize(self, stream, content_type):
207207
"""Deserialize JSON data from an inference endpoint into a Python object.
@@ -222,7 +222,7 @@ def deserialize(self, stream, content_type):
222222
class PandasDeserializer(BaseDeserializer):
223223
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
224224

225-
ACCEPT = "text/csv"
225+
ACCEPT = ["text/csv", "application/json"]
226226

227227
def deserialize(self, stream, content_type):
228228
"""Deserialize CSV or JSON data from an inference endpoint into a pandas

src/sagemaker/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
131131
args["ContentType"] = self.content_type
132132

133133
if self.accept and "Accept" not in args:
134-
args["Accept"] = self.accept
134+
args["Accept"] = ", ".join(self.accept)
135135

136136
if target_model:
137137
args["TargetModel"] = target_model

0 commit comments

Comments
 (0)