|
3 | 3 | from sklearn.utils import tosequence
|
4 | 4 |
|
5 | 5 |
|
| 6 | +def _call_fit(fit_method, X, y=None, **kwargs): |
| 7 | + """ |
| 8 | + helper function, calls the fit or fit_transform method with the correct |
| 9 | + number of parameters |
| 10 | +
|
| 11 | + fit_method: fit or fit_transform method of the transformer |
| 12 | + X: the data to fit |
| 13 | + y: the target vector relative to X, optional |
| 14 | + kwargs: any keyword arguments to the fit method |
| 15 | +
|
| 16 | + return: the result of the fit or fit_transform method |
| 17 | +
|
| 18 | + WARNING: if this function raises a TypeError exception, test the fit |
| 19 | + or fit_transform method passed to it in isolation as _call_fit will not |
| 20 | + distinguish TypeError due to incorrect number of arguments from |
| 21 | + other TypeError |
| 22 | + """ |
| 23 | + try: |
| 24 | + return fit_method(X, y, **kwargs) |
| 25 | + except TypeError: |
| 26 | + # fit takes only one argument |
| 27 | + return fit_method(X, **kwargs) |
| 28 | + |
| 29 | + |
6 | 30 | class TransformerPipeline(Pipeline):
|
7 | 31 | """
|
8 | 32 | Pipeline that expects all steps to be transformers taking a single X argument,
|
@@ -40,42 +64,26 @@ def _pre_transform(self, X, y=None, **fit_params):
|
40 | 64 | Xt = X
|
41 | 65 | for name, transform in self.steps[:-1]:
|
42 | 66 | if hasattr(transform, "fit_transform"):
|
43 |
| - try: |
44 |
| - Xt = transform.fit_transform(Xt, y, **fit_params_steps[name]) |
45 |
| - except TypeError: |
46 |
| - # fit takes only one argument |
47 |
| - Xt = transform.fit_transform(Xt, **fit_params_steps[name]) |
| 67 | + Xt = _call_fit(transform.fit_transform, |
| 68 | + Xt, y, **fit_params_steps[name]) |
48 | 69 | else:
|
49 |
| - try: |
50 |
| - Xt = transform.fit(Xt, y, **fit_params_steps[name]).transform(Xt) |
51 |
| - except TypeError: |
52 |
| - # fit takes only one argument |
53 |
| - Xt = transform.fit(Xt, **fit_params_steps[name]).transform(Xt) |
| 70 | + Xt = _call_fit(transform.fit, |
| 71 | + Xt, y, **fit_params_steps[name]).transform(Xt) |
54 | 72 | return Xt, fit_params_steps[self.steps[-1][0]]
|
55 | 73 |
|
56 | 74 | def fit(self, X, y=None, **fit_params):
|
57 | 75 | Xt, fit_params = self._pre_transform(X, y, **fit_params)
|
58 |
| - try: |
59 |
| - self.steps[-1][-1].fit(Xt, y, **fit_params) |
60 |
| - except TypeError: |
61 |
| - # fit takes only one argument |
62 |
| - self.steps[-1][-1].fit(Xt, **fit_params) |
| 76 | + _call_fit(self.steps[-1][-1].fit, Xt, y, **fit_params) |
63 | 77 | return self
|
64 | 78 |
|
65 | 79 | def fit_transform(self, X, y=None, **fit_params):
|
66 | 80 | Xt, fit_params = self._pre_transform(X, y, **fit_params)
|
67 | 81 | if hasattr(self.steps[-1][-1], 'fit_transform'):
|
68 |
| - try: |
69 |
| - return self.steps[-1][-1].fit_transform(Xt, y, **fit_params) |
70 |
| - except TypeError: |
71 |
| - # fit_transform takes only one argument |
72 |
| - return self.steps[-1][-1].fit_transform(Xt, **fit_params) |
| 82 | + return _call_fit(self.steps[-1][-1].fit_transform, |
| 83 | + Xt, y, **fit_params) |
73 | 84 | else:
|
74 |
| - try: |
75 |
| - return self.steps[-1][-1].fit(Xt, y, **fit_params).transform(Xt) |
76 |
| - except: |
77 |
| - # fit takes only one argument |
78 |
| - return self.steps[-1][-1].fit(Xt, **fit_params).transform(Xt) |
| 85 | + return _call_fit(self.steps[-1][-1].fit, |
| 86 | + Xt, y, **fit_params).transform(Xt) |
79 | 87 |
|
80 | 88 |
|
81 | 89 | def make_transformer_pipeline(*steps):
|
|
0 commit comments