|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import io |
15 | 16 | import json
|
16 | 17 |
|
17 | 18 | import pytest
|
18 | 19 | from mock import Mock, call, patch
|
19 | 20 |
|
20 |
| -from sagemaker.deserializers import CSVDeserializer, StringDeserializer |
| 21 | +from sagemaker.deserializers import CSVDeserializer, PandasDeserializer |
21 | 22 | from sagemaker.predictor import Predictor
|
22 | 23 | from sagemaker.serializers import JSONSerializer, CSVSerializer
|
23 | 24 |
|
@@ -169,9 +170,7 @@ def ret_csv_sagemaker_session():
|
169 | 170 | ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
|
170 | 171 | ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
|
171 | 172 |
|
172 |
| - response_body = Mock("body") |
173 |
| - response_body.read = Mock("read", return_value=bytes(CSV_RETURN_VALUE, "utf-8")) |
174 |
| - response_body.close = Mock("close", return_value=None) |
| 173 | + response_body = io.BytesIO(bytes(CSV_RETURN_VALUE, "utf-8")) |
175 | 174 | ims.sagemaker_runtime_client.invoke_endpoint = Mock(
|
176 | 175 | name="invoke_endpoint",
|
177 | 176 | return_value={"Body": response_body, "ContentType": CSV_CONTENT_TYPE},
|
@@ -205,25 +204,23 @@ def test_predict_call_with_csv():
|
205 | 204 | def test_predict_call_with_multiple_accept_types():
|
206 | 205 | sagemaker_session = ret_csv_sagemaker_session()
|
207 | 206 | predictor = Predictor(
|
208 |
| - ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer() |
| 207 | + ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=PandasDeserializer() |
209 | 208 | )
|
210 | 209 |
|
211 | 210 | data = [1, 2]
|
212 |
| - result = predictor.predict(data) |
| 211 | + predictor.predict(data) |
213 | 212 |
|
214 | 213 | assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
|
215 | 214 |
|
216 | 215 | expected_request_args = {
|
217 |
| - "Accept": "application/json", |
| 216 | + "Accept": "text/csv, application/json", |
218 | 217 | "Body": "1,2",
|
219 | 218 | "ContentType": CSV_CONTENT_TYPE,
|
220 | 219 | "EndpointName": ENDPOINT,
|
221 | 220 | }
|
222 | 221 | call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
|
223 | 222 | assert kwargs == expected_request_args
|
224 | 223 |
|
225 |
| - assert result == "1,2,3\r\n" |
226 |
| - |
227 | 224 |
|
228 | 225 | @patch("sagemaker.predictor.name_from_base")
|
229 | 226 | def test_update_endpoint_no_args(name_from_base):
|
|
0 commit comments