Skip to content

Commit 67cc975

Browse files
rthglemaitre
authored andcommitted
[MRG+1] MAINT Parametrize common estimator tests with pytest (scikit-learn#11063)
1 parent aaf9cf0 commit 67cc975

File tree

2 files changed

+64
-23
lines changed

2 files changed

+64
-23
lines changed

doc/developers/tips.rst

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,22 @@ will be displayed as a color background behind the line number.
6464
Useful pytest aliases and flags
6565
-------------------------------
6666

67-
We recommend using pytest to run unit tests. When a unit tests fail, the
68-
following tricks can make debugging easier:
67+
The full test suite takes fairly long to run. For faster iterations,
68+
it is possibly to select a subset of tests using pytest selectors.
69+
In particular, one can run a `single test based on its node ID
70+
<https://docs.pytest.org/en/latest/example/markers.html#selecting-tests-based-on-their-node-id>`_::
71+
72+
pytest -v sklearn/linear_model/tests/test_logistic.py::test_sparsify
73+
74+
or use the `-k pytest parameter
75+
<https://docs.pytest.org/en/latest/example/markers.html#using-k-expr-to-select-tests-based-on-their-name>`_
76+
to select tests based on their name. For instance,::
77+
78+
pytest sklearn/tests/test_common.py -v -k LogisticRegression
79+
80+
will run all :term:`common tests` for the ``LogisticRegression`` estimator.
81+
82+
When a unit tests fail, the following tricks can make debugging easier:
6983

7084
1. The command line argument ``pytest -l`` instructs pytest to print the local
7185
variables when a failure occurs.

sklearn/tests/test_common.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import re
1414
import pkgutil
1515

16+
import pytest
17+
1618
from sklearn.utils.testing import assert_false, clean_warning_registry
1719
from sklearn.utils.testing import all_estimators
1820
from sklearn.utils.testing import assert_equal
@@ -41,34 +43,57 @@ def test_all_estimator_no_base_class():
4143

4244

4345
def test_all_estimators():
44-
# Test that estimators are default-constructible, cloneable
45-
# and have working repr.
4646
estimators = all_estimators(include_meta_estimators=True)
4747

4848
# Meta sanity-check to make sure that the estimator introspection runs
4949
# properly
5050
assert_greater(len(estimators), 0)
5151

52-
for name, Estimator in estimators:
53-
# some can just not be sensibly default constructed
54-
yield check_parameters_default_constructible, name, Estimator
5552

53+
@pytest.mark.parametrize(
54+
'name, Estimator',
55+
all_estimators(include_meta_estimators=True)
56+
)
57+
def test_parameters_default_constructible(name, Estimator):
58+
# Test that estimators are default-constructible
59+
check_parameters_default_constructible(name, Estimator)
5660

57-
def test_non_meta_estimators():
58-
# input validation etc for non-meta estimators
59-
estimators = all_estimators()
60-
for name, Estimator in estimators:
61+
62+
def _tested_non_meta_estimators():
63+
for name, Estimator in all_estimators():
6164
if issubclass(Estimator, BiclusterMixin):
6265
continue
6366
if name.startswith("_"):
6467
continue
68+
yield name, Estimator
69+
70+
71+
def _generate_checks_per_estimator(check_generator, estimators):
72+
for name, Estimator in estimators:
6573
estimator = Estimator()
66-
# check this on class
67-
yield check_no_attributes_set_in_init, name, estimator
74+
for check in check_generator(name, estimator):
75+
yield name, Estimator, check
6876

69-
for check in _yield_all_checks(name, estimator):
70-
set_checking_parameters(estimator)
71-
yield check, name, estimator
77+
78+
@pytest.mark.parametrize(
79+
"name, Estimator, check",
80+
_generate_checks_per_estimator(_yield_all_checks,
81+
_tested_non_meta_estimators())
82+
)
83+
def test_non_meta_estimators(name, Estimator, check):
84+
# Common tests for non-meta estimators
85+
estimator = Estimator()
86+
set_checking_parameters(estimator)
87+
check(name, estimator)
88+
89+
90+
@pytest.mark.parametrize("name, Estimator",
91+
_tested_non_meta_estimators())
92+
def test_no_attributes_set_in_init(name, Estimator):
93+
# input validation etc for non-meta estimators
94+
estimator = Estimator()
95+
# check this on class
96+
check_no_attributes_set_in_init(name, estimator)
7297

7398

7499
def test_configure():
@@ -95,19 +120,21 @@ def test_configure():
95120
os.chdir(cwd)
96121

97122

98-
def test_class_weight_balanced_linear_classifiers():
123+
def _tested_linear_classifiers():
99124
classifiers = all_estimators(type_filter='classifier')
100125

101126
clean_warning_registry()
102127
with warnings.catch_warnings(record=True):
103-
linear_classifiers = [
104-
(name, clazz)
105-
for name, clazz in classifiers
128+
for name, clazz in classifiers:
106129
if ('class_weight' in clazz().get_params().keys() and
107-
issubclass(clazz, LinearClassifierMixin))]
130+
issubclass(clazz, LinearClassifierMixin)):
131+
yield name, clazz
132+
108133

109-
for name, Classifier in linear_classifiers:
110-
yield check_class_weight_balanced_linear_classifier, name, Classifier
134+
@pytest.mark.parametrize("name, Classifier",
135+
_tested_linear_classifiers())
136+
def test_class_weight_balanced_linear_classifiers(name, Classifier):
137+
check_class_weight_balanced_linear_classifier(name, Classifier)
111138

112139

113140
@ignore_warnings

0 commit comments

Comments
 (0)