Skip to content

Commit 3c95c53

Browse files
committed
feature: Add LibSVM serializer
1 parent cfa9c97 commit 3c95c53

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

src/sagemaker/serializers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,27 @@ def serialize(self, data):
270270
buffer = io.BytesIO()
271271
scipy.sparse.save_npz(buffer, data)
272272
return buffer.getvalue()
273+
274+
275+
class LibSVMSerializer(BaseSerializer):
276+
"""Searilize data of various formats to a LibSVM-formatted string."""
277+
278+
CONTENT_TYPE = "text/libsvm"
279+
280+
def serialize(self, data):
281+
"""Serialize data of various formats to a LibSVM-formatted string.
282+
283+
Args:
284+
data (object): Data to be serialized. Can be a string or a
285+
file-like object.
286+
287+
Returns:
288+
str: The data serialized as a LibSVM-formatted string.
289+
"""
290+
if isinstance(data, str):
291+
return data
292+
293+
if hasattr(data, "read"):
294+
return data.read()
295+
296+
raise ValueError("Unable to handle input format: ", type(data))

src/sagemaker/xgboost/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.fw_utils import model_code_key_prefix
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2323
from sagemaker.predictor import Predictor
24-
from sagemaker.serializers import NumpySerializer
24+
from sagemaker.serializers import LibSVMSerializer
2525
from sagemaker.xgboost.defaults import XGBOOST_NAME
2626

2727
logger = logging.getLogger("sagemaker")
@@ -44,7 +44,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4444
chain.
4545
"""
4646
super(XGBoostPredictor, self).__init__(
47-
endpoint_name, sagemaker_session, NumpySerializer(), CSVDeserializer()
47+
endpoint_name, sagemaker_session, LibSVMSerializer(), CSVDeserializer()
4848
)
4949

5050

tests/data/xgboost_abalone/abalone

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
15 1:1 2:0.455 3:0.365 4:0.095 5:0.514 6:0.2245 7:0.101 8:0.15
2+
7 1:1 2:0.35 3:0.265 4:0.09 5:0.2255 6:0.0995 7:0.0485 8:0.07
3+
9 1:2 2:0.53 3:0.42 4:0.135 5:0.677 6:0.2565 7:0.1415 8:0.21
4+
10 1:1 2:0.44 3:0.365 4:0.125 5:0.516 6:0.2155 7:0.114 8:0.155
5+
7 1:3 2:0.33 3:0.255 4:0.08 5:0.205 6:0.0895 7:0.0395 8:0.055

tests/unit/sagemaker/test_serializers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
IdentitySerializer,
2828
SparseMatrixSerializer,
2929
JSONLinesSerializer,
30+
LibSVMSerializer,
3031
)
3132
from tests.unit import DATA_DIR
3233

@@ -310,3 +311,23 @@ def test_sparse_matrix_serializer(sparse_matrix_serializer):
310311
result = scipy.sparse.load_npz(stream).toarray()
311312
expected = data.toarray()
312313
assert np.array_equal(result, expected)
314+
315+
316+
@pytest.fixture
317+
def libsvm_serializer():
318+
return LibSVMSerializer()
319+
320+
321+
def test_libsvm_serializer_str(libsvm_serializer):
322+
original = "0 0:1 5:1"
323+
result = libsvm_serializer.serialize("0 0:1 5:1")
324+
assert result == original
325+
326+
327+
def test_libsvm_serializer_file_like(libsvm_serializer):
328+
libsvm_file_path = os.path.join(DATA_DIR, "xgboost_abalone", "abalone")
329+
with open(libsvm_file_path) as libsvm_file:
330+
validation_data = libsvm_file.read()
331+
libsvm_file.seek(0)
332+
result = libsvm_serializer.serialize(libsvm_file)
333+
assert result == validation_data

0 commit comments

Comments
 (0)