Skip to content

Commit 17ecf4e

Browse files
Yannick SchwartzGaelVaroquaux
authored andcommitted
updated stratified shuffle split test
1 parent 9419451 commit 17ecf4e

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

sklearn/tests/test_cross_validation.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from nose.tools import assert_true
77
from nose.tools import assert_raises
88
from nose.tools import assert_equal
9+
from nose.tools import assert_less_equal
910

1011
from ..base import BaseEstimator
1112
from ..datasets import make_regression
@@ -62,12 +63,20 @@ def test_stratified_shuffle_split():
6263
# Check that errors are raised if there is not enough samples
6364
assert_raises(ValueError, cross_validation.StratifiedShuffleSplit, y, 3, 0.5, 0.6)
6465

65-
# Check if returns balanced classes
66-
sss = cross_validation.StratifiedShuffleSplit(y, 6, test_size=0.33)
66+
# Check if returns better balanced classes than ShuffleSplit
67+
sss = cross_validation.StratifiedShuffleSplit(y, 6, test_size=0.33, random_state=0)
68+
ss = cross_validation.ShuffleSplit(y.size, 6, 0.33, random_state=0)
69+
70+
train_std = []
71+
test_std = []
6772

6873
for train, test in sss:
69-
assert_array_equal(y[train], np.unique(y))
70-
assert_array_equal(np.unique(y[test]), np.unique(y))
74+
train_std.append(np.std(np.bincount(y[train])))
75+
test_std.append(np.std(np.bincount(y[test])))
76+
77+
for i, [train, test] in enumerate(ss):
78+
assert_less_equal(np.std(np.bincount(y[train])), train_std[i])
79+
assert_less_equal(np.std(np.bincount(y[test])), test_std[i])
7180

7281
def test_cross_val_score():
7382
clf = MockClassifier()

0 commit comments

Comments
 (0)