-
Notifications
You must be signed in to change notification settings - Fork 418
added y argument to fit methods #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,37 @@ | |
from sklearn.utils import tosequence | ||
|
||
|
||
def _call_fit(fit_method, X, y=None, **kwargs): | ||
""" | ||
helper function, calls the fit or fit_transform method with the correct | ||
number of parameters | ||
|
||
fit_method: fit or fit_transform method of the transformer | ||
X: the data to fit | ||
y: the target vector relative to X, optional | ||
kwargs: any keyword arguments to the fit method | ||
|
||
return: the result of the fit or fit_transform method | ||
|
||
WARNING: if this function raises a TypeError exception, test the fit | ||
or fit_transform method passed to it in isolation as _call_fit will not | ||
distinguish TypeError due to incorrect number of arguments from | ||
other TypeError | ||
""" | ||
try: | ||
return fit_method(X, y, **kwargs) | ||
except TypeError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about:
I know it looks a bit hacky but I guess it will solve your warning above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty clever, I hadn't thought of doing something like this before. Unfortunately when I tested it, the error message accompanying a TypeError varies between Python 2 and 3. On 2, it's "test_func() takes exactly 2 arguments (1 given)" but on 3 it's "test_func() missing 1 required positional argument:". It's probably safer to leave as is. |
||
# fit takes only one argument | ||
return fit_method(X, **kwargs) | ||
|
||
|
||
class TransformerPipeline(Pipeline): | ||
""" | ||
Pipeline that expects all steps to be transformers taking a single argument | ||
Pipeline that expects all steps to be transformers taking a single X argument, | ||
an optional y argument, | ||
and having fit and transform methods. | ||
|
||
Code is copied from sklearn's Pipeline, leaving out the `y=None` argument. | ||
Code is copied from sklearn's Pipeline | ||
""" | ||
def __init__(self, steps): | ||
names, estimators = zip(*steps) | ||
|
@@ -31,31 +56,34 @@ def __init__(self, steps): | |
"'%s' (type %s) doesn't)" | ||
% (estimator, type(estimator))) | ||
|
||
def _pre_transform(self, X, **fit_params): | ||
def _pre_transform(self, X, y=None, **fit_params): | ||
fit_params_steps = dict((step, {}) for step, _ in self.steps) | ||
for pname, pval in six.iteritems(fit_params): | ||
step, param = pname.split('__', 1) | ||
fit_params_steps[step][param] = pval | ||
Xt = X | ||
for name, transform in self.steps[:-1]: | ||
if hasattr(transform, "fit_transform"): | ||
Xt = transform.fit_transform(Xt, **fit_params_steps[name]) | ||
Xt = _call_fit(transform.fit_transform, | ||
Xt, y, **fit_params_steps[name]) | ||
else: | ||
Xt = transform.fit(Xt, **fit_params_steps[name]) \ | ||
.transform(Xt) | ||
Xt = _call_fit(transform.fit, | ||
Xt, y, **fit_params_steps[name]).transform(Xt) | ||
return Xt, fit_params_steps[self.steps[-1][0]] | ||
|
||
def fit(self, X, **fit_params): | ||
Xt, fit_params = self._pre_transform(X, **fit_params) | ||
self.steps[-1][-1].fit(Xt, **fit_params) | ||
def fit(self, X, y=None, **fit_params): | ||
Xt, fit_params = self._pre_transform(X, y, **fit_params) | ||
_call_fit(self.steps[-1][-1].fit, Xt, y, **fit_params) | ||
return self | ||
|
||
def fit_transform(self, X, **fit_params): | ||
Xt, fit_params = self._pre_transform(X, **fit_params) | ||
def fit_transform(self, X, y=None, **fit_params): | ||
Xt, fit_params = self._pre_transform(X, y, **fit_params) | ||
if hasattr(self.steps[-1][-1], 'fit_transform'): | ||
return self.steps[-1][-1].fit_transform(Xt, **fit_params) | ||
return _call_fit(self.steps[-1][-1].fit_transform, | ||
Xt, y, **fit_params) | ||
else: | ||
return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt) | ||
return _call_fit(self.steps[-1][-1].fit, | ||
Xt, y, **fit_params).transform(Xt) | ||
|
||
|
||
def make_transformer_pipeline(*steps): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
import pytest | ||
from sklearn_pandas.pipeline import TransformerPipeline | ||
from sklearn_pandas.pipeline import TransformerPipeline, _call_fit | ||
from functools import partial | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imported but unused; please remove this import. |
||
|
||
# In py3, mock is included with the unittest standard library | ||
# In py2, it's a separate package | ||
try: | ||
from unittest.mock import patch | ||
except ImportError: | ||
from mock import patch | ||
|
||
|
||
class NoTransformT(object): | ||
|
@@ -16,6 +24,39 @@ def transform(self, x): | |
return self | ||
|
||
|
||
class Trans(object): | ||
""" | ||
Transformer with fit and transform methods | ||
""" | ||
def fit(self, x, y=None): | ||
return self | ||
|
||
def transform(self, x): | ||
return self | ||
|
||
|
||
def func_x_y(x, y, kwarg='kwarg'): | ||
""" | ||
Function with required x and y arguments | ||
""" | ||
return | ||
|
||
|
||
def func_x(x, kwarg='kwarg'): | ||
""" | ||
Function with required x argument | ||
""" | ||
return | ||
|
||
|
||
def func_raise_type_err(x, y, kwarg='kwarg'): | ||
""" | ||
Function with required x and y arguments, | ||
raises TypeError | ||
""" | ||
raise TypeError | ||
|
||
|
||
def test_all_steps_fit_transform(): | ||
""" | ||
All steps must implement fit and transform. Otherwise, raise TypeError. | ||
|
@@ -25,3 +66,36 @@ def test_all_steps_fit_transform(): | |
|
||
with pytest.raises(TypeError): | ||
TransformerPipeline([('svc', NoFitT())]) | ||
|
||
|
||
@patch.object(Trans, 'fit', side_effect=func_x_y) | ||
def test_called_with_x_and_y(mock_fit): | ||
""" | ||
Fit method with required X and y arguments is called with both and with | ||
any additional keywords | ||
""" | ||
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg') | ||
mock_fit.assert_called_with('X', 'y', kwarg='kwarg') | ||
|
||
|
||
@patch.object(Trans, 'fit', side_effect=func_x) | ||
def test_called_with_x(mock_fit): | ||
""" | ||
Fit method with a required X arguments is called with it and with | ||
any additional keywords | ||
""" | ||
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg') | ||
mock_fit.assert_called_with('X', kwarg='kwarg') | ||
|
||
_call_fit(Trans().fit, 'X', kwarg='kwarg') | ||
mock_fit.assert_called_with('X', kwarg='kwarg') | ||
|
||
|
||
@patch.object(Trans, 'fit', side_effect=func_raise_type_err) | ||
def test_raises_type_error(mock_fit): | ||
""" | ||
If a fit method with required X and y arguments raises a TypeError, it's | ||
re-raised (for a different reason) when it's called with one argument | ||
""" | ||
with pytest.raises(TypeError): | ||
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does one need to pass "1" as second argument to this transform, and why is the output different from the previous case in the last column?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "1" is the second argument of np.round. This test case was failing even though I don't think I modified anything that affected it. The issue seems to be that on my machine np.round(-0.3) equals "0.", not "-0." Changing it to round to 1 decimal place fixed the test case.