1313import re
1414import pkgutil
1515
16+ import pytest
17+
1618from sklearn .utils .testing import assert_false , clean_warning_registry
1719from sklearn .utils .testing import all_estimators
1820from sklearn .utils .testing import assert_equal
@@ -41,34 +43,57 @@ def test_all_estimator_no_base_class():
4143
4244
4345def test_all_estimators ():
44- # Test that estimators are default-constructible, cloneable
45- # and have working repr.
4646 estimators = all_estimators (include_meta_estimators = True )
4747
4848 # Meta sanity-check to make sure that the estimator introspection runs
4949 # properly
5050 assert_greater (len (estimators ), 0 )
5151
52- for name , Estimator in estimators :
53- # some can just not be sensibly default constructed
54- yield check_parameters_default_constructible , name , Estimator
5552
53+ @pytest .mark .parametrize (
54+ 'name, Estimator' ,
55+ all_estimators (include_meta_estimators = True )
56+ )
57+ def test_parameters_default_constructible (name , Estimator ):
58+ # Test that estimators are default-constructible
59+ check_parameters_default_constructible (name , Estimator )
5660
57- def test_non_meta_estimators ():
58- # input validation etc for non-meta estimators
59- estimators = all_estimators ()
60- for name , Estimator in estimators :
61+
62+ def _tested_non_meta_estimators ():
63+ for name , Estimator in all_estimators ():
6164 if issubclass (Estimator , BiclusterMixin ):
6265 continue
6366 if name .startswith ("_" ):
6467 continue
68+ yield name , Estimator
69+
70+
71+ def _generate_checks_per_estimator (check_generator , estimators ):
72+ for name , Estimator in estimators :
6573 estimator = Estimator ()
66- # check this on class
67- yield check_no_attributes_set_in_init , name , estimator
74+ for check in check_generator ( name , estimator ):
75+ yield name , Estimator , check
6876
69- for check in _yield_all_checks (name , estimator ):
70- set_checking_parameters (estimator )
71- yield check , name , estimator
77+
78+ @pytest .mark .parametrize (
79+ "name, Estimator, check" ,
80+ _generate_checks_per_estimator (_yield_all_checks ,
81+ _tested_non_meta_estimators ())
82+ )
83+ def test_non_meta_estimators (name , Estimator , check ):
84+ # Common tests for non-meta estimators
85+ estimator = Estimator ()
86+ set_checking_parameters (estimator )
87+ check (name , estimator )
88+
89+
90+ @pytest .mark .parametrize ("name, Estimator" ,
91+ _tested_non_meta_estimators ())
92+ def test_no_attributes_set_in_init (name , Estimator ):
93+ # input validation etc for non-meta estimators
94+ estimator = Estimator ()
95+ # check this on class
96+ check_no_attributes_set_in_init (name , estimator )
7297
7398
7499def test_configure ():
@@ -95,19 +120,21 @@ def test_configure():
95120 os .chdir (cwd )
96121
97122
98- def test_class_weight_balanced_linear_classifiers ():
123+ def _tested_linear_classifiers ():
99124 classifiers = all_estimators (type_filter = 'classifier' )
100125
101126 clean_warning_registry ()
102127 with warnings .catch_warnings (record = True ):
103- linear_classifiers = [
104- (name , clazz )
105- for name , clazz in classifiers
128+ for name , clazz in classifiers :
106129 if ('class_weight' in clazz ().get_params ().keys () and
107- issubclass (clazz , LinearClassifierMixin ))]
130+ issubclass (clazz , LinearClassifierMixin )):
131+ yield name , clazz
132+
108133
109- for name , Classifier in linear_classifiers :
110- yield check_class_weight_balanced_linear_classifier , name , Classifier
134+ @pytest .mark .parametrize ("name, Classifier" ,
135+ _tested_linear_classifiers ())
136+ def test_class_weight_balanced_linear_classifiers (name , Classifier ):
137+ check_class_weight_balanced_linear_classifier (name , Classifier )
111138
112139
113140@ignore_warnings
0 commit comments