Skip to content

Commit cb249f8

Browse files
author
Balaji Veeramani
committed
Appease integration tests
1 parent bd909bb commit cb249f8

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/sagemaker/predictor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def __init__(
7272
accept (str): The invocation's "Accept", overriding any accept from
7373
the deserializer (default: None).
7474
"""
75+
if serializer is not None and not isinstance(serializer, BaseSerializer):
76+
serializer = LegacySerializer(serializer)
77+
if deserializer is not None and not isinstance(deserializer, BaseDeserializer):
78+
deserializer = LegacyDeserializer(deserializer)
79+
7580
self.endpoint_name = endpoint_name
7681
self.sagemaker_session = sagemaker_session or Session()
7782
self.serializer = serializer
@@ -414,7 +419,7 @@ def __init__(self, serializer):
414419
serializer (callable): A legacy serializer.
415420
"""
416421
self.serializer = serializer
417-
self.content_type = self.serializer.content_type
422+
self.content_type = getattr(serializer, "content_type", None)
418423

419424
def __call__(self, *args, **kwargs):
420425
"""Wraps the call method of the legacy serializer.
@@ -454,7 +459,7 @@ def __init__(self, deserializer):
454459
deserializer (callable): A legacy deserializer.
455460
"""
456461
self.deserializer = deserializer
457-
self.accept = deserializer.accept
462+
self.accept = getattr(deserializer, "accept", None)
458463

459464
def __call__(self, *args, **kwargs):
460465
"""Wraps the call method of the legacy deserializer.
@@ -559,7 +564,7 @@ def _csv_serialize_object(data):
559564
return csv_buffer.getvalue().rstrip("\r\n")
560565

561566

562-
csv_serializer = LegacySerializer(_CsvSerializer())
567+
csv_serializer = _CsvSerializer()
563568

564569

565570
def _is_mutable_sequence_like(obj):
@@ -611,7 +616,7 @@ def __call__(self, stream, content_type):
611616
stream.close()
612617

613618

614-
csv_deserializer = LegacyDeserializer(_CsvDeserializer())
619+
csv_deserializer = _CsvDeserializer()
615620

616621

617622
class BytesDeserializer(object):
@@ -724,7 +729,7 @@ def __call__(self, data):
724729
return json.dumps(_ndarray_to_list(data))
725730

726731

727-
json_serializer = LegacySerializer(_JsonSerializer())
732+
json_serializer = _JsonSerializer()
728733

729734

730735
def _ndarray_to_list(data):
@@ -766,7 +771,7 @@ def __call__(self, stream, content_type):
766771
stream.close()
767772

768773

769-
json_deserializer = LegacyDeserializer(_JsonDeserializer())
774+
json_deserializer = _JsonDeserializer()
770775

771776

772777
class _NumpyDeserializer(object):
@@ -810,7 +815,7 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
810815
)
811816

812817

813-
numpy_deserializer = LegacyDeserializer(_NumpyDeserializer())
818+
numpy_deserializer = _NumpyDeserializer()
814819

815820

816821
class _NPYSerializer(object):
@@ -858,4 +863,4 @@ def _npy_serialize(data):
858863
return buffer.getvalue()
859864

860865

861-
npy_serializer = LegacySerializer(_NPYSerializer())
866+
npy_serializer = _NPYSerializer()

0 commit comments

Comments
 (0)