@@ -67,7 +67,12 @@ def _yield_checks(estimator):
6767 yield check_sample_weights_not_an_array
6868 yield check_sample_weights_list
6969 yield check_sample_weights_shape
70- yield check_sample_weights_invariance
70+ if (has_fit_parameter (estimator , "sample_weight" )
71+ and not (hasattr (estimator , "_pairwise" )
72+ and estimator ._pairwise )):
73+ # We skip pairwise because the data is not pairwise
74+ yield partial (check_sample_weights_invariance , kind = 'ones' )
75+ yield partial (check_sample_weights_invariance , kind = 'zeros' )
7176 yield check_estimators_fit_returns_self
7277 yield partial (check_estimators_fit_returns_self , readonly_memmap = True )
7378
@@ -836,41 +841,55 @@ def check_sample_weights_shape(name, estimator_orig):
836841
837842
838843@ignore_warnings (category = FutureWarning )
839- def check_sample_weights_invariance (name , estimator_orig ):
840- # check that the estimators yield same results for
844+ def check_sample_weights_invariance (name , estimator_orig , kind = "ones" ):
845+ # For kind="ones" check that the estimators yield same results for
841846 # unit weights and no weights
842- if (has_fit_parameter (estimator_orig , "sample_weight" ) and
843- not (hasattr (estimator_orig , "_pairwise" )
844- and estimator_orig ._pairwise )):
845- # We skip pairwise because the data is not pairwise
846-
847- estimator1 = clone (estimator_orig )
848- estimator2 = clone (estimator_orig )
849- set_random_state (estimator1 , random_state = 0 )
850- set_random_state (estimator2 , random_state = 0 )
851-
852- X = np .array ([[1 , 3 ], [1 , 3 ], [1 , 3 ], [1 , 3 ],
853- [2 , 1 ], [2 , 1 ], [2 , 1 ], [2 , 1 ],
854- [3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ],
855- [4 , 1 ], [4 , 1 ], [4 , 1 ], [4 , 1 ]], dtype = np .dtype ('float' ))
856- y = np .array ([1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ,
857- 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ], dtype = np .dtype ('int' ))
858- y = _enforce_estimator_tags_y (estimator1 , y )
859-
860- estimator1 .fit (X , y = y , sample_weight = np .ones (shape = len (y )))
861- estimator2 .fit (X , y = y , sample_weight = None )
862-
863- for method in ["predict" , "transform" ]:
864- if hasattr (estimator_orig , method ):
865- X_pred1 = getattr (estimator1 , method )(X )
866- X_pred2 = getattr (estimator2 , method )(X )
867- if sparse .issparse (X_pred1 ):
868- X_pred1 = X_pred1 .toarray ()
869- X_pred2 = X_pred2 .toarray ()
870- assert_allclose (X_pred1 , X_pred2 ,
871- err_msg = "For %s sample_weight=None is not"
872- " equivalent to sample_weight=ones"
873- % name )
847+ # For kind="zeros" check that setting sample_weight to 0 is equivalent
848+ # to removing corresponding samples.
849+ estimator1 = clone (estimator_orig )
850+ estimator2 = clone (estimator_orig )
851+ set_random_state (estimator1 , random_state = 0 )
852+ set_random_state (estimator2 , random_state = 0 )
853+
854+ X1 = np .array ([[1 , 3 ], [1 , 3 ], [1 , 3 ], [1 , 3 ],
855+ [2 , 1 ], [2 , 1 ], [2 , 1 ], [2 , 1 ],
856+ [3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ],
857+ [4 , 1 ], [4 , 1 ], [4 , 1 ], [4 , 1 ]], dtype = np .float64 )
858+ y1 = np .array ([1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ,
859+ 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ], dtype = np .int )
860+
861+ if kind == 'ones' :
862+ X2 = X1
863+ y2 = y1
864+ sw2 = np .ones (shape = len (y1 ))
865+ err_msg = (f"For { name } sample_weight=None is not equivalent to "
866+ f"sample_weight=ones" )
867+ elif kind == 'zeros' :
868+ # Construct a dataset that is very different to (X, y) if weights
869+ # are disregarded, but identical to (X, y) given weights.
870+ X2 = np .vstack ([X1 , X1 + 1 ])
871+ y2 = np .hstack ([y1 , 3 - y1 ])
872+ sw2 = np .ones (shape = len (y1 ) * 2 )
873+ sw2 [len (y1 ):] = 0
874+ X2 , y2 , sw2 = shuffle (X2 , y2 , sw2 , random_state = 0 )
875+
876+ err_msg = (f"For { name } , a zero sample_weight is not equivalent "
877+ f"to removing the sample" )
878+ else : # pragma: no cover
879+ raise ValueError
880+
881+ y1 = _enforce_estimator_tags_y (estimator1 , y1 )
882+ y2 = _enforce_estimator_tags_y (estimator2 , y2 )
883+
884+ estimator1 .fit (X1 , y = y1 , sample_weight = None )
885+ estimator2 .fit (X2 , y = y2 , sample_weight = sw2 )
886+
887+ for method in ["predict" , "predict_proba" ,
888+ "decision_function" , "transform" ]:
889+ if hasattr (estimator_orig , method ):
890+ X_pred1 = getattr (estimator1 , method )(X1 )
891+ X_pred2 = getattr (estimator2 , method )(X1 )
892+ assert_allclose_dense_sparse (X_pred1 , X_pred2 , err_msg = err_msg )
874893
875894
876895@ignore_warnings (category = (FutureWarning , UserWarning ))
0 commit comments