Skip to content

Commit db562fa

Browse files
author
Balaji Veeramani
committed
Add CSV serializer
1 parent 4ffa222 commit db562fa

File tree

12 files changed

+218
-194
lines changed

12 files changed

+218
-194
lines changed

doc/frameworks/tensorflow/deploying_tensorflow_serving.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ your input data to CSV format:
236236
237237
# create a Predictor with JSON serialization
238238
239-
predictor = Predictor('endpoint-name', serializer=sagemaker.predictor.csv_serializer)
239+
predictor = Predictor('endpoint-name', serializer=sagemaker.serializers.CSVSerializer())
240240
241241
# CSV-formatted string input
242242
input = '1.0,2.0,5.0\n1.0,2.0,5.0\n1.0,2.0,5.0'
@@ -252,7 +252,7 @@ your input data to CSV format:
252252
]
253253
}
254254
255-
You can also use python arrays or numpy arrays as input and let the `csv_serializer` object
255+
You can also use python arrays or numpy arrays as input and let the `CSVSerializer` object
256256
convert them to CSV, but the client-size CSV conversion is more sophisticated than the
257257
CSV parsing on the Endpoint, so if you encounter conversion problems, try using one of the
258258
JSON options instead.

doc/frameworks/tensorflow/using_tf.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ your input data to CSV format:
685685
686686
# create a Predictor with JSON serialization
687687
688-
predictor = Predictor('endpoint-name', serializer=sagemaker.predictor.csv_serializer)
688+
predictor = Predictor('endpoint-name', serializer=sagemaker.serializers.CSVSerializer())
689689
690690
# CSV-formatted string input
691691
input = '1.0,2.0,5.0\n1.0,2.0,5.0\n1.0,2.0,5.0'
@@ -701,7 +701,7 @@ your input data to CSV format:
701701
]
702702
}
703703
704-
You can also use python arrays or numpy arrays as input and let the `csv_serializer` object
704+
You can also use python arrays or numpy arrays as input and let the `CSVSerializer` object
705705
convert them to CSV, but the client-size CSV conversion is more sophisticated than the
706706
CSV parsing on the Endpoint, so if you encounter conversion problems, try using one of the
707707
JSON options instead.

src/sagemaker/amazon/ipinsights.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1717
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1818
from sagemaker.amazon.validation import ge, le
19-
from sagemaker.predictor import Predictor, csv_serializer, json_deserializer
19+
from sagemaker.predictor import Predictor, json_deserializer
2020
from sagemaker.model import Model
21+
from sagemaker.serializers import CSVSerializer
2122
from sagemaker.session import Session
2223
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2324

@@ -197,7 +198,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
197198
super(IPInsightsPredictor, self).__init__(
198199
endpoint_name,
199200
sagemaker_session,
200-
serializer=csv_serializer,
201+
serializer=CSVSerializer(),
201202
deserializer=json_deserializer,
202203
)
203204

src/sagemaker/predictor.py

Lines changed: 1 addition & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import csv
1818
import json
1919
import six
20-
from six import StringIO, BytesIO
20+
from six import BytesIO
2121
import numpy as np
2222

2323
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
@@ -495,108 +495,6 @@ def ACCEPT(self):
495495
return self.accept
496496

497497

498-
class _CsvSerializer(object):
499-
"""Placeholder docstring"""
500-
501-
def __init__(self):
502-
"""Placeholder docstring"""
503-
self.content_type = CONTENT_TYPE_CSV
504-
505-
def __call__(self, data):
506-
"""Take data of various data formats and serialize them into CSV.
507-
508-
Args:
509-
data (object): Data to be serialized.
510-
511-
Returns:
512-
object: Sequence of bytes to be used for the request body.
513-
"""
514-
# For inputs which represent multiple "rows", the result should be newline-separated CSV
515-
# rows
516-
if _is_mutable_sequence_like(data) and len(data) > 0 and _is_sequence_like(data[0]):
517-
return "\n".join([_CsvSerializer._serialize_row(row) for row in data])
518-
return _CsvSerializer._serialize_row(data)
519-
520-
@staticmethod
521-
def _serialize_row(data):
522-
# Don't attempt to re-serialize a string
523-
"""
524-
Args:
525-
data:
526-
"""
527-
if isinstance(data, str):
528-
return data
529-
if isinstance(data, np.ndarray):
530-
data = np.ndarray.flatten(data)
531-
if hasattr(data, "__len__"):
532-
if len(data) == 0:
533-
raise ValueError("Cannot serialize empty array")
534-
return _csv_serialize_python_array(data)
535-
536-
# files and buffers
537-
if hasattr(data, "read"):
538-
return _csv_serialize_from_buffer(data)
539-
540-
raise ValueError("Unable to handle input format: ", type(data))
541-
542-
543-
def _csv_serialize_python_array(data):
544-
"""
545-
Args:
546-
data:
547-
"""
548-
return _csv_serialize_object(data)
549-
550-
551-
def _csv_serialize_from_buffer(buff):
552-
"""
553-
Args:
554-
buff:
555-
"""
556-
return buff.read()
557-
558-
559-
def _csv_serialize_object(data):
560-
"""
561-
Args:
562-
data:
563-
"""
564-
csv_buffer = StringIO()
565-
566-
csv_writer = csv.writer(csv_buffer, delimiter=",")
567-
csv_writer.writerow(data)
568-
return csv_buffer.getvalue().rstrip("\r\n")
569-
570-
571-
csv_serializer = _CsvSerializer()
572-
573-
574-
def _is_mutable_sequence_like(obj):
575-
"""
576-
Args:
577-
obj:
578-
"""
579-
return _is_sequence_like(obj) and hasattr(obj, "__setitem__")
580-
581-
582-
def _is_sequence_like(obj):
583-
"""
584-
Args:
585-
obj:
586-
"""
587-
return hasattr(obj, "__iter__") and hasattr(obj, "__getitem__")
588-
589-
590-
def _row_to_csv(obj):
591-
"""
592-
Args:
593-
obj:
594-
"""
595-
if isinstance(obj, str):
596-
return obj
597-
return ",".join(obj)
598-
599-
600498
class _CsvDeserializer(object):
601499
"""Placeholder docstring"""
602500

src/sagemaker/serializers.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import csv
18+
import io
19+
20+
import numpy as np
1721

1822

1923
class BaseSerializer(abc.ABC):
@@ -38,3 +42,100 @@ def serialize(self, data):
3842
@abc.abstractmethod
3943
def CONTENT_TYPE(self):
4044
"""The MIME type of the data sent to the inference endpoint."""
45+
46+
47+
class CSVSerializer(BaseSerializer):
48+
"""Placeholder docstring"""
49+
50+
CONTENT_TYPE = "text/csv"
51+
52+
def serialize(self, data):
53+
"""Take data of various data formats and serialize them into CSV.
54+
55+
Args:
56+
data (object): Data to be serialized.
57+
58+
Returns:
59+
object: Sequence of bytes to be used for the request body.
60+
"""
61+
# For inputs which represent multiple "rows", the result should be newline-separated CSV
62+
# rows
63+
if _is_mutable_sequence_like(data) and len(data) > 0 and _is_sequence_like(data[0]):
64+
return "\n".join([CSVSerializer._serialize_row(row) for row in data])
65+
return CSVSerializer._serialize_row(data)
66+
67+
@staticmethod
68+
def _serialize_row(data):
69+
# Don't attempt to re-serialize a string
70+
"""
71+
Args:
72+
data:
73+
"""
74+
if isinstance(data, str):
75+
return data
76+
if isinstance(data, np.ndarray):
77+
data = np.ndarray.flatten(data)
78+
if hasattr(data, "__len__"):
79+
if len(data) == 0:
80+
raise ValueError("Cannot serialize empty array")
81+
return _csv_serialize_python_array(data)
82+
83+
# files and buffers
84+
if hasattr(data, "read"):
85+
return _csv_serialize_from_buffer(data)
86+
87+
raise ValueError("Unable to handle input format: ", type(data))
88+
89+
90+
def _csv_serialize_python_array(data):
91+
"""
92+
Args:
93+
data:
94+
"""
95+
return _csv_serialize_object(data)
96+
97+
98+
def _csv_serialize_from_buffer(buff):
99+
"""
100+
Args:
101+
buff:
102+
"""
103+
return buff.read()
104+
105+
106+
def _csv_serialize_object(data):
107+
"""
108+
Args:
109+
data:
110+
"""
111+
csv_buffer = io.StringIO()
112+
113+
csv_writer = csv.writer(csv_buffer, delimiter=",")
114+
csv_writer.writerow(data)
115+
return csv_buffer.getvalue().rstrip("\r\n")
116+
117+
118+
def _is_mutable_sequence_like(obj):
119+
"""
120+
Args:
121+
obj:
122+
"""
123+
return _is_sequence_like(obj) and hasattr(obj, "__setitem__")
124+
125+
126+
def _is_sequence_like(obj):
127+
"""
128+
Args:
129+
obj:
130+
"""
131+
return hasattr(obj, "__iter__") and hasattr(obj, "__getitem__")
132+
133+
134+
def _row_to_csv(obj):
135+
"""
136+
Args:
137+
obj:
138+
"""
139+
if isinstance(obj, str):
140+
return obj
141+
return ",".join(obj)

src/sagemaker/sparkml/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker import Model, Predictor, Session
1717
from sagemaker.content_types import CONTENT_TYPE_CSV
1818
from sagemaker.fw_registry import registry
19-
from sagemaker.predictor import csv_serializer
19+
from sagemaker.serializers import CSVSerializer
2020

2121
framework_name = "sparkml-serving"
2222
repo_name = "sagemaker-sparkml-serving"
@@ -51,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5151
super(SparkMLPredictor, self).__init__(
5252
endpoint_name=endpoint_name,
5353
sagemaker_session=sagemaker_session,
54-
serializer=csv_serializer,
54+
serializer=CSVSerializer(),
5555
content_type=CONTENT_TYPE_CSV,
5656
)
5757

tests/integ/test_marketplace.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sagemaker
2323
import tests.integ
2424
from sagemaker import AlgorithmEstimator, ModelPackage
25+
from sagemaker.serializeres import CSVSerializer
2526
from sagemaker.tuner import IntegerParameter, HyperparameterTuner
2627
from sagemaker.utils import sagemaker_timestamp
2728
from sagemaker.utils import _aws_partition
@@ -136,10 +137,7 @@ def test_marketplace_attach(sagemaker_session, cpu_instance_type):
136137
training_job_name=training_job_name, sagemaker_session=sagemaker_session
137138
)
138139
predictor = estimator.deploy(
139-
1,
140-
cpu_instance_type,
141-
endpoint_name=endpoint_name,
142-
serializer=sagemaker.predictor.csv_serializer,
140+
1, cpu_instance_type, endpoint_name=endpoint_name, serializer=CSVSerializer()
143141
)
144142
shape = pandas.read_csv(os.path.join(data_path, "iris.csv"), header=None)
145143
a = [50 * i for i in range(3)]
@@ -165,7 +163,7 @@ def test_marketplace_model(sagemaker_session, cpu_instance_type):
165163
)
166164

167165
def predict_wrapper(endpoint, session):
168-
return sagemaker.Predictor(endpoint, session, serializer=sagemaker.predictor.csv_serializer)
166+
return sagemaker.Predictor(endpoint, session, serializer=CSVSerializer())
169167

170168
model = ModelPackage(
171169
role="SageMakerRole",

tests/integ/test_multi_variant_endpoint.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from sagemaker.content_types import CONTENT_TYPE_CSV
2626
from sagemaker.utils import unique_name_from_base
2727
from sagemaker.amazon.amazon_estimator import get_image_uri
28-
from sagemaker.predictor import csv_serializer, Predictor
28+
from sagemaker.predictor import Predictor
29+
from sagemaker.serializers import CSVSerializer
2930

3031

3132
import tests.integ
@@ -169,7 +170,7 @@ def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant
169170
predictor = Predictor(
170171
endpoint_name=multi_variant_endpoint.endpoint_name,
171172
sagemaker_session=sagemaker_session,
172-
serializer=csv_serializer,
173+
serializer=CSVSerializer(),
173174
content_type=CONTENT_TYPE_CSV,
174175
accept=CONTENT_TYPE_CSV,
175176
)
@@ -297,7 +298,7 @@ def test_predict_invocation_with_target_variant_local_mode(
297298
predictor = Predictor(
298299
endpoint_name=multi_variant_endpoint.endpoint_name,
299300
sagemaker_session=sagemaker_session,
300-
serializer=csv_serializer,
301+
serializer=CSVSerializer(),
301302
content_type=CONTENT_TYPE_CSV,
302303
accept=CONTENT_TYPE_CSV,
303304
)

tests/integ/test_tfs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tests.integ
2525
import tests.integ.timeout
2626
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor
27+
from sagemaker.serializers import CSVSerializer
2728

2829

2930
@pytest.fixture(scope="module")
@@ -236,9 +237,7 @@ def test_predict_csv(tfs_predictor):
236237
expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]}
237238

238239
predictor = TensorFlowPredictor(
239-
tfs_predictor.endpoint_name,
240-
tfs_predictor.sagemaker_session,
241-
serializer=sagemaker.predictor.csv_serializer,
240+
tfs_predictor.endpoint_name, tfs_predictor.sagemaker_session, serializer=CSVSerializer(),
242241
)
243242

244243
result = predictor.predict(input_data)

0 commit comments

Comments
 (0)