Skip to content

Commit b162aca

Browse files
amuellerthomasjpfan
authored andcommitted
MAINT Slight common tests cleanup (scikit-learn#14511)
1 parent 7a87ac5 commit b162aca

File tree

3 files changed

+7
-23
lines changed

3 files changed

+7
-23
lines changed

sklearn/tests/test_common.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
_safe_tags,
3333
set_checking_parameters,
3434
check_parameters_default_constructible,
35-
check_no_attributes_set_in_init,
3635
check_class_weight_balanced_linear_classifier)
3736

3837

@@ -111,22 +110,6 @@ def test_estimators(estimator, check):
111110
check(name, estimator)
112111

113112

114-
@pytest.mark.parametrize("name, estimator",
115-
_tested_estimators())
116-
def test_no_attributes_set_in_init(name, estimator):
117-
# input validation etc for all estimators
118-
with ignore_warnings(category=(DeprecationWarning, ConvergenceWarning,
119-
UserWarning, FutureWarning)):
120-
tags = _safe_tags(estimator)
121-
if tags['_skip_test']:
122-
warnings.warn("Explicit SKIP via _skip_test tag for "
123-
"{}.".format(name),
124-
SkipTestWarning)
125-
return
126-
# check this on class
127-
check_no_attributes_set_in_init(name, estimator)
128-
129-
130113
@ignore_warnings(category=DeprecationWarning)
131114
# ignore deprecated open(.., 'U') in numpy distutils
132115
def test_configure():

sklearn/utils/estimator_checks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _safe_tags(estimator, key=None):
7272

7373
def _yield_checks(name, estimator):
7474
tags = _safe_tags(estimator)
75+
yield check_no_attributes_set_in_init
7576
yield check_estimators_dtypes
7677
yield check_fit_score_takes_y
7778
yield check_sample_weights_pandas_series
@@ -288,7 +289,6 @@ def check_estimator(Estimator):
288289
name = Estimator.__name__
289290
estimator = Estimator()
290291
check_parameters_default_constructible(name, Estimator)
291-
check_no_attributes_set_in_init(name, estimator)
292292
else:
293293
# got an instance
294294
estimator = Estimator
@@ -2056,9 +2056,10 @@ def check_estimators_overwrite_params(name, estimator_orig):
20562056
% (name, param_name, original_value, new_value))
20572057

20582058

2059-
def check_no_attributes_set_in_init(name, estimator):
2059+
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
2060+
def check_no_attributes_set_in_init(name, estimator_orig):
20602061
"""Check setting during init. """
2061-
2062+
estimator = clone(estimator_orig)
20622063
if hasattr(type(estimator).__init__, "deprecated_original"):
20632064
return
20642065

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_check_estimator():
414414

415415
# doesn't error on actual estimator
416416
check_estimator(LogisticRegression)
417-
check_estimator(LogisticRegression())
417+
check_estimator(LogisticRegression(C=0.01))
418418
check_estimator(MultiTaskElasticNet)
419419
check_estimator(MultiTaskElasticNet())
420420

@@ -483,11 +483,11 @@ def test_check_estimators_unfitted():
483483

484484

485485
def test_check_no_attributes_set_in_init():
486-
class NonConformantEstimatorPrivateSet:
486+
class NonConformantEstimatorPrivateSet(BaseEstimator):
487487
def __init__(self):
488488
self.you_should_not_set_this_ = None
489489

490-
class NonConformantEstimatorNoParamSet:
490+
class NonConformantEstimatorNoParamSet(BaseEstimator):
491491
def __init__(self, you_should_set_this_=None):
492492
pass
493493

0 commit comments

Comments
 (0)