Skip to content

Commit eb25279

Browse files
committed
feature: Add LibSVM serializer
1 parent 77086de commit eb25279

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

src/sagemaker/serializers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,27 @@ def serialize(self, data):
241241
buffer = io.BytesIO()
242242
scipy.sparse.save_npz(buffer, data)
243243
return buffer.getvalue()
244+
245+
246+
class LibSVMSerializer(BaseSerializer):
247+
"""Searilize data of various formats to a LibSVM-formatted string."""
248+
249+
CONTENT_TYPE = "text/libsvm"
250+
251+
def serialize(self, data):
252+
"""Serialize data of various formats to a LibSVM-formatted string.
253+
254+
Args:
255+
data (object): Data to be serialized. Can be a string or a
256+
file-like object.
257+
258+
Returns:
259+
str: The data serialized as a LibSVM-formatted string.
260+
"""
261+
if isinstance(data, str):
262+
return data
263+
264+
if hasattr(data, "read"):
265+
return data.read()
266+
267+
raise ValueError("Unable to handle input format: ", type(data))

src/sagemaker/xgboost/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.fw_utils import model_code_key_prefix
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2323
from sagemaker.predictor import Predictor
24-
from sagemaker.serializers import NumpySerializer
24+
from sagemaker.serializers import LibSVMSerializer
2525
from sagemaker.xgboost.defaults import XGBOOST_NAME
2626

2727
logger = logging.getLogger("sagemaker")
@@ -44,7 +44,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4444
chain.
4545
"""
4646
super(XGBoostPredictor, self).__init__(
47-
endpoint_name, sagemaker_session, NumpySerializer(), CSVDeserializer()
47+
endpoint_name, sagemaker_session, LibSVMSerializer(), CSVDeserializer()
4848
)
4949

5050

tests/data/xgboost_abalone/abalone

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
15 1:1 2:0.455 3:0.365 4:0.095 5:0.514 6:0.2245 7:0.101 8:0.15
2+
7 1:1 2:0.35 3:0.265 4:0.09 5:0.2255 6:0.0995 7:0.0485 8:0.07
3+
9 1:2 2:0.53 3:0.42 4:0.135 5:0.677 6:0.2565 7:0.1415 8:0.21
4+
10 1:1 2:0.44 3:0.365 4:0.125 5:0.516 6:0.2155 7:0.114 8:0.155
5+
7 1:3 2:0.33 3:0.255 4:0.08 5:0.205 6:0.0895 7:0.0395 8:0.055

tests/unit/sagemaker/test_serializers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
JSONSerializer,
2727
SparseMatrixSerializer,
2828
JSONLinesSerializer,
29+
LibSVMSerializer,
2930
)
3031
from tests.unit import DATA_DIR
3132

@@ -298,3 +299,23 @@ def test_sparse_matrix_serializer(sparse_matrix_serializer):
298299
result = scipy.sparse.load_npz(stream).toarray()
299300
expected = data.toarray()
300301
assert np.array_equal(result, expected)
302+
303+
304+
@pytest.fixture
305+
def libsvm_serializer():
306+
return LibSVMSerializer()
307+
308+
309+
def test_libsvm_serializer_str(libsvm_serializer):
310+
original = "0 0:1 5:1"
311+
result = libsvm_serializer.serialize("0 0:1 5:1")
312+
assert result == original
313+
314+
315+
def test_libsvm_serializer_file_like(libsvm_serializer):
316+
libsvm_file_path = os.path.join(DATA_DIR, "xgboost_abalone", "abalone")
317+
with open(libsvm_file_path) as libsvm_file:
318+
validation_data = libsvm_file.read()
319+
libsvm_file.seek(0)
320+
result = libsvm_serializer.serialize(libsvm_file)
321+
assert result == validation_data

0 commit comments

Comments
 (0)