Skip to content

Commit 4646097

Browse files
committed
Merge branch 'unpickle_list_shim'. Fixes #45.
2 parents 647516e + e2938a2 commit 4646097

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ Changelog
194194
*******************
195195

196196
* Use custom ``TransformerPipeline`` class to allow transformation steps accepting only a X argument. Fixes #46.
197+
* Add compatibility shim for unpickling mappers with list of transformers created before 1.0.0. Fixes #45.
197198

198199

199200
1.0.0 (2015-11-28)

sklearn_pandas/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.base import BaseEstimator, TransformerMixin
99
from sklearn import cross_validation
1010
from sklearn import grid_search
11-
from pipeline import make_transformer_pipeline
11+
from pipeline import make_transformer_pipeline, TransformerPipeline
1212

1313
# load in the correct stringtype: str for py3, basestring for py2
1414
string_types = str if sys.version_info >= (3, 0) else basestring
@@ -96,6 +96,12 @@ def __init__(self, features, sparse=False):
9696
self.features = features
9797
self.sparse = sparse
9898

99+
def __setstate__(self, state):
100+
# compatibility shim for pickles created with sklearn-pandas<1.0.0
101+
self.features = [(columns, _build_transformer(transformers))
102+
for (columns, transformers) in state['features']]
103+
self.sparse = state.get('sparse', False)
104+
99105
def _get_col_subset(self, X, cols):
100106
"""
101107
Get a subset of columns from the given table X.

tests/test_dataframe_mapper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from sklearn.base import BaseEstimator, TransformerMixin
1919
import numpy as np
2020
from numpy.testing import assert_array_equal
21+
import pickle
2122

2223
from sklearn_pandas import (
2324
DataFrameMapper,
2425
PassthroughTransformer,
2526
cross_val_score,
2627
_build_transformer,
2728
_handle_feature,
29+
TransformerPipeline,
2830
)
2931

3032

@@ -191,6 +193,18 @@ def test_list_transformers():
191193
assert (abs(dmatrix.std(axis=0) - 1) <= 1e-6).all()
192194

193195

196+
def test_list_transformers_old_unpickle(simple_dataframe):
197+
mapper = DataFrameMapper(None)
198+
# simulate the mapper was created with < 1.0.0 code
199+
mapper.features = [('a', [MockXTransformer()])]
200+
mapper_pickled = pickle.dumps(mapper)
201+
202+
loaded_mapper = pickle.loads(mapper_pickled)
203+
transformer = loaded_mapper.features[0][1]
204+
assert isinstance(transformer, TransformerPipeline)
205+
assert isinstance(transformer.steps[0][1], MockXTransformer)
206+
207+
194208
def test_sparse_features(simple_dataframe):
195209
"""
196210
If any of the extracted features is sparse and "sparse" argument

0 commit comments

Comments
 (0)