Skip to content

Commit 665d66d

Browse files
authored
Merge pull request #59 from vzaretsk/add-y-arg-to-fit
added y argument to fit methods
2 parents 7a3c997 + c01abd3 commit 665d66d

File tree

5 files changed

+188
-28
lines changed

5 files changed

+188
-28
lines changed

README.rst

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,36 @@ passing it as the ``default`` argument to the mapper:
167167
... ('pet', sklearn.preprocessing.LabelBinarizer()),
168168
... ('children', None)
169169
... ], default=sklearn.preprocessing.StandardScaler())
170-
>>> np.round(mapper4.fit_transform(data.copy()))
171-
array([[ 1., 0., 0., 4., 2.],
172-
[ 0., 1., 0., 6., -1.],
173-
[ 0., 1., 0., 3., 0.],
174-
[ 0., 0., 1., 3., -1.],
175-
[ 1., 0., 0., 2., -0.],
176-
[ 0., 1., 0., 3., 1.],
177-
[ 1., 0., 0., 5., -0.],
178-
[ 0., 0., 1., 4., -1.]])
170+
>>> np.round(mapper4.fit_transform(data.copy()), 1)
171+
array([[ 1. , 0. , 0. , 4. , 2.3],
172+
[ 0. , 1. , 0. , 6. , -0.9],
173+
[ 0. , 1. , 0. , 3. , 0.1],
174+
[ 0. , 0. , 1. , 3. , -0.7],
175+
[ 1. , 0. , 0. , 2. , -0.5],
176+
[ 0. , 1. , 0. , 3. , 0.8],
177+
[ 1. , 0. , 0. , 5. , -0.3],
178+
[ 0. , 0. , 1. , 4. , -0.7]])
179179

180180
Using ``default=False`` (the default) drops unselected columns. Using
181181
``default=None`` pass the unselected columns unchanged.
182182

183+
Feature selection and other supervised transformations
184+
******************************************************
185+
186+
``DataFrameMapper`` supports transformers that require both X and y arguments. An example of this is feature selection. Treating the 'pet' column as the target, we will select the column that best predicts it.
187+
188+
>>> from sklearn.feature_selection import SelectKBest, chi2
189+
>>> mapper_fs = DataFrameMapper([(['children','salary'], SelectKBest(chi2, k=1))])
190+
>>> mapper_fs.fit_transform(data[['children','salary']], data['pet'])
191+
array([[ 90.],
192+
[ 24.],
193+
[ 44.],
194+
[ 27.],
195+
[ 32.],
196+
[ 59.],
197+
[ 36.],
198+
[ 27.]])
199+
183200
Working with sparse features
184201
****************************
185202

sklearn_pandas/dataframe_mapper.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.base import BaseEstimator, TransformerMixin
66

77
from .cross_validation import DataWrapper
8-
from .pipeline import make_transformer_pipeline
8+
from .pipeline import make_transformer_pipeline, _call_fit
99

1010
# load in the correct stringtype: str for py3, basestring for py2
1111
string_types = str if sys.version_info >= (3, 0) else basestring
@@ -130,16 +130,19 @@ def fit(self, X, y=None):
130130
Fit a transformation from the pipeline
131131
132132
X the data to fit
133+
134+
y the target vector relative to X, optional
135+
133136
"""
134137
for columns, transformers in self.features:
135138
if transformers is not None:
136-
transformers.fit(self._get_col_subset(X, columns))
139+
_call_fit(transformers.fit,
140+
self._get_col_subset(X, columns), y)
137141

138142
# handle features not explicitly selected
139143
if self.default: # not False and not None
140-
self.default.fit(
141-
self._get_col_subset(X, self._unselected_columns(X))
142-
)
144+
_call_fit(self.default.fit,
145+
self._get_col_subset(X, self._unselected_columns(X)), y)
143146
return self
144147

145148
def transform(self, X):

sklearn_pandas/pipeline.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,37 @@
33
from sklearn.utils import tosequence
44

55

6+
def _call_fit(fit_method, X, y=None, **kwargs):
7+
"""
8+
helper function, calls the fit or fit_transform method with the correct
9+
number of parameters
10+
11+
fit_method: fit or fit_transform method of the transformer
12+
X: the data to fit
13+
y: the target vector relative to X, optional
14+
kwargs: any keyword arguments to the fit method
15+
16+
return: the result of the fit or fit_transform method
17+
18+
WARNING: if this function raises a TypeError exception, test the fit
19+
or fit_transform method passed to it in isolation as _call_fit will not
20+
distinguish TypeError due to incorrect number of arguments from
21+
other TypeError
22+
"""
23+
try:
24+
return fit_method(X, y, **kwargs)
25+
except TypeError:
26+
# fit takes only one argument
27+
return fit_method(X, **kwargs)
28+
29+
630
class TransformerPipeline(Pipeline):
731
"""
8-
Pipeline that expects all steps to be transformers taking a single argument
32+
Pipeline that expects all steps to be transformers taking a single X argument,
33+
an optional y argument,
934
and having fit and transform methods.
1035
11-
Code is copied from sklearn's Pipeline, leaving out the `y=None` argument.
36+
Code is copied from sklearn's Pipeline
1237
"""
1338
def __init__(self, steps):
1439
names, estimators = zip(*steps)
@@ -31,31 +56,34 @@ def __init__(self, steps):
3156
"'%s' (type %s) doesn't)"
3257
% (estimator, type(estimator)))
3358

34-
def _pre_transform(self, X, **fit_params):
59+
def _pre_transform(self, X, y=None, **fit_params):
3560
fit_params_steps = dict((step, {}) for step, _ in self.steps)
3661
for pname, pval in six.iteritems(fit_params):
3762
step, param = pname.split('__', 1)
3863
fit_params_steps[step][param] = pval
3964
Xt = X
4065
for name, transform in self.steps[:-1]:
4166
if hasattr(transform, "fit_transform"):
42-
Xt = transform.fit_transform(Xt, **fit_params_steps[name])
67+
Xt = _call_fit(transform.fit_transform,
68+
Xt, y, **fit_params_steps[name])
4369
else:
44-
Xt = transform.fit(Xt, **fit_params_steps[name]) \
45-
.transform(Xt)
70+
Xt = _call_fit(transform.fit,
71+
Xt, y, **fit_params_steps[name]).transform(Xt)
4672
return Xt, fit_params_steps[self.steps[-1][0]]
4773

48-
def fit(self, X, **fit_params):
49-
Xt, fit_params = self._pre_transform(X, **fit_params)
50-
self.steps[-1][-1].fit(Xt, **fit_params)
74+
def fit(self, X, y=None, **fit_params):
75+
Xt, fit_params = self._pre_transform(X, y, **fit_params)
76+
_call_fit(self.steps[-1][-1].fit, Xt, y, **fit_params)
5177
return self
5278

53-
def fit_transform(self, X, **fit_params):
54-
Xt, fit_params = self._pre_transform(X, **fit_params)
79+
def fit_transform(self, X, y=None, **fit_params):
80+
Xt, fit_params = self._pre_transform(X, y, **fit_params)
5581
if hasattr(self.steps[-1][-1], 'fit_transform'):
56-
return self.steps[-1][-1].fit_transform(Xt, **fit_params)
82+
return _call_fit(self.steps[-1][-1].fit_transform,
83+
Xt, y, **fit_params)
5784
else:
58-
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt)
85+
return _call_fit(self.steps[-1][-1].fit,
86+
Xt, y, **fit_params).transform(Xt)
5987

6088

6189
def make_transformer_pipeline(*steps):

tests/test_dataframe_mapper.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sklearn.svm import SVC
1919
from sklearn.feature_extraction.text import CountVectorizer
2020
from sklearn.preprocessing import Imputer, StandardScaler, OneHotEncoder
21+
from sklearn.feature_selection import SelectKBest, chi2
2122
from sklearn.base import BaseEstimator, TransformerMixin
2223
import numpy as np
2324
from numpy.testing import assert_array_equal
@@ -69,6 +70,13 @@ def simple_dataframe():
6970
return pd.DataFrame({'a': [1, 2, 3]})
7071

7172

73+
@pytest.fixture
74+
def complex_dataframe():
75+
return pd.DataFrame({'target': ['a', 'a', 'a', 'b', 'b', 'b'],
76+
'feat1': [1, 2, 3, 4, 5, 6],
77+
'feat2': [1, 2, 3, 2, 3, 4]})
78+
79+
7280
def test_nonexistent_columns_explicit_fail(simple_dataframe):
7381
"""
7482
If a nonexistent column is selected, KeyError is raised.
@@ -306,6 +314,37 @@ def test_sparse_off(simple_dataframe):
306314
assert type(dmatrix) != sparse.csr.csr_matrix
307315

308316

317+
def test_fit_with_optional_y_arg(complex_dataframe):
318+
"""
319+
Transformers with an optional y argument in the fit method
320+
are handled correctly
321+
"""
322+
df = complex_dataframe
323+
mapper = DataFrameMapper([(['feat1', 'feat2'], MockTClassifier())])
324+
# doesn't fail
325+
mapper.fit(df[['feat1', 'feat2']], df['target'])
326+
327+
328+
def test_fit_with_required_y_arg(complex_dataframe):
329+
"""
330+
Transformers with a required y argument in the fit method
331+
are handled and perform correctly
332+
"""
333+
df = complex_dataframe
334+
mapper = DataFrameMapper([(['feat1', 'feat2'], SelectKBest(chi2, k=1))])
335+
336+
# fit, doesn't fail
337+
ft_arr = mapper.fit(df[['feat1', 'feat2']], df['target'])
338+
339+
# fit_transform
340+
ft_arr = mapper.fit_transform(df[['feat1', 'feat2']], df['target'])
341+
assert_array_equal(ft_arr, df[['feat1']].values)
342+
343+
# transform
344+
t_arr = mapper.transform(df[['feat1', 'feat2']])
345+
assert_array_equal(t_arr, df[['feat1']].values)
346+
347+
309348
# Integration tests with real dataframes
310349

311350
@pytest.fixture

tests/test_pipeline.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
import pytest
2-
from sklearn_pandas.pipeline import TransformerPipeline
2+
from sklearn_pandas.pipeline import TransformerPipeline, _call_fit
3+
4+
# In py3, mock is included with the unittest standard library
5+
# In py2, it's a separate package
6+
try:
7+
from unittest.mock import patch
8+
except ImportError:
9+
from mock import patch
310

411

512
class NoTransformT(object):
@@ -16,6 +23,39 @@ def transform(self, x):
1623
return self
1724

1825

26+
class Trans(object):
27+
"""
28+
Transformer with fit and transform methods
29+
"""
30+
def fit(self, x, y=None):
31+
return self
32+
33+
def transform(self, x):
34+
return self
35+
36+
37+
def func_x_y(x, y, kwarg='kwarg'):
38+
"""
39+
Function with required x and y arguments
40+
"""
41+
return
42+
43+
44+
def func_x(x, kwarg='kwarg'):
45+
"""
46+
Function with required x argument
47+
"""
48+
return
49+
50+
51+
def func_raise_type_err(x, y, kwarg='kwarg'):
52+
"""
53+
Function with required x and y arguments,
54+
raises TypeError
55+
"""
56+
raise TypeError
57+
58+
1959
def test_all_steps_fit_transform():
2060
"""
2161
All steps must implement fit and transform. Otherwise, raise TypeError.
@@ -25,3 +65,36 @@ def test_all_steps_fit_transform():
2565

2666
with pytest.raises(TypeError):
2767
TransformerPipeline([('svc', NoFitT())])
68+
69+
70+
@patch.object(Trans, 'fit', side_effect=func_x_y)
71+
def test_called_with_x_and_y(mock_fit):
72+
"""
73+
Fit method with required X and y arguments is called with both and with
74+
any additional keywords
75+
"""
76+
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg')
77+
mock_fit.assert_called_with('X', 'y', kwarg='kwarg')
78+
79+
80+
@patch.object(Trans, 'fit', side_effect=func_x)
81+
def test_called_with_x(mock_fit):
82+
"""
83+
Fit method with a required X arguments is called with it and with
84+
any additional keywords
85+
"""
86+
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg')
87+
mock_fit.assert_called_with('X', kwarg='kwarg')
88+
89+
_call_fit(Trans().fit, 'X', kwarg='kwarg')
90+
mock_fit.assert_called_with('X', kwarg='kwarg')
91+
92+
93+
@patch.object(Trans, 'fit', side_effect=func_raise_type_err)
94+
def test_raises_type_error(mock_fit):
95+
"""
96+
If a fit method with required X and y arguments raises a TypeError, it's
97+
re-raised (for a different reason) when it's called with one argument
98+
"""
99+
with pytest.raises(TypeError):
100+
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg')

0 commit comments

Comments
 (0)