Skip to content

feature: Add LibSVM serializer #1776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ To view logs after attaching a training job to an estimator, use :func:`sagemake
until the completion of the Hyperparameter Tuning Job or Batch Transform Job, respectively.
To make the function non-blocking, use ``wait=False``.

XGBoost Predictor
-----------------

The default serializer of ``sagemaker.xgboost.model.XGBoostPredictor`` has been changed from ``NumpySerializer`` to ``LibSVMSerializer``.


Parameter and Class Name Changes
================================

Expand Down Expand Up @@ -263,6 +269,8 @@ The follow serializer/deserializer classes have been renamed and/or moved:
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+

``sagemaker.serializers.LibSVMSerializer`` has been added in v2.0.

``distributions``
~~~~~~~~~~~~~~~~~

Expand Down
35 changes: 33 additions & 2 deletions src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def CONTENT_TYPE(self):


class CSVSerializer(BaseSerializer):
"""Searilize data of various formats to a CSV-formatted string."""
"""Serialize data of various formats to a CSV-formatted string."""

CONTENT_TYPE = "text/csv"

Expand Down Expand Up @@ -102,7 +102,7 @@ def _serialize_row(self, data):
csv_writer.writerow(data)
return csv_buffer.getvalue().rstrip("\r\n")

raise ValueError("Unable to handle input format: ", type(data))
raise ValueError("Unable to handle input format: %s" % type(data))

def _is_sequence_like(self, data):
"""Returns true if obj is iterable and subscriptable."""
Expand Down Expand Up @@ -270,3 +270,34 @@ def serialize(self, data):
buffer = io.BytesIO()
scipy.sparse.save_npz(buffer, data)
return buffer.getvalue()


class LibSVMSerializer(BaseSerializer):
"""Serialize data of various formats to a LibSVM-formatted string.

The data must already be in LIBSVM file format:
<label> <index1>:<value1> <index2>:<value2> ...

It is suitable for sparse datasets since it does not store zero-valued
features.
"""

CONTENT_TYPE = "text/libsvm"

def serialize(self, data):
"""Serialize data of various formats to a LibSVM-formatted string.

Args:
data (object): Data to be serialized. Can be a string or a
file-like object.

Returns:
str: The data serialized as a LibSVM-formatted string.
"""
if isinstance(data, str):
return data

if hasattr(data, "read"):
return data.read()

raise ValueError("Unable to handle input format: %s" % type(data))
4 changes: 2 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.serializers import LibSVMSerializer
from sagemaker.xgboost.defaults import XGBOOST_NAME

logger = logging.getLogger("sagemaker")
Expand All @@ -44,7 +44,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
chain.
"""
super(XGBoostPredictor, self).__init__(
endpoint_name, sagemaker_session, NumpySerializer(), CSVDeserializer()
endpoint_name, sagemaker_session, LibSVMSerializer(), CSVDeserializer()
)


Expand Down
5 changes: 5 additions & 0 deletions tests/data/xgboost_abalone/abalone
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
15 1:1 2:0.455 3:0.365 4:0.095 5:0.514 6:0.2245 7:0.101 8:0.15
7 1:1 2:0.35 3:0.265 4:0.09 5:0.2255 6:0.0995 7:0.0485 8:0.07
9 1:2 2:0.53 3:0.42 4:0.135 5:0.677 6:0.2565 7:0.1415 8:0.21
10 1:1 2:0.44 3:0.365 4:0.125 5:0.516 6:0.2155 7:0.114 8:0.155
7 1:3 2:0.33 3:0.255 4:0.08 5:0.205 6:0.0895 7:0.0395 8:0.055
21 changes: 21 additions & 0 deletions tests/unit/sagemaker/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
IdentitySerializer,
SparseMatrixSerializer,
JSONLinesSerializer,
LibSVMSerializer,
)
from tests.unit import DATA_DIR

Expand Down Expand Up @@ -310,3 +311,23 @@ def test_sparse_matrix_serializer(sparse_matrix_serializer):
result = scipy.sparse.load_npz(stream).toarray()
expected = data.toarray()
assert np.array_equal(result, expected)


@pytest.fixture
def libsvm_serializer():
return LibSVMSerializer()


def test_libsvm_serializer_str(libsvm_serializer):
original = "0 0:1 5:1"
result = libsvm_serializer.serialize("0 0:1 5:1")
assert result == original


def test_libsvm_serializer_file_like(libsvm_serializer):
libsvm_file_path = os.path.join(DATA_DIR, "xgboost_abalone", "abalone")
with open(libsvm_file_path) as libsvm_file:
validation_data = libsvm_file.read()
libsvm_file.seek(0)
result = libsvm_serializer.serialize(libsvm_file)
assert result == validation_data