Skip to content

Commit 647516e

Browse files
committed
Merge branch 'transformer_pipeline'. Fixes #46.
2 parents b72335d + 042a62a commit 647516e

File tree

5 files changed

+124
-3
lines changed

5 files changed

+124
-3
lines changed

README.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ Sklearn-pandas' ``cross_val_score`` function provides exactly the same interface
190190
Changelog
191191
---------
192192

193+
1.1.0 (development)
194+
*******************
195+
196+
* Use custom ``TransformerPipeline`` class to allow transformation steps accepting only a X argument. Fixes #46.
197+
198+
193199
1.0.0 (2015-11-28)
194200
*******************
195201

pipeline.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import six
2+
from sklearn.pipeline import _name_estimators, Pipeline
3+
from sklearn.utils import tosequence
4+
5+
6+
class TransformerPipeline(Pipeline):
7+
"""
8+
Pipeline that expects all steps to be transformers taking a single argument
9+
and having fit and transform methods.
10+
11+
Code is copied from sklearn's Pipeline, leaving out the `y=None` argument.
12+
"""
13+
def __init__(self, steps):
14+
names, estimators = zip(*steps)
15+
if len(dict(steps)) != len(steps):
16+
raise ValueError("Provided step names are not unique: %s" % (names,))
17+
18+
# shallow copy of steps
19+
self.steps = tosequence(steps)
20+
estimator = estimators[-1]
21+
22+
for e in estimators:
23+
if (not (hasattr(e, "fit") or hasattr(e, "fit_transform")) or not
24+
hasattr(e, "transform")):
25+
raise TypeError("All steps of the chain should "
26+
"be transforms and implement fit and transform"
27+
" '%s' (type %s) doesn't)" % (e, type(e)))
28+
29+
if not hasattr(estimator, "fit"):
30+
raise TypeError("Last step of chain should implement fit "
31+
"'%s' (type %s) doesn't)"
32+
% (estimator, type(estimator)))
33+
34+
def _pre_transform(self, X, **fit_params):
35+
fit_params_steps = dict((step, {}) for step, _ in self.steps)
36+
for pname, pval in six.iteritems(fit_params):
37+
step, param = pname.split('__', 1)
38+
fit_params_steps[step][param] = pval
39+
Xt = X
40+
for name, transform in self.steps[:-1]:
41+
if hasattr(transform, "fit_transform"):
42+
Xt = transform.fit_transform(Xt, **fit_params_steps[name])
43+
else:
44+
Xt = transform.fit(Xt, **fit_params_steps[name]) \
45+
.transform(Xt)
46+
return Xt, fit_params_steps[self.steps[-1][0]]
47+
48+
def fit(self, X, **fit_params):
49+
Xt, fit_params = self._pre_transform(X, **fit_params)
50+
self.steps[-1][-1].fit(Xt, **fit_params)
51+
return self
52+
53+
def fit_transform(self, X, **fit_params):
54+
Xt, fit_params = self._pre_transform(X, **fit_params)
55+
if hasattr(self.steps[-1][-1], 'fit_transform'):
56+
return self.steps[-1][-1].fit_transform(Xt, **fit_params)
57+
else:
58+
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt)
59+
60+
61+
def make_transformer_pipeline(*steps):
62+
"""Construct a TransformerPipeline from the given estimators.
63+
"""
64+
return TransformerPipeline(_name_estimators(steps))

sklearn_pandas/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
__version__ = '1.0.0'
22

3+
4+
import sys
35
import numpy as np
46
import pandas as pd
57
from scipy import sparse
68
from sklearn.base import BaseEstimator, TransformerMixin
7-
from sklearn.pipeline import make_pipeline
89
from sklearn import cross_validation
910
from sklearn import grid_search
10-
import sys
11+
from pipeline import make_transformer_pipeline
1112

1213
# load in the correct stringtype: str for py3, basestring for py2
1314
string_types = str if sys.version_info >= (3, 0) else basestring
@@ -68,7 +69,7 @@ def _handle_feature(fea):
6869

6970
def _build_transformer(transformers):
7071
if isinstance(transformers, list):
71-
transformers = make_pipeline(*transformers)
72+
transformers = make_transformer_pipeline(*transformers)
7273
return transformers
7374

7475

test_pipeline.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
from pipeline import TransformerPipeline
3+
4+
5+
class NoTransformT(object):
6+
"""Transformer without transform method.
7+
"""
8+
def fit(self, x):
9+
return self
10+
11+
12+
class NoFitT(object):
13+
"""Transformer without fit method.
14+
"""
15+
def transform(self, x):
16+
return self
17+
18+
19+
def test_all_steps_fit_transform():
20+
"""
21+
All steps must implement fit and transform. Otherwise, raise TypeError.
22+
"""
23+
with pytest.raises(TypeError):
24+
TransformerPipeline([('svc', NoTransformT())])
25+
26+
with pytest.raises(TypeError):
27+
TransformerPipeline([('svc', NoFitT())])

tests/test_dataframe_mapper.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@
2828
)
2929

3030

31+
class MockXTransformer(object):
32+
"""
33+
Mock transformer that accepts no y argument.
34+
"""
35+
def fit(self, X):
36+
return self
37+
38+
def transform(self, X):
39+
return X
40+
41+
3142
class MockTClassifier(object):
3243
"""
3344
Mock transformer/classifier.
@@ -148,6 +159,18 @@ def test_build_transformers():
148159
assert pipeline.steps[ix][1] == transformer
149160

150161

162+
def test_list_transformers_single_arg(simple_dataframe):
163+
"""
164+
Multiple transformers can be specified in a list even if some of them
165+
only accept one X argument instead of two (X, y).
166+
"""
167+
mapper = DataFrameMapper([
168+
('a', [MockXTransformer()])
169+
])
170+
# doesn't fail
171+
mapper.fit_transform(simple_dataframe)
172+
173+
151174
def test_list_transformers():
152175
"""
153176
Specifying a list of transformers applies them sequentially to the

0 commit comments

Comments
 (0)