Skip to content

Commit 575d2d6

Browse files
committed
refactored to use _call_fit
refactored pipeline and dataframe_mapper to use _call_fit helper function, consolidating code to handle fit with 1 or 2 arguments
1 parent 44b3591 commit 575d2d6

File tree

2 files changed

+38
-38
lines changed

2 files changed

+38
-38
lines changed

sklearn_pandas/dataframe_mapper.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.base import BaseEstimator, TransformerMixin
66

77
from .cross_validation import DataWrapper
8-
from .pipeline import make_transformer_pipeline
8+
from .pipeline import make_transformer_pipeline, _call_fit
99

1010
# load in the correct stringtype: str for py3, basestring for py2
1111
string_types = str if sys.version_info >= (3, 0) else basestring
@@ -136,21 +136,13 @@ def fit(self, X, y=None):
136136
"""
137137
for columns, transformers in self.features:
138138
if transformers is not None:
139-
try:
140-
transformers.fit(self._get_col_subset(X, columns), y)
141-
except TypeError:
142-
# fit takes only one argument
143-
transformers.fit(self._get_col_subset(X, columns))
139+
_call_fit(transformers.fit,
140+
self._get_col_subset(X, columns), y)
144141

145142
# handle features not explicitly selected
146143
if self.default: # not False and not None
147-
try:
148-
self.default.fit(
149-
self._get_col_subset(X, self._unselected_columns(X)), y)
150-
except TypeError:
151-
# fit takes only one argument
152-
self.default.fit(
153-
self._get_col_subset(X, self._unselected_columns(X)))
144+
_call_fit(self.default.fit,
145+
self._get_col_subset(X, self._unselected_columns(X)), y)
154146
return self
155147

156148
def transform(self, X):

sklearn_pandas/pipeline.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,30 @@
33
from sklearn.utils import tosequence
44

55

6+
def _call_fit(fit_method, X, y=None, **kwargs):
7+
"""
8+
helper function, calls the fit or fit_transform method with the correct
9+
number of parameters
10+
11+
fit_method: fit or fit_transform method of the transformer
12+
X: the data to fit
13+
y: the target vector relative to X, optional
14+
kwargs: any keyword arguments to the fit method
15+
16+
return: the result of the fit or fit_transform method
17+
18+
WARNING: if this function raises a TypeError exception, test the fit
19+
or fit_transform method passed to it in isolation as _call_fit will not
20+
distinguish TypeError due to incorrect number of arguments from
21+
other TypeError
22+
"""
23+
try:
24+
return fit_method(X, y, **kwargs)
25+
except TypeError:
26+
# fit takes only one argument
27+
return fit_method(X, **kwargs)
28+
29+
630
class TransformerPipeline(Pipeline):
731
"""
832
Pipeline that expects all steps to be transformers taking a single X argument,
@@ -40,42 +64,26 @@ def _pre_transform(self, X, y=None, **fit_params):
4064
Xt = X
4165
for name, transform in self.steps[:-1]:
4266
if hasattr(transform, "fit_transform"):
43-
try:
44-
Xt = transform.fit_transform(Xt, y, **fit_params_steps[name])
45-
except TypeError:
46-
# fit takes only one argument
47-
Xt = transform.fit_transform(Xt, **fit_params_steps[name])
67+
Xt = _call_fit(transform.fit_transform,
68+
Xt, y, **fit_params_steps[name])
4869
else:
49-
try:
50-
Xt = transform.fit(Xt, y, **fit_params_steps[name]).transform(Xt)
51-
except TypeError:
52-
# fit takes only one argument
53-
Xt = transform.fit(Xt, **fit_params_steps[name]).transform(Xt)
70+
Xt = _call_fit(transform.fit,
71+
Xt, y, **fit_params_steps[name]).transform(Xt)
5472
return Xt, fit_params_steps[self.steps[-1][0]]
5573

5674
def fit(self, X, y=None, **fit_params):
5775
Xt, fit_params = self._pre_transform(X, y, **fit_params)
58-
try:
59-
self.steps[-1][-1].fit(Xt, y, **fit_params)
60-
except TypeError:
61-
# fit takes only one argument
62-
self.steps[-1][-1].fit(Xt, **fit_params)
76+
_call_fit(self.steps[-1][-1].fit, Xt, y, **fit_params)
6377
return self
6478

6579
def fit_transform(self, X, y=None, **fit_params):
6680
Xt, fit_params = self._pre_transform(X, y, **fit_params)
6781
if hasattr(self.steps[-1][-1], 'fit_transform'):
68-
try:
69-
return self.steps[-1][-1].fit_transform(Xt, y, **fit_params)
70-
except TypeError:
71-
# fit_transform takes only one argument
72-
return self.steps[-1][-1].fit_transform(Xt, **fit_params)
82+
return _call_fit(self.steps[-1][-1].fit_transform,
83+
Xt, y, **fit_params)
7384
else:
74-
try:
75-
return self.steps[-1][-1].fit(Xt, y, **fit_params).transform(Xt)
76-
except:
77-
# fit takes only one argument
78-
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt)
85+
return _call_fit(self.steps[-1][-1].fit,
86+
Xt, y, **fit_params).transform(Xt)
7987

8088

8189
def make_transformer_pipeline(*steps):

0 commit comments

Comments
 (0)