Skip to content

Commit ea7e9a9

Browse files
committed
feature: Data Serializer capabile of supporting image audio files
1 parent 26b984a commit ea7e9a9

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

src/sagemaker/serializers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import csv
1919
import io
2020
import json
21-
21+
import os
2222
import numpy as np
2323
from six import with_metaclass
2424

@@ -357,3 +357,35 @@ def serialize(self, data):
357357
return data.read()
358358

359359
raise ValueError("Unable to handle input format: %s" % type(data))
360+
361+
362+
class DataSerializer(SimpleBaseSerializer):
363+
"""Serialize data in any file by extracting raw bytes from the file."""
364+
365+
def __init__(self, content_type="file-path/raw-bytes"):
366+
"""Initialize a ``DataSerializer`` instance.
367+
368+
Args:
369+
content_type (str): The MIME type to signal to the inference endpoint when sending
370+
request data (default: "file-path/raw-bytes").
371+
"""
372+
super(DataSerializer, self).__init__(content_type=content_type)
373+
374+
def serialize(self, data):
375+
"""Serialize file of various formats to a raw bytes.
376+
377+
Args:
378+
data (object): Data to be serialized. The data can be a string,
379+
representing file-path or the raw bytes from a file.
380+
Returns:
381+
raw-bytes: The data serialized as a raw-bytes from the input.
382+
"""
383+
if isinstance(data, str):
384+
if not os.path.exists(data):
385+
raise ValueError(f"{data} is not a valid file path.")
386+
image = open(data, "rb")
387+
return image.read()
388+
if isinstance(data, bytes):
389+
return data
390+
391+
raise ValueError(f"Object of type {type(data)} is not Data serializable.")

tests/data/cuteCat.raw

6.43 KB
Binary file not shown.

tests/unit/sagemaker/test_serializers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SparseMatrixSerializer,
2929
JSONLinesSerializer,
3030
LibSVMSerializer,
31+
DataSerializer,
3132
)
3233
from tests.unit import DATA_DIR
3334

@@ -331,3 +332,26 @@ def test_libsvm_serializer_file_like(libsvm_serializer):
331332
libsvm_file.seek(0)
332333
result = libsvm_serializer.serialize(libsvm_file)
333334
assert result == validation_data
335+
336+
337+
@pytest.fixture
338+
def data_serializer():
339+
return DataSerializer()
340+
341+
342+
def test_data_serializer_raw(data_serializer):
343+
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
344+
with open(input_image_file_path, "rb") as image:
345+
input_image = image.read()
346+
input_image_data = data_serializer.serialize(input_image)
347+
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
348+
validation_image_data = open(validation_image_file_path, "rb").read()
349+
assert input_image_data == validation_image_data
350+
351+
352+
def test_data_serializer_file_like(data_serializer):
353+
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
354+
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
355+
input_image_data = data_serializer.serialize(input_image_file_path)
356+
validation_image_data = open(validation_image_file_path, "rb").read()
357+
assert input_image_data == validation_image_data

0 commit comments

Comments
 (0)