Skip to content

Commit 44b3591

Browse files
committed
added y argument to fit methods
added optional y argument to fit methods of TransformerPipeline and DataFrameMapper
1 parent 7a3c997 commit 44b3591

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

sklearn_pandas/dataframe_mapper.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,27 @@ def fit(self, X, y=None):
130130
Fit a transformation from the pipeline
131131
132132
X the data to fit
133+
134+
y the target vector relative to X, optional
135+
133136
"""
134137
for columns, transformers in self.features:
135138
if transformers is not None:
136-
transformers.fit(self._get_col_subset(X, columns))
139+
try:
140+
transformers.fit(self._get_col_subset(X, columns), y)
141+
except TypeError:
142+
# fit takes only one argument
143+
transformers.fit(self._get_col_subset(X, columns))
137144

138145
# handle features not explicitly selected
139146
if self.default: # not False and not None
140-
self.default.fit(
141-
self._get_col_subset(X, self._unselected_columns(X))
142-
)
147+
try:
148+
self.default.fit(
149+
self._get_col_subset(X, self._unselected_columns(X)), y)
150+
except TypeError:
151+
# fit takes only one argument
152+
self.default.fit(
153+
self._get_col_subset(X, self._unselected_columns(X)))
143154
return self
144155

145156
def transform(self, X):

sklearn_pandas/pipeline.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
class TransformerPipeline(Pipeline):
77
"""
8-
Pipeline that expects all steps to be transformers taking a single argument
8+
Pipeline that expects all steps to be transformers taking a single X argument,
9+
an optional y argument,
910
and having fit and transform methods.
1011
11-
Code is copied from sklearn's Pipeline, leaving out the `y=None` argument.
12+
Code is copied from sklearn's Pipeline
1213
"""
1314
def __init__(self, steps):
1415
names, estimators = zip(*steps)
@@ -31,31 +32,50 @@ def __init__(self, steps):
3132
"'%s' (type %s) doesn't)"
3233
% (estimator, type(estimator)))
3334

34-
def _pre_transform(self, X, **fit_params):
35+
def _pre_transform(self, X, y=None, **fit_params):
3536
fit_params_steps = dict((step, {}) for step, _ in self.steps)
3637
for pname, pval in six.iteritems(fit_params):
3738
step, param = pname.split('__', 1)
3839
fit_params_steps[step][param] = pval
3940
Xt = X
4041
for name, transform in self.steps[:-1]:
4142
if hasattr(transform, "fit_transform"):
42-
Xt = transform.fit_transform(Xt, **fit_params_steps[name])
43+
try:
44+
Xt = transform.fit_transform(Xt, y, **fit_params_steps[name])
45+
except TypeError:
46+
# fit takes only one argument
47+
Xt = transform.fit_transform(Xt, **fit_params_steps[name])
4348
else:
44-
Xt = transform.fit(Xt, **fit_params_steps[name]) \
45-
.transform(Xt)
49+
try:
50+
Xt = transform.fit(Xt, y, **fit_params_steps[name]).transform(Xt)
51+
except TypeError:
52+
# fit takes only one argument
53+
Xt = transform.fit(Xt, **fit_params_steps[name]).transform(Xt)
4654
return Xt, fit_params_steps[self.steps[-1][0]]
4755

48-
def fit(self, X, **fit_params):
49-
Xt, fit_params = self._pre_transform(X, **fit_params)
50-
self.steps[-1][-1].fit(Xt, **fit_params)
56+
def fit(self, X, y=None, **fit_params):
57+
Xt, fit_params = self._pre_transform(X, y, **fit_params)
58+
try:
59+
self.steps[-1][-1].fit(Xt, y, **fit_params)
60+
except TypeError:
61+
# fit takes only one argument
62+
self.steps[-1][-1].fit(Xt, **fit_params)
5163
return self
5264

53-
def fit_transform(self, X, **fit_params):
54-
Xt, fit_params = self._pre_transform(X, **fit_params)
65+
def fit_transform(self, X, y=None, **fit_params):
66+
Xt, fit_params = self._pre_transform(X, y, **fit_params)
5567
if hasattr(self.steps[-1][-1], 'fit_transform'):
56-
return self.steps[-1][-1].fit_transform(Xt, **fit_params)
68+
try:
69+
return self.steps[-1][-1].fit_transform(Xt, y, **fit_params)
70+
except TypeError:
71+
# fit_transform takes only one argument
72+
return self.steps[-1][-1].fit_transform(Xt, **fit_params)
5773
else:
58-
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt)
74+
try:
75+
return self.steps[-1][-1].fit(Xt, y, **fit_params).transform(Xt)
76+
except:
77+
# fit takes only one argument
78+
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt)
5979

6080

6181
def make_transformer_pipeline(*steps):

0 commit comments

Comments
 (0)