5
5
6
6
class TransformerPipeline (Pipeline ):
7
7
"""
8
- Pipeline that expects all steps to be transformers taking a single argument
8
+ Pipeline that expects all steps to be transformers taking a single X argument,
9
+ an optional y argument,
9
10
and having fit and transform methods.
10
11
11
- Code is copied from sklearn's Pipeline, leaving out the `y=None` argument.
12
+ Code is copied from sklearn's Pipeline
12
13
"""
13
14
def __init__ (self , steps ):
14
15
names , estimators = zip (* steps )
@@ -31,31 +32,50 @@ def __init__(self, steps):
31
32
"'%s' (type %s) doesn't)"
32
33
% (estimator , type (estimator )))
33
34
34
- def _pre_transform (self , X , ** fit_params ):
35
+ def _pre_transform (self , X , y = None , ** fit_params ):
35
36
fit_params_steps = dict ((step , {}) for step , _ in self .steps )
36
37
for pname , pval in six .iteritems (fit_params ):
37
38
step , param = pname .split ('__' , 1 )
38
39
fit_params_steps [step ][param ] = pval
39
40
Xt = X
40
41
for name , transform in self .steps [:- 1 ]:
41
42
if hasattr (transform , "fit_transform" ):
42
- Xt = transform .fit_transform (Xt , ** fit_params_steps [name ])
43
+ try :
44
+ Xt = transform .fit_transform (Xt , y , ** fit_params_steps [name ])
45
+ except TypeError :
46
+ # fit takes only one argument
47
+ Xt = transform .fit_transform (Xt , ** fit_params_steps [name ])
43
48
else :
44
- Xt = transform .fit (Xt , ** fit_params_steps [name ]) \
45
- .transform (Xt )
49
+ try :
50
+ Xt = transform .fit (Xt , y , ** fit_params_steps [name ]).transform (Xt )
51
+ except TypeError :
52
+ # fit takes only one argument
53
+ Xt = transform .fit (Xt , ** fit_params_steps [name ]).transform (Xt )
46
54
return Xt , fit_params_steps [self .steps [- 1 ][0 ]]
47
55
48
- def fit (self , X , ** fit_params ):
49
- Xt , fit_params = self ._pre_transform (X , ** fit_params )
50
- self .steps [- 1 ][- 1 ].fit (Xt , ** fit_params )
56
+ def fit (self , X , y = None , ** fit_params ):
57
+ Xt , fit_params = self ._pre_transform (X , y , ** fit_params )
58
+ try :
59
+ self .steps [- 1 ][- 1 ].fit (Xt , y , ** fit_params )
60
+ except TypeError :
61
+ # fit takes only one argument
62
+ self .steps [- 1 ][- 1 ].fit (Xt , ** fit_params )
51
63
return self
52
64
53
- def fit_transform (self , X , ** fit_params ):
54
- Xt , fit_params = self ._pre_transform (X , ** fit_params )
65
+ def fit_transform (self , X , y = None , ** fit_params ):
66
+ Xt , fit_params = self ._pre_transform (X , y , ** fit_params )
55
67
if hasattr (self .steps [- 1 ][- 1 ], 'fit_transform' ):
56
- return self .steps [- 1 ][- 1 ].fit_transform (Xt , ** fit_params )
68
+ try :
69
+ return self .steps [- 1 ][- 1 ].fit_transform (Xt , y , ** fit_params )
70
+ except TypeError :
71
+ # fit_transform takes only one argument
72
+ return self .steps [- 1 ][- 1 ].fit_transform (Xt , ** fit_params )
57
73
else :
58
- return self .steps [- 1 ][- 1 ].fit (Xt , ** fit_params ).transform (Xt )
74
+ try :
75
+ return self .steps [- 1 ][- 1 ].fit (Xt , y , ** fit_params ).transform (Xt )
76
+ except :
77
+ # fit takes only one argument
78
+ return self .steps [- 1 ][- 1 ].fit (Xt , ** fit_params ).transform (Xt )
59
79
60
80
61
81
def make_transformer_pipeline (* steps ):
0 commit comments