Skip to content

Commit 9df5626

Browse files
ndawearjoly
authored andcommitted
add test_base.test_score_sample_weight
1 parent 08669f4 commit 9df5626

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

sklearn/tests/test_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.utils.testing import assert_true
1010
from sklearn.utils.testing import assert_false
1111
from sklearn.utils.testing import assert_equal
12+
from sklearn.utils.testing import assert_not_equal
1213
from sklearn.utils.testing import assert_raises
1314

1415
from sklearn.base import BaseEstimator, clone, is_classifier
@@ -202,3 +203,28 @@ def test_set_params():
202203
#bad_pipeline = Pipeline([("bad", NoEstimator())])
203204
#assert_raises(AttributeError, bad_pipeline.set_params,
204205
#bad__stupid_param=True)
206+
207+
208+
def test_score_sample_weight():
209+
from sklearn.tree import DecisionTreeClassifier
210+
from sklearn.tree import DecisionTreeRegressor
211+
from sklearn import datasets
212+
213+
rng = np.random.RandomState(0)
214+
215+
# test both ClassifierMixin and RegressorMixin
216+
estimators = [DecisionTreeClassifier(max_depth=2),
217+
DecisionTreeRegressor(max_depth=2)]
218+
sets = [datasets.load_iris(),
219+
datasets.load_boston()]
220+
221+
for est, ds in zip(estimators, sets):
222+
est.fit(ds.data, ds.target)
223+
# generate random sample weights
224+
sample_weight = rng.randint(1, 10, size=len(ds.target))
225+
# check that the score with and without sample weights are different
226+
assert_not_equal(est.score(ds.data, ds.target),
227+
est.score(ds.data, ds.target,
228+
sample_weight=sample_weight),
229+
msg="Unweighted and weighted scores "
230+
"are unexpectedly equal")

0 commit comments

Comments
 (0)