Skip to content

Commit 8372f35

Browse files
committed
added unit tests for _call_fit
added unit tests for _call_fit to test_pipeline
1 parent 575d2d6 commit 8372f35

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

tests/test_pipeline.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
import pytest
2-
from sklearn_pandas.pipeline import TransformerPipeline
2+
from sklearn_pandas.pipeline import TransformerPipeline, _call_fit
3+
from functools import partial
4+
5+
# In py3, mock is included with the unittest standard library
6+
# In py2, it's a separate package
7+
try:
8+
from unittest.mock import patch
9+
except ImportError:
10+
from mock import patch
311

412

513
class NoTransformT(object):
@@ -16,6 +24,39 @@ def transform(self, x):
1624
return self
1725

1826

27+
class Trans(object):
28+
"""
29+
Transformer with fit and transform methods
30+
"""
31+
def fit(self, x, y=None):
32+
return self
33+
34+
def transform(self, x):
35+
return self
36+
37+
38+
def func_x_y(x, y, kwarg='kwarg'):
39+
"""
40+
Function with required x and y arguments
41+
"""
42+
return
43+
44+
45+
def func_x(x, kwarg='kwarg'):
46+
"""
47+
Function with required x argument
48+
"""
49+
return
50+
51+
52+
def func_raise_type_err(x, y, kwarg='kwarg'):
53+
"""
54+
Function with required x and y arguments,
55+
raises TypeError
56+
"""
57+
raise TypeError
58+
59+
1960
def test_all_steps_fit_transform():
2061
"""
2162
All steps must implement fit and transform. Otherwise, raise TypeError.
@@ -25,3 +66,36 @@ def test_all_steps_fit_transform():
2566

2667
with pytest.raises(TypeError):
2768
TransformerPipeline([('svc', NoFitT())])
69+
70+
71+
@patch.object(Trans, 'fit', side_effect=func_x_y)
72+
def test_called_with_x_and_y(mock_fit):
73+
"""
74+
Fit method with required X and y arguments is called with both and with
75+
any additional keywords
76+
"""
77+
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg')
78+
mock_fit.assert_called_with('X', 'y', kwarg='kwarg')
79+
80+
81+
@patch.object(Trans, 'fit', side_effect=func_x)
82+
def test_called_with_x(mock_fit):
83+
"""
84+
Fit method with a required X arguments is called with it and with
85+
any additional keywords
86+
"""
87+
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg')
88+
mock_fit.assert_called_with('X', kwarg='kwarg')
89+
90+
_call_fit(Trans().fit, 'X', kwarg='kwarg')
91+
mock_fit.assert_called_with('X', kwarg='kwarg')
92+
93+
94+
@patch.object(Trans, 'fit', side_effect=func_raise_type_err)
95+
def test_raises_type_error(mock_fit):
96+
"""
97+
If a fit method with required X and y arguments raises a TypeError, it's
98+
re-raised (for a different reason) when it's called with one argument
99+
"""
100+
with pytest.raises(TypeError):
101+
_call_fit(Trans().fit, 'X', 'y', kwarg='kwarg')

0 commit comments

Comments
 (0)