Skip to content

Commit c46b00c

Browse files
author
Balaji Veeramani
committed
Clean code
1 parent dcfc735 commit c46b00c

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

src/sagemaker/serializers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,14 @@ def serialize(self, data):
196196
class IdentitySerializer(BaseSerializer):
197197
"""Serialize data by returning data without modification."""
198198

199-
CONTENT_TYPE = "application/json"
199+
def __init__(self, content_type="application/octet-stream"):
200+
"""Initialize the ``content_type`` attribute.
201+
202+
Args:
203+
content_type (str): The MIME type of the serialized data (default:
204+
"application/octet-stream").
205+
"""
206+
self.content_type = content_type
200207

201208
def serialize(self, data):
202209
"""Return data without modification.
@@ -209,6 +216,11 @@ def serialize(self, data):
209216
"""
210217
return data
211218

219+
@property
220+
def CONTENT_TYPE(self):
221+
"""The MIME type of the data sent to the inference endpoint."""
222+
return self.content_type
223+
212224

213225
class JSONLinesSerializer(BaseSerializer):
214226
"""Serialize data to a JSON Lines formatted string."""

tests/integ/test_tfs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
import sagemaker.utils
2424
import tests.integ
2525
import tests.integ.timeout
26+
from sagemaker.deserializers import JSONLinesDeserializer
2627
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor
27-
from sagemaker.serializers import CSVSerializer
28+
from sagemaker.serializers import CSVSerializer, IdentitySerializer, JSONLinesSerializer
2829

2930

3031
@pytest.fixture(scope="module")
@@ -190,7 +191,7 @@ def test_predict_jsons_json_content_type(tfs_predictor):
190191
predictor = sagemaker.Predictor(
191192
tfs_predictor.endpoint_name,
192193
tfs_predictor.sagemaker_session,
193-
serializer=sagemaker.serializers.IdentitySerializer(),
194+
serializer=IdentitySerializer(content_type="application/json"),
194195
deserializer=sagemaker.deserializers.JSONDeserializer(),
195196
)
196197

@@ -205,7 +206,7 @@ def test_predict_jsons(tfs_predictor):
205206
predictor = sagemaker.Predictor(
206207
tfs_predictor.endpoint_name,
207208
tfs_predictor.sagemaker_session,
208-
serializer=sagemaker.serializer.IdentitySerializer(),
209+
serializer=IdentitySerializer(content_type="application/json"),
209210
deserializer=sagemaker.deserializers.JSONDeserializer(),
210211
)
211212

@@ -220,8 +221,8 @@ def test_predict_jsonlines(tfs_predictor):
220221
predictor = sagemaker.Predictor(
221222
tfs_predictor.endpoint_name,
222223
tfs_predictor.sagemaker_session,
223-
serializer=sagemaker.serializers.JSONLinesSerializer(),
224-
deserializer=sagemaker.deserializers.JSONLinesDeserializer(),
224+
serializer=JSONLinesSerializer(),
225+
deserializer=JSONLinesDeserializer(),
225226
)
226227

227228
result = predictor.predict(input_data)

tests/unit/sagemaker/test_serializers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,17 @@ def test_json_serializer_csv_buffer(json_serializer):
237237
assert result == validation_value
238238

239239

240-
@pytest.fixture
241-
def identity_serializer():
242-
return IdentitySerializer()
243-
244-
245240
def test_identity_serializer(identity_serializer):
241+
identity_serializer = IdentitySerializer()
246242
assert identity_serializer.serialize(b"{}") == b"{}"
247243

248244

245+
def test_identity_serializer_with_custom_content_type():
246+
identity_serializer = IdentitySerializer(content_type="text/csv")
247+
assert identity_serializer.serialize(b"a,b\n1,2") == b"a,b\n1,2"
248+
assert identity_serializer.CONTENT_TYPE == "text/csv"
249+
250+
249251
@pytest.fixture
250252
def json_lines_serializer():
251253
return JSONLinesSerializer()

0 commit comments

Comments
 (0)