1
1
import pytest
2
- from sklearn_pandas .pipeline import TransformerPipeline
2
+ from sklearn_pandas .pipeline import TransformerPipeline , _call_fit
3
+ from functools import partial
4
+
5
+ # In py3, mock is included with the unittest standard library
6
+ # In py2, it's a separate package
7
+ try :
8
+ from unittest .mock import patch
9
+ except ImportError :
10
+ from mock import patch
3
11
4
12
5
13
class NoTransformT (object ):
@@ -16,6 +24,39 @@ def transform(self, x):
16
24
return self
17
25
18
26
27
+ class Trans (object ):
28
+ """
29
+ Transformer with fit and transform methods
30
+ """
31
+ def fit (self , x , y = None ):
32
+ return self
33
+
34
+ def transform (self , x ):
35
+ return self
36
+
37
+
38
+ def func_x_y (x , y , kwarg = 'kwarg' ):
39
+ """
40
+ Function with required x and y arguments
41
+ """
42
+ return
43
+
44
+
45
+ def func_x (x , kwarg = 'kwarg' ):
46
+ """
47
+ Function with required x argument
48
+ """
49
+ return
50
+
51
+
52
+ def func_raise_type_err (x , y , kwarg = 'kwarg' ):
53
+ """
54
+ Function with required x and y arguments,
55
+ raises TypeError
56
+ """
57
+ raise TypeError
58
+
59
+
19
60
def test_all_steps_fit_transform ():
20
61
"""
21
62
All steps must implement fit and transform. Otherwise, raise TypeError.
@@ -25,3 +66,36 @@ def test_all_steps_fit_transform():
25
66
26
67
with pytest .raises (TypeError ):
27
68
TransformerPipeline ([('svc' , NoFitT ())])
69
+
70
+
71
+ @patch .object (Trans , 'fit' , side_effect = func_x_y )
72
+ def test_called_with_x_and_y (mock_fit ):
73
+ """
74
+ Fit method with required X and y arguments is called with both and with
75
+ any additional keywords
76
+ """
77
+ _call_fit (Trans ().fit , 'X' , 'y' , kwarg = 'kwarg' )
78
+ mock_fit .assert_called_with ('X' , 'y' , kwarg = 'kwarg' )
79
+
80
+
81
+ @patch .object (Trans , 'fit' , side_effect = func_x )
82
+ def test_called_with_x (mock_fit ):
83
+ """
84
+ Fit method with a required X arguments is called with it and with
85
+ any additional keywords
86
+ """
87
+ _call_fit (Trans ().fit , 'X' , 'y' , kwarg = 'kwarg' )
88
+ mock_fit .assert_called_with ('X' , kwarg = 'kwarg' )
89
+
90
+ _call_fit (Trans ().fit , 'X' , kwarg = 'kwarg' )
91
+ mock_fit .assert_called_with ('X' , kwarg = 'kwarg' )
92
+
93
+
94
+ @patch .object (Trans , 'fit' , side_effect = func_raise_type_err )
95
+ def test_raises_type_error (mock_fit ):
96
+ """
97
+ If a fit method with required X and y arguments raises a TypeError, it's
98
+ re-raised (for a different reason) when it's called with one argument
99
+ """
100
+ with pytest .raises (TypeError ):
101
+ _call_fit (Trans ().fit , 'X' , 'y' , kwarg = 'kwarg' )
0 commit comments