Skip to content

Commit bfdd6aa

Browse files
author
Balaji Veeramani
committed
Update test_predictor.py
1 parent a249739 commit bfdd6aa

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

tests/unit/test_predictor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import io
1516
import json
1617

1718
import pytest
1819
from mock import Mock, call, patch
1920

20-
from sagemaker.deserializers import CSVDeserializer, StringDeserializer
21+
from sagemaker.deserializers import CSVDeserializer, PandasDeserializer
2122
from sagemaker.predictor import Predictor
2223
from sagemaker.serializers import JSONSerializer, CSVSerializer
2324

@@ -169,9 +170,7 @@ def ret_csv_sagemaker_session():
169170
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
170171
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
171172

172-
response_body = Mock("body")
173-
response_body.read = Mock("read", return_value=bytes(CSV_RETURN_VALUE, "utf-8"))
174-
response_body.close = Mock("close", return_value=None)
173+
response_body = io.BytesIO(bytes(CSV_RETURN_VALUE, "utf-8"))
175174
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
176175
name="invoke_endpoint",
177176
return_value={"Body": response_body, "ContentType": CSV_CONTENT_TYPE},
@@ -205,25 +204,23 @@ def test_predict_call_with_csv():
205204
def test_predict_call_with_multiple_accept_types():
206205
sagemaker_session = ret_csv_sagemaker_session()
207206
predictor = Predictor(
208-
ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer()
207+
ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=PandasDeserializer()
209208
)
210209

211210
data = [1, 2]
212-
result = predictor.predict(data)
211+
predictor.predict(data)
213212

214213
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
215214

216215
expected_request_args = {
217-
"Accept": "application/json",
216+
"Accept": "text/csv, application/json",
218217
"Body": "1,2",
219218
"ContentType": CSV_CONTENT_TYPE,
220219
"EndpointName": ENDPOINT,
221220
}
222221
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
223222
assert kwargs == expected_request_args
224223

225-
assert result == "1,2,3\r\n"
226-
227224

228225
@patch("sagemaker.predictor.name_from_base")
229226
def test_update_endpoint_no_args(name_from_base):

0 commit comments

Comments
 (0)