Skip to content

Commit 2dbc747

Browse files
committed
All steps in the TransformerPipeline must implement fit/transform.
1 parent c2a1c59 commit 2dbc747

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

pipeline.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
11
import six
22
from sklearn.pipeline import _name_estimators, Pipeline
3+
from sklearn.utils import tosequence
34

45

56
class TransformerPipeline(Pipeline):
67
"""
7-
Pipeline that expects all steps to be transformers taking a single argument.
8+
Pipeline that expects all steps to be transformers taking a single argument
9+
and having fit and transform methods.
810
911
Code is copied from sklearn's Pipeline, leaving out the `y=None` argument.
1012
"""
13+
def __init__(self, steps):
14+
names, estimators = zip(*steps)
15+
if len(dict(steps)) != len(steps):
16+
raise ValueError("Provided step names are not unique: %s" % (names,))
17+
18+
# shallow copy of steps
19+
self.steps = tosequence(steps)
20+
estimator = estimators[-1]
21+
22+
for e in estimators:
23+
if (not (hasattr(e, "fit") or hasattr(e, "fit_transform")) or not
24+
hasattr(e, "transform")):
25+
raise TypeError("All steps of the chain should "
26+
"be transforms and implement fit and transform"
27+
" '%s' (type %s) doesn't)" % (e, type(e)))
28+
29+
if not hasattr(estimator, "fit"):
30+
raise TypeError("Last step of chain should implement fit "
31+
"'%s' (type %s) doesn't)"
32+
% (estimator, type(estimator)))
33+
1134
def _pre_transform(self, X, **fit_params):
1235
fit_params_steps = dict((step, {}) for step, _ in self.steps)
1336
for pname, pval in six.iteritems(fit_params):

test_pipeline.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
from pipeline import TransformerPipeline
3+
4+
5+
class NoTransformT(object):
6+
"""Transformer without transform method.
7+
"""
8+
def fit(self, x):
9+
return self
10+
11+
12+
class NoFitT(object):
13+
"""Transformer without fit method.
14+
"""
15+
def transform(self, x):
16+
return self
17+
18+
19+
def test_all_steps_fit_transform():
20+
"""
21+
All steps must implement fit and transform. Otherwise, raise TypeError.
22+
"""
23+
with pytest.raises(TypeError):
24+
TransformerPipeline([('svc', NoTransformT())])
25+
26+
with pytest.raises(TypeError):
27+
TransformerPipeline([('svc', NoFitT())])

0 commit comments

Comments
 (0)