|
6 | 6 | from nose.tools import assert_true |
7 | 7 | from nose.tools import assert_raises |
8 | 8 | from nose.tools import assert_equal |
| 9 | +from nose.tools import assert_less_equal |
9 | 10 |
|
10 | 11 | from ..base import BaseEstimator |
11 | 12 | from ..datasets import make_regression |
@@ -62,12 +63,20 @@ def test_stratified_shuffle_split(): |
62 | 63 | # Check that errors are raised if there is not enough samples |
63 | 64 | assert_raises(ValueError, cross_validation.StratifiedShuffleSplit, y, 3, 0.5, 0.6) |
64 | 65 |
|
65 | | - # Check if returns balanced classes |
66 | | - sss = cross_validation.StratifiedShuffleSplit(y, 6, test_size=0.33) |
| 66 | + # Check if returns better balanced classes than ShuffleSplit |
| 67 | + sss = cross_validation.StratifiedShuffleSplit(y, 6, test_size=0.33, random_state=0) |
| 68 | + ss = cross_validation.ShuffleSplit(y.size, 6, 0.33, random_state=0) |
| 69 | + |
| 70 | + train_std = [] |
| 71 | + test_std = [] |
67 | 72 |
|
68 | 73 | for train, test in sss: |
69 | | - assert_array_equal(y[train], np.unique(y)) |
70 | | - assert_array_equal(np.unique(y[test]), np.unique(y)) |
| 74 | + train_std.append(np.std(np.bincount(y[train]))) |
| 75 | + test_std.append(np.std(np.bincount(y[test]))) |
| 76 | + |
| 77 | + for i, [train, test] in enumerate(ss): |
| 78 | + assert_less_equal(np.std(np.bincount(y[train])), train_std[i]) |
| 79 | + assert_less_equal(np.std(np.bincount(y[test])), test_std[i]) |
71 | 80 |
|
72 | 81 | def test_cross_val_score(): |
73 | 82 | clf = MockClassifier() |
|
0 commit comments