|
21 | 21 | import numpy as np
|
22 | 22 |
|
23 | 23 | from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
|
| 24 | +from sagemaker.deserializers import BaseDeserializer |
24 | 25 | from sagemaker.model_monitor import DataCaptureConfig
|
| 26 | +from sagemaker.serializers import BaseSerializer |
25 | 27 | from sagemaker.session import production_variant, Session
|
26 | 28 | from sagemaker.utils import name_from_base
|
27 | 29 |
|
@@ -59,27 +61,23 @@ def __init__(
|
59 | 61 | object, used for SageMaker interactions (default: None). If not
|
60 | 62 | specified, one is created using the default AWS configuration
|
61 | 63 | chain.
|
62 |
| - serializer (callable): Accepts a single argument, the input data, |
63 |
| - and returns a sequence of bytes. It may provide a |
64 |
| - ``content_type`` attribute that defines the endpoint request |
65 |
| - content type. If not specified, a sequence of bytes is expected |
66 |
| - for the data. |
67 |
| - deserializer (callable): Accepts two arguments, the result data and |
68 |
| - the response content type, and returns a sequence of bytes. It |
69 |
| - may provide a ``content_type`` attribute that defines the |
70 |
| - endpoint response's "Accept" content type. If not specified, a |
71 |
| - sequence of bytes is expected for the data. |
| 64 | + serializer (sagemaker.serializers.BaseSerializer): A serializer |
| 65 | + object, used to encode data for an inference endpoint |
| 66 | + (default: None). |
| 67 | + deserializer (sagemaker.deserializers.BaseDeserializer): A |
| 68 | + deserializer object, used to decode data from an inference |
| 69 | + endpoint (default: None). |
72 | 70 | content_type (str): The invocation's "ContentType", overriding any
|
73 |
| - ``content_type`` from the serializer (default: None). |
| 71 | + ``CONTENT_TYPE`` from the serializer (default: None). |
74 | 72 | accept (str): The invocation's "Accept", overriding any accept from
|
75 | 73 | the deserializer (default: None).
|
76 | 74 | """
|
77 | 75 | self.endpoint_name = endpoint_name
|
78 | 76 | self.sagemaker_session = sagemaker_session or Session()
|
79 | 77 | self.serializer = serializer
|
80 | 78 | self.deserializer = deserializer
|
81 |
| - self.content_type = content_type or getattr(serializer, "content_type", None) |
82 |
| - self.accept = accept or getattr(deserializer, "accept", None) |
| 79 | + self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None) |
| 80 | + self.accept = accept or getattr(deserializer, "ACCEPT", None) |
83 | 81 | self._endpoint_config_name = self._get_endpoint_config_name()
|
84 | 82 | self._model_names = self._get_model_names()
|
85 | 83 |
|
@@ -121,7 +119,7 @@ def _handle_response(self, response):
|
121 | 119 | response_body = response["Body"]
|
122 | 120 | if self.deserializer is not None:
|
123 | 121 | # It's the deserializer's responsibility to close the stream
|
124 |
| - return self.deserializer(response_body, response["ContentType"]) |
| 122 | + return self.deserializer.deserialize(response_body, response["ContentType"]) |
125 | 123 | data = response_body.read()
|
126 | 124 | response_body.close()
|
127 | 125 | return data
|
@@ -152,7 +150,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
|
152 | 150 | args["TargetVariant"] = target_variant
|
153 | 151 |
|
154 | 152 | if self.serializer is not None:
|
155 |
| - data = self.serializer(data) |
| 153 | + data = self.serializer.serialize(data) |
156 | 154 |
|
157 | 155 | args["Body"] = data
|
158 | 156 | return args
|
@@ -406,6 +404,88 @@ def _get_model_names(self):
|
406 | 404 | return [d["ModelName"] for d in production_variants]
|
407 | 405 |
|
408 | 406 |
|
| 407 | +class LegacySerializer(BaseSerializer): |
| 408 | + """Wrapper that makes legacy serializers forward compatibile.""" |
| 409 | + |
| 410 | + def __init__(self, serializer): |
| 411 | + """Placeholder docstring. |
| 412 | +
|
| 413 | + Args: |
| 414 | + serializer (callable): A legacy serializer. |
| 415 | + """ |
| 416 | + self.serializer = serializer |
| 417 | + self.content_type = self.serializer.content_type |
| 418 | + |
| 419 | + def __call__(self, *args, **kwargs): |
| 420 | + """Wraps the call method of the legacy serializer. |
| 421 | +
|
| 422 | + Args: |
| 423 | + data (object): Data to be serialized. |
| 424 | +
|
| 425 | + Returns: |
| 426 | + object: Serialized data used for a request. |
| 427 | + """ |
| 428 | + return self.serializer(*args, **kwargs) |
| 429 | + |
| 430 | + def serialize(self, data): |
| 431 | + """Wraps the call method of the legacy serializer. |
| 432 | +
|
| 433 | + Args: |
| 434 | + data (object): Data to be serialized. |
| 435 | +
|
| 436 | + Returns: |
| 437 | + object: Serialized data used for a request. |
| 438 | + """ |
| 439 | + return self.serializer(data) |
| 440 | + |
| 441 | + @property |
| 442 | + def CONTENT_TYPE(self): |
| 443 | + """The MIME type of the data sent to the inference endpoint.""" |
| 444 | + return self.content_type |
| 445 | + |
| 446 | + |
| 447 | +class LegacyDeserializer(BaseDeserializer): |
| 448 | + """Wrapper that makes legacy deserializers forward compatibile.""" |
| 449 | + |
| 450 | + def __init__(self, deserializer): |
| 451 | + """Placeholder docstring. |
| 452 | +
|
| 453 | + Args: |
| 454 | + deserializer (callable): A legacy deserializer. |
| 455 | + """ |
| 456 | + self.deserializer = deserializer |
| 457 | + self.accept = deserializer.accept |
| 458 | + |
| 459 | + def __call__(self, *args, **kwargs): |
| 460 | + """Wraps the call method of the legacy deserializer. |
| 461 | +
|
| 462 | + Args: |
| 463 | + data (object): Data to be deserialized. |
| 464 | + content_type (str): The MIME type of the data. |
| 465 | +
|
| 466 | + Returns: |
| 467 | + object: The data deserialized into an object. |
| 468 | + """ |
| 469 | + return self.deserializer(*args, **kwargs) |
| 470 | + |
| 471 | + def deserialize(self, data, content_type): |
| 472 | + """Wraps the call method of the legacy deserializer. |
| 473 | +
|
| 474 | + Args: |
| 475 | + data (object): Data to be deserialized. |
| 476 | + content_type (str): The MIME type of the data. |
| 477 | +
|
| 478 | + Returns: |
| 479 | + object: The data deserialized into an object. |
| 480 | + """ |
| 481 | + return self.deserializer(data, content_type) |
| 482 | + |
| 483 | + @property |
| 484 | + def ACCEPT(self): |
| 485 | + """The content type that is expected from the inference endpoint.""" |
| 486 | + return self.accept |
| 487 | + |
| 488 | + |
409 | 489 | class _CsvSerializer(object):
|
410 | 490 | """Placeholder docstring"""
|
411 | 491 |
|
@@ -479,7 +559,7 @@ def _csv_serialize_object(data):
|
479 | 559 | return csv_buffer.getvalue().rstrip("\r\n")
|
480 | 560 |
|
481 | 561 |
|
482 |
| -csv_serializer = _CsvSerializer() |
| 562 | +csv_serializer = LegacySerializer(_CsvSerializer()) |
483 | 563 |
|
484 | 564 |
|
485 | 565 | def _is_mutable_sequence_like(obj):
|
@@ -531,7 +611,7 @@ def __call__(self, stream, content_type):
|
531 | 611 | stream.close()
|
532 | 612 |
|
533 | 613 |
|
534 |
| -csv_deserializer = _CsvDeserializer() |
| 614 | +csv_deserializer = LegacyDeserializer(_CsvDeserializer()) |
535 | 615 |
|
536 | 616 |
|
537 | 617 | class BytesDeserializer(object):
|
@@ -644,7 +724,7 @@ def __call__(self, data):
|
644 | 724 | return json.dumps(_ndarray_to_list(data))
|
645 | 725 |
|
646 | 726 |
|
647 |
| -json_serializer = _JsonSerializer() |
| 727 | +json_serializer = LegacySerializer(_JsonSerializer()) |
648 | 728 |
|
649 | 729 |
|
650 | 730 | def _ndarray_to_list(data):
|
@@ -686,7 +766,7 @@ def __call__(self, stream, content_type):
|
686 | 766 | stream.close()
|
687 | 767 |
|
688 | 768 |
|
689 |
| -json_deserializer = _JsonDeserializer() |
| 769 | +json_deserializer = LegacyDeserializer(_JsonDeserializer()) |
690 | 770 |
|
691 | 771 |
|
692 | 772 | class _NumpyDeserializer(object):
|
@@ -730,7 +810,7 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
|
730 | 810 | )
|
731 | 811 |
|
732 | 812 |
|
733 |
| -numpy_deserializer = _NumpyDeserializer() |
| 813 | +numpy_deserializer = LegacyDeserializer(_NumpyDeserializer()) |
734 | 814 |
|
735 | 815 |
|
736 | 816 | class _NPYSerializer(object):
|
@@ -778,4 +858,4 @@ def _npy_serialize(data):
|
778 | 858 | return buffer.getvalue()
|
779 | 859 |
|
780 | 860 |
|
781 |
| -npy_serializer = _NPYSerializer() |
| 861 | +npy_serializer = LegacySerializer(_NPYSerializer()) |
0 commit comments