Skip to content

Commit 1678bd5

Browse files
committed
Merge pull request scikit-learn#4694 from betatim/check-fit-returns-self
[MRG + 1] Check fit() returns self for all estimators
2 parents 69e344f + 175f50e commit 1678bd5

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _yield_non_meta_checks(name, Estimator):
5252
yield check_estimators_dtypes
5353
yield check_fit_score_takes_y
5454
yield check_dtype_object
55+
yield check_estimators_fit_returns_self
5556

5657
# Check that all estimator yield informative messages when
5758
# trained on empty datasets
@@ -777,6 +778,21 @@ def check_classifiers_train(name, Classifier):
777778
assert_raises(ValueError, classifier.predict_proba, X.T)
778779

779780

781+
def check_estimators_fit_returns_self(name, Estimator):
782+
"""Check if self is returned when calling fit"""
783+
X, y = make_blobs(random_state=0, n_samples=9, n_features=4)
784+
y = multioutput_estimator_convert_y_2d(name, y)
785+
# some want non-negative input
786+
X -= X.min()
787+
788+
estimator = Estimator()
789+
790+
set_fast_parameters(estimator)
791+
set_random_state(estimator)
792+
793+
assert_true(estimator.fit(X, y) is estimator)
794+
795+
780796
def check_estimators_unfitted(name, Estimator):
781797
"""Check if NotFittedError is raised when calling predict and related
782798
functions"""
@@ -1202,7 +1218,7 @@ def check_parameters_default_constructible(name, Estimator):
12021218
# test __repr__
12031219
repr(estimator)
12041220
# test that set_params returns self
1205-
assert_true(isinstance(estimator.set_params(), Estimator))
1221+
assert_true(estimator.set_params() is estimator)
12061222

12071223
# test if init does nothing but set parameters
12081224
# this is important for grid_search etc.

0 commit comments

Comments
 (0)