Skip to content

Commit 99773e5

Browse files
authored
change: Support multiple Accept types (#1794)
1 parent 804b713 commit 99773e5

File tree

4 files changed

+54
-20
lines changed

4 files changed

+54
-20
lines changed

src/sagemaker/amazon/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def serialize(self, data):
5858
class RecordDeserializer(BaseDeserializer):
5959
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
6060

61-
ACCEPT = "application/x-recordio-protobuf"
61+
ACCEPT = ("application/x-recordio-protobuf",)
6262

6363
def deserialize(self, data, content_type):
6464
"""Deserialize RecordIO Protobuf data from an inference endpoint.

src/sagemaker/deserializers.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ def deserialize(self, stream, content_type):
5252
@property
5353
@abc.abstractmethod
5454
def ACCEPT(self):
55-
"""The content type that is expected from the inference endpoint."""
55+
"""The content types that are expected from the inference endpoint."""
5656

5757

5858
class StringDeserializer(BaseDeserializer):
5959
"""Deserialize data from an inference endpoint into a decoded string."""
6060

61-
ACCEPT = "application/json"
61+
ACCEPT = ("application/json",)
6262

6363
def __init__(self, encoding="UTF-8"):
6464
"""Initialize the string encoding.
@@ -87,7 +87,7 @@ def deserialize(self, stream, content_type):
8787
class BytesDeserializer(BaseDeserializer):
8888
"""Deserialize a stream of bytes into a bytes object."""
8989

90-
ACCEPT = "*/*"
90+
ACCEPT = ("*/*",)
9191

9292
def deserialize(self, stream, content_type):
9393
"""Read a stream of bytes returned from an inference endpoint.
@@ -108,7 +108,7 @@ def deserialize(self, stream, content_type):
108108
class CSVDeserializer(BaseDeserializer):
109109
"""Deserialize a stream of bytes into a list of lists."""
110110

111-
ACCEPT = "text/csv"
111+
ACCEPT = ("text/csv",)
112112

113113
def __init__(self, encoding="utf-8"):
114114
"""Initialize the string encoding.
@@ -143,7 +143,7 @@ class StreamDeserializer(BaseDeserializer):
143143
reading it.
144144
"""
145145

146-
ACCEPT = "*/*"
146+
ACCEPT = ("*/*",)
147147

148148
def deserialize(self, stream, content_type):
149149
"""Returns a stream of the response body and the MIME type of the data.
@@ -161,16 +161,17 @@ def deserialize(self, stream, content_type):
161161
class NumpyDeserializer(BaseDeserializer):
162162
"""Deserialize a stream of data in the .npy format."""
163163

164-
ACCEPT = "application/x-npy"
165-
166-
def __init__(self, dtype=None, allow_pickle=True):
164+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
167165
"""Initialize the dtype and allow_pickle arguments.
168166
169167
Args:
170168
dtype (str): The dtype of the data (default: None).
169+
accept (str): The MIME type that is expected from the inference
170+
endpoint (default: "application/x-npy").
171171
allow_pickle (bool): Allow loading pickled object arrays (default: True).
172172
"""
173173
self.dtype = dtype
174+
self.accept = accept
174175
self.allow_pickle = allow_pickle
175176

176177
def deserialize(self, stream, content_type):
@@ -197,11 +198,21 @@ def deserialize(self, stream, content_type):
197198

198199
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
199200

201+
@property
202+
def ACCEPT(self):
203+
"""The content types that are expected from the inference endpoint.
204+
205+
To maintain backwards compatability with legacy images, the
206+
NumpyDeserializer supports sending only one content type in the Accept
207+
header.
208+
"""
209+
return (self.accept,)
210+
200211

201212
class JSONDeserializer(BaseDeserializer):
202213
"""Deserialize JSON data from an inference endpoint into a Python object."""
203214

204-
ACCEPT = "application/json"
215+
ACCEPT = ("application/json",)
205216

206217
def deserialize(self, stream, content_type):
207218
"""Deserialize JSON data from an inference endpoint into a Python object.
@@ -222,7 +233,7 @@ def deserialize(self, stream, content_type):
222233
class PandasDeserializer(BaseDeserializer):
223234
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
224235

225-
ACCEPT = "text/csv"
236+
ACCEPT = ("text/csv", "application/json")
226237

227238
def deserialize(self, stream, content_type):
228239
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
@@ -250,7 +261,7 @@ def deserialize(self, stream, content_type):
250261
class JSONLinesDeserializer(BaseDeserializer):
251262
"""Deserialize JSON lines data from an inference endpoint."""
252263

253-
ACCEPT = "application/jsonlines"
264+
ACCEPT = ("application/jsonlines",)
254265

255266
def deserialize(self, stream, content_type):
256267
"""Deserialize JSON lines data from an inference endpoint.

src/sagemaker/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
121121
args["ContentType"] = self.content_type
122122

123123
if "Accept" not in args:
124-
args["Accept"] = self.accept
124+
args["Accept"] = ", ".join(self.accept)
125125

126126
if target_model:
127127
args["TargetModel"] = target_model

tests/unit/test_predictor.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import io
1516
import json
1617

1718
import pytest
1819
from mock import Mock, call, patch
1920

21+
from sagemaker.deserializers import CSVDeserializer, PandasDeserializer
2022
from sagemaker.predictor import Predictor
2123
from sagemaker.serializers import JSONSerializer, CSVSerializer
2224

@@ -132,7 +134,7 @@ def json_sagemaker_session():
132134
response_body.close = Mock("close", return_value=None)
133135
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
134136
name="invoke_endpoint",
135-
return_value={"Body": response_body, "ContentType": DEFAULT_CONTENT_TYPE},
137+
return_value={"Body": response_body, "ContentType": "application/json"},
136138
)
137139
return ims
138140

@@ -168,9 +170,7 @@ def ret_csv_sagemaker_session():
168170
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
169171
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
170172

171-
response_body = Mock("body")
172-
response_body.read = Mock("read", return_value=CSV_RETURN_VALUE)
173-
response_body.close = Mock("close", return_value=None)
173+
response_body = io.BytesIO(bytes(CSV_RETURN_VALUE, "utf-8"))
174174
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
175175
name="invoke_endpoint",
176176
return_value={"Body": response_body, "ContentType": CSV_CONTENT_TYPE},
@@ -180,23 +180,46 @@ def ret_csv_sagemaker_session():
180180

181181
def test_predict_call_with_csv():
182182
sagemaker_session = ret_csv_sagemaker_session()
183-
predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer())
183+
predictor = Predictor(
184+
ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=CSVDeserializer()
185+
)
184186

185187
data = [1, 2]
186188
result = predictor.predict(data)
187189

188190
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
189191

190192
expected_request_args = {
191-
"Accept": DEFAULT_ACCEPT,
193+
"Accept": CSV_CONTENT_TYPE,
192194
"Body": "1,2",
193195
"ContentType": CSV_CONTENT_TYPE,
194196
"EndpointName": ENDPOINT,
195197
}
196198
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
197199
assert kwargs == expected_request_args
198200

199-
assert result == CSV_RETURN_VALUE
201+
assert result == [["1", "2", "3"]]
202+
203+
204+
def test_predict_call_with_multiple_accept_types():
205+
sagemaker_session = ret_csv_sagemaker_session()
206+
predictor = Predictor(
207+
ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=PandasDeserializer()
208+
)
209+
210+
data = [1, 2]
211+
predictor.predict(data)
212+
213+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
214+
215+
expected_request_args = {
216+
"Accept": "text/csv, application/json",
217+
"Body": "1,2",
218+
"ContentType": CSV_CONTENT_TYPE,
219+
"EndpointName": ENDPOINT,
220+
}
221+
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
222+
assert kwargs == expected_request_args
200223

201224

202225
@patch("sagemaker.predictor.name_from_base")

0 commit comments

Comments
 (0)