Skip to content

Commit 6d11398

Browse files
authored
Merge branch 'zwei' into neo-uris
2 parents c5902f7 + f7b3f6b commit 6d11398

File tree

3 files changed

+94
-5
lines changed

3 files changed

+94
-5
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/serde.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def node_should_be_modified(self, node):
274274
the ``sagemaker.amazon.common`` module.
275275
"""
276276
return node.module == "sagemaker.amazon.common" and any(
277-
[alias.name in OLD_AMAZON_CLASS_NAMES for alias in node.names]
277+
alias.name in OLD_AMAZON_CLASS_NAMES for alias in node.names
278278
)
279279

280280
def modify_node(self, node):
@@ -368,7 +368,11 @@ class SerializerImportInserter(_ImportInserter):
368368
"""
369369

370370
def __init__(self):
371-
# Amazon SerDe are not defined in the sagemaker.serializers module.
371+
"""Initialize the ``class_names`` and ``import_node`` attributes.
372+
373+
Amazon-specific serializers are ignored because they are not defined in
374+
the ``sagemaker.serializers`` module.
375+
"""
372376
class_names = {
373377
class_name
374378
for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES
@@ -398,7 +402,11 @@ class DeserializerImportInserter(_ImportInserter):
398402
"""
399403

400404
def __init__(self):
401-
# Amazon SerDe are not defined in the sagemaker.serializers module.
405+
"""Initialize the ``class_names`` and ``import_node`` attributes.
406+
407+
Amazon-specific deserializers are ignored because they are not defined
408+
in the ``sagemaker.deserializers`` module.
409+
"""
402410
class_names = {
403411
class_name
404412
for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES

src/sagemaker/serializers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
from collections.abc import Iterable
1718
import csv
1819
import io
1920
import json
@@ -23,7 +24,7 @@
2324
from sagemaker.utils import DeferredError
2425

2526
try:
26-
import scipy
27+
import scipy.sparse
2728
except ImportError as e:
2829
scipy = DeferredError(e)
2930

@@ -192,6 +193,34 @@ def serialize(self, data):
192193
return json.dumps(data)
193194

194195

196+
class JSONLinesSerializer(BaseSerializer):
197+
"""Serialize data to a JSON Lines formatted string."""
198+
199+
CONTENT_TYPE = "application/jsonlines"
200+
201+
def serialize(self, data):
202+
"""Serialize data of various formats to a JSON Lines formatted string.
203+
204+
Args:
205+
data (object): Data to be serialized. The data can be a string,
206+
iterable of JSON serializable objects, or a file-like object.
207+
208+
Returns:
209+
str: The data serialized as a string containing newline-separated
210+
JSON values.
211+
"""
212+
if isinstance(data, str):
213+
return data
214+
215+
if hasattr(data, "read"):
216+
return data.read()
217+
218+
if isinstance(data, Iterable):
219+
return "\n".join(json.dumps(element) for element in data)
220+
221+
raise ValueError("Object of type %s is not JSON Lines serializable." % type(data))
222+
223+
195224
class SparseMatrixSerializer(BaseSerializer):
196225
"""Serialize a sparse matrix to a buffer using the .npz format."""
197226

tests/unit/sagemaker/test_serializers.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
import numpy as np
2020
import pytest
21-
import scipy
21+
import scipy.sparse
2222

2323
from sagemaker.serializers import (
2424
CSVSerializer,
2525
NumpySerializer,
2626
JSONSerializer,
2727
SparseMatrixSerializer,
28+
JSONLinesSerializer,
2829
)
2930
from tests.unit import DATA_DIR
3031

@@ -235,6 +236,57 @@ def test_json_serializer_csv_buffer(json_serializer):
235236
assert result == validation_value
236237

237238

239+
@pytest.fixture
240+
def json_lines_serializer():
241+
return JSONLinesSerializer()
242+
243+
244+
@pytest.mark.parametrize(
245+
"input, expected",
246+
[
247+
('["Name", "Score"]\n["Gilbert", 24]', '["Name", "Score"]\n["Gilbert", 24]'),
248+
(
249+
'{"Name": "Gilbert", "Score": 24}\n{"Name": "Alexa", "Score": 29}',
250+
'{"Name": "Gilbert", "Score": 24}\n{"Name": "Alexa", "Score": 29}',
251+
),
252+
],
253+
)
254+
def test_json_lines_serializer_string(json_lines_serializer, input, expected):
255+
actual = json_lines_serializer.serialize(input)
256+
assert actual == expected
257+
258+
259+
@pytest.mark.parametrize(
260+
"input, expected",
261+
[
262+
([["Name", "Score"], ["Gilbert", 24]], '["Name", "Score"]\n["Gilbert", 24]'),
263+
(
264+
[{"Name": "Gilbert", "Score": 24}, {"Name": "Alexa", "Score": 29}],
265+
'{"Name": "Gilbert", "Score": 24}\n{"Name": "Alexa", "Score": 29}',
266+
),
267+
],
268+
)
269+
def test_json_lines_serializer_list(json_lines_serializer, input, expected):
270+
actual = json_lines_serializer.serialize(input)
271+
assert actual == expected
272+
273+
274+
@pytest.mark.parametrize(
275+
"source, expected",
276+
[
277+
('["Name", "Score"]\n["Gilbert", 24]', '["Name", "Score"]\n["Gilbert", 24]'),
278+
(
279+
'{"Name": "Gilbert", "Score": 24}\n{"Name": "Alexa", "Score": 29}',
280+
'{"Name": "Gilbert", "Score": 24}\n{"Name": "Alexa", "Score": 29}',
281+
),
282+
],
283+
)
284+
def test_json_lines_serializer_file_like(json_lines_serializer, source, expected):
285+
input = io.StringIO(source)
286+
actual = json_lines_serializer.serialize(input)
287+
assert actual == expected
288+
289+
238290
@pytest.fixture
239291
def sparse_matrix_serializer():
240292
return SparseMatrixSerializer()

0 commit comments

Comments
 (0)