Skip to content

Commit 7fcb5d6

Browse files
committed
feature: Add LibSVM serializer
1 parent cfa9c97 commit 7fcb5d6

File tree

5 files changed

+69
-4
lines changed

5 files changed

+69
-4
lines changed

doc/v2.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ To view logs after attaching a training job to an estimator, use :func:`sagemake
203203
until the completion of the Hyperparameter Tuning Job or Batch Transform Job, respectively.
204204
To make the function non-blocking, use ``wait=False``.
205205

206+
XGBoost Predictor
207+
-----------------
208+
209+
The default serializer of ``sagemaker.xgboost.model.XGBoostPredictor`` has been changed from ``NumpySerializer`` to ``LibSVMSerializer``.
210+
211+
206212
Parameter and Class Name Changes
207213
================================
208214

@@ -263,6 +269,8 @@ The follow serializer/deserializer classes have been renamed and/or moved:
263269
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
264270
+--------------------------------------------------------+-------------------------------------------------------+
265271

272+
``sagemaker.serializers.LibSVMSerializer`` has been added in v2.0.
273+
266274
``distributions``
267275
~~~~~~~~~~~~~~~~~
268276

src/sagemaker/serializers.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def CONTENT_TYPE(self):
5454

5555

5656
class CSVSerializer(BaseSerializer):
57-
"""Searilize data of various formats to a CSV-formatted string."""
57+
"""Serialize data of various formats to a CSV-formatted string."""
5858

5959
CONTENT_TYPE = "text/csv"
6060

@@ -102,7 +102,7 @@ def _serialize_row(self, data):
102102
csv_writer.writerow(data)
103103
return csv_buffer.getvalue().rstrip("\r\n")
104104

105-
raise ValueError("Unable to handle input format: ", type(data))
105+
raise ValueError("Unable to handle input format: %s" % type(data))
106106

107107
def _is_sequence_like(self, data):
108108
"""Returns true if obj is iterable and subscriptable."""
@@ -270,3 +270,34 @@ def serialize(self, data):
270270
buffer = io.BytesIO()
271271
scipy.sparse.save_npz(buffer, data)
272272
return buffer.getvalue()
273+
274+
275+
class LibSVMSerializer(BaseSerializer):
276+
"""Serialize data of various formats to a LibSVM-formatted string.
277+
278+
The data must already be in LIBSVM file format:
279+
<label> <index1>:<value1> <index2>:<value2> ...
280+
281+
It is suitable for sparse datasets since it does not store zero-valued
282+
features.
283+
"""
284+
285+
CONTENT_TYPE = "text/libsvm"
286+
287+
def serialize(self, data):
288+
"""Serialize data of various formats to a LibSVM-formatted string.
289+
290+
Args:
291+
data (object): Data to be serialized. Can be a string or a
292+
file-like object.
293+
294+
Returns:
295+
str: The data serialized as a LibSVM-formatted string.
296+
"""
297+
if isinstance(data, str):
298+
return data
299+
300+
if hasattr(data, "read"):
301+
return data.read()
302+
303+
raise ValueError("Unable to handle input format: %s" % type(data))

src/sagemaker/xgboost/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.fw_utils import model_code_key_prefix
2222
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2323
from sagemaker.predictor import Predictor
24-
from sagemaker.serializers import NumpySerializer
24+
from sagemaker.serializers import LibSVMSerializer
2525
from sagemaker.xgboost.defaults import XGBOOST_NAME
2626

2727
logger = logging.getLogger("sagemaker")
@@ -44,7 +44,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4444
chain.
4545
"""
4646
super(XGBoostPredictor, self).__init__(
47-
endpoint_name, sagemaker_session, NumpySerializer(), CSVDeserializer()
47+
endpoint_name, sagemaker_session, LibSVMSerializer(), CSVDeserializer()
4848
)
4949

5050

tests/data/xgboost_abalone/abalone

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
15 1:1 2:0.455 3:0.365 4:0.095 5:0.514 6:0.2245 7:0.101 8:0.15
2+
7 1:1 2:0.35 3:0.265 4:0.09 5:0.2255 6:0.0995 7:0.0485 8:0.07
3+
9 1:2 2:0.53 3:0.42 4:0.135 5:0.677 6:0.2565 7:0.1415 8:0.21
4+
10 1:1 2:0.44 3:0.365 4:0.125 5:0.516 6:0.2155 7:0.114 8:0.155
5+
7 1:3 2:0.33 3:0.255 4:0.08 5:0.205 6:0.0895 7:0.0395 8:0.055

tests/unit/sagemaker/test_serializers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
IdentitySerializer,
2828
SparseMatrixSerializer,
2929
JSONLinesSerializer,
30+
LibSVMSerializer,
3031
)
3132
from tests.unit import DATA_DIR
3233

@@ -310,3 +311,23 @@ def test_sparse_matrix_serializer(sparse_matrix_serializer):
310311
result = scipy.sparse.load_npz(stream).toarray()
311312
expected = data.toarray()
312313
assert np.array_equal(result, expected)
314+
315+
316+
@pytest.fixture
317+
def libsvm_serializer():
318+
return LibSVMSerializer()
319+
320+
321+
def test_libsvm_serializer_str(libsvm_serializer):
322+
original = "0 0:1 5:1"
323+
result = libsvm_serializer.serialize("0 0:1 5:1")
324+
assert result == original
325+
326+
327+
def test_libsvm_serializer_file_like(libsvm_serializer):
328+
libsvm_file_path = os.path.join(DATA_DIR, "xgboost_abalone", "abalone")
329+
with open(libsvm_file_path) as libsvm_file:
330+
validation_data = libsvm_file.read()
331+
libsvm_file.seek(0)
332+
result = libsvm_serializer.serialize(libsvm_file)
333+
assert result == validation_data

0 commit comments

Comments
 (0)