Skip to content

Commit f18295c

Browse files
committed
TST improve test-coverage in base, remove unreachable code-path
1 parent 7c70198 commit f18295c

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

sklearn/base.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def get_params(self, deep=True):
196196
out = dict()
197197
for key in self._get_param_names():
198198
value = getattr(self, key, None)
199+
# XXX: should we rather test if instance of estimator?
199200
if deep and hasattr(value, 'get_params'):
200201
deep_items = value.get_params().items()
201202
out.update((key + '__' + k, val) for k, val in deep_items)
@@ -227,12 +228,6 @@ def set_params(self, **params):
227228
raise ValueError('Invalid parameter %s for estimator %s'
228229
% (name, self))
229230
sub_object = valid_params[name]
230-
if not hasattr(sub_object, 'get_params'):
231-
raise TypeError(
232-
'Parameter %s of %s is not an estimator, cannot set '
233-
'sub parameter %s' %
234-
(sub_name, self.__class__.__name__, sub_name)
235-
)
236231
sub_object.set_params(**{sub_name: value})
237232
else:
238233
# simple objects case

sklearn/tests/test_base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ def __init__(self, a=None):
4444
self.a = 1
4545

4646

47+
class NoEstimator(object):
48+
def __init__(self):
49+
pass
50+
51+
def fit(self, X=None, y=None):
52+
return self
53+
54+
def predict(self, X=None):
55+
return None
56+
57+
58+
class VargEstimator(BaseEstimator):
59+
"""Sklearn estimators shouldn't have vargs."""
60+
def __init__(self, *vargs):
61+
pass
62+
63+
4764
#############################################################################
4865
# The tests
4966

@@ -88,6 +105,12 @@ def test_clone_buggy():
88105
buggy.a = 2
89106
assert_raises(RuntimeError, clone, buggy)
90107

108+
no_estimator = NoEstimator()
109+
assert_raises(TypeError, clone, no_estimator)
110+
111+
varg_est = VargEstimator()
112+
assert_raises(RuntimeError, clone, varg_est)
113+
91114

92115
def test_clone_empty_array():
93116
"""Regression test for cloning estimators with empty arrays"""
@@ -110,6 +133,9 @@ def test_repr():
110133
"T(a=K(c=None, d=None), b=K(c=None, d=None))"
111134
)
112135

136+
some_est = T(a=["long_params"] * 1000)
137+
assert_equal(len(repr(some_est)), 415)
138+
113139

114140
def test_str():
115141
"""Smoke test the str of the base estimator"""
@@ -135,3 +161,16 @@ def test_is_classifier():
135161
assert_true(is_classifier(Pipeline([('svc', svc)])))
136162
assert_true(is_classifier(Pipeline([('svc_cv',
137163
GridSearchCV(svc, {'C': [0.1, 1]}))])))
164+
165+
166+
def test_set_params():
167+
# test nested estimator parameter setting
168+
clf = Pipeline([("svc", SVC())])
169+
# non-existing parameter in svc
170+
assert_raises(ValueError, clf.set_params, svc__stupid_param=True)
171+
# non-existing parameter of pipeline
172+
assert_raises(ValueError, clf.set_params, svm__stupid_param=True)
173+
# we don't currently catch if the things in pipeline are estimators
174+
#bad_pipeline = Pipeline([("bad", NoEstimator())])
175+
#assert_raises(AttributeError, bad_pipeline.set_params,
176+
#bad__stupid_param=True)

0 commit comments

Comments
 (0)