@@ -52,6 +52,7 @@ def _yield_non_meta_checks(name, Estimator):
5252 yield check_estimators_dtypes
5353 yield check_fit_score_takes_y
5454 yield check_dtype_object
55+ yield check_estimators_fit_returns_self
5556
5657 # Check that all estimator yield informative messages when
5758 # trained on empty datasets
@@ -777,6 +778,21 @@ def check_classifiers_train(name, Classifier):
777778 assert_raises (ValueError , classifier .predict_proba , X .T )
778779
779780
781+ def check_estimators_fit_returns_self (name , Estimator ):
782+ """Check if self is returned when calling fit"""
783+ X , y = make_blobs (random_state = 0 , n_samples = 9 , n_features = 4 )
784+ y = multioutput_estimator_convert_y_2d (name , y )
785+ # some want non-negative input
786+ X -= X .min ()
787+
788+ estimator = Estimator ()
789+
790+ set_fast_parameters (estimator )
791+ set_random_state (estimator )
792+
793+ assert_true (estimator .fit (X , y ) is estimator )
794+
795+
780796def check_estimators_unfitted (name , Estimator ):
781797 """Check if NotFittedError is raised when calling predict and related
782798 functions"""
@@ -1202,7 +1218,7 @@ def check_parameters_default_constructible(name, Estimator):
12021218 # test __repr__
12031219 repr (estimator )
12041220 # test that set_params returns self
1205- assert_true (isinstance ( estimator .set_params (), Estimator ) )
1221+ assert_true (estimator .set_params () is estimator )
12061222
12071223 # test if init does nothing but set parameters
12081224 # this is important for grid_search etc.
0 commit comments