@@ -72,6 +72,11 @@ def __init__(
72
72
accept (str): The invocation's "Accept", overriding any accept from
73
73
the deserializer (default: None).
74
74
"""
75
+ if serializer is not None and not isinstance (serializer , BaseSerializer ):
76
+ serializer = LegacySerializer (serializer )
77
+ if deserializer is not None and not isinstance (deserializer , BaseDeserializer ):
78
+ deserializer = LegacyDeserializer (deserializer )
79
+
75
80
self .endpoint_name = endpoint_name
76
81
self .sagemaker_session = sagemaker_session or Session ()
77
82
self .serializer = serializer
@@ -414,7 +419,7 @@ def __init__(self, serializer):
414
419
serializer (callable): A legacy serializer.
415
420
"""
416
421
self .serializer = serializer
417
- self .content_type = self . serializer . content_type
422
+ self .content_type = getattr ( serializer , " content_type" , None )
418
423
419
424
def __call__ (self , * args , ** kwargs ):
420
425
"""Wraps the call method of the legacy serializer.
@@ -454,7 +459,7 @@ def __init__(self, deserializer):
454
459
deserializer (callable): A legacy deserializer.
455
460
"""
456
461
self .deserializer = deserializer
457
- self .accept = deserializer . accept
462
+ self .accept = getattr ( deserializer , " accept" , None )
458
463
459
464
def __call__ (self , * args , ** kwargs ):
460
465
"""Wraps the call method of the legacy deserializer.
@@ -559,7 +564,7 @@ def _csv_serialize_object(data):
559
564
return csv_buffer .getvalue ().rstrip ("\r \n " )
560
565
561
566
562
- csv_serializer = LegacySerializer ( _CsvSerializer () )
567
+ csv_serializer = _CsvSerializer ()
563
568
564
569
565
570
def _is_mutable_sequence_like (obj ):
@@ -611,7 +616,7 @@ def __call__(self, stream, content_type):
611
616
stream .close ()
612
617
613
618
614
- csv_deserializer = LegacyDeserializer ( _CsvDeserializer () )
619
+ csv_deserializer = _CsvDeserializer ()
615
620
616
621
617
622
class BytesDeserializer (object ):
@@ -724,7 +729,7 @@ def __call__(self, data):
724
729
return json .dumps (_ndarray_to_list (data ))
725
730
726
731
727
- json_serializer = LegacySerializer ( _JsonSerializer () )
732
+ json_serializer = _JsonSerializer ()
728
733
729
734
730
735
def _ndarray_to_list (data ):
@@ -766,7 +771,7 @@ def __call__(self, stream, content_type):
766
771
stream .close ()
767
772
768
773
769
- json_deserializer = LegacyDeserializer ( _JsonDeserializer () )
774
+ json_deserializer = _JsonDeserializer ()
770
775
771
776
772
777
class _NumpyDeserializer (object ):
@@ -810,7 +815,7 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
810
815
)
811
816
812
817
813
- numpy_deserializer = LegacyDeserializer ( _NumpyDeserializer () )
818
+ numpy_deserializer = _NumpyDeserializer ()
814
819
815
820
816
821
class _NPYSerializer (object ):
@@ -858,4 +863,4 @@ def _npy_serialize(data):
858
863
return buffer .getvalue ()
859
864
860
865
861
- npy_serializer = LegacySerializer ( _NPYSerializer () )
866
+ npy_serializer = _NPYSerializer ()
0 commit comments