Skip to content

Commit 58da238

Browse files
committed
BUG: avoid same indices in test and train
COSMIT: better names as suggested @schwarty Address comments by @schwarty
1 parent 41b312c commit 58da238

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

sklearn/cross_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ def _iter_indices(self):
10151015
)[0]
10161016
missing_idx = rng.permutation(missing_idx)
10171017
train.extend(missing_idx[:(self.n_train - len(train))])
1018-
test.extend(missing_idx[:(self.n_test - len(test))])
1018+
test.extend(missing_idx[-(self.n_test - len(test)):])
10191019

10201020
train = rng.permutation(train)
10211021
test = rng.permutation(test)

sklearn/tests/test_cross_validation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ def test_stratified_shuffle_split_iter_no_indices():
397397
assert_array_equal(sorted(test_indices), np.where(test_mask)[0])
398398

399399

400-
def test_shuffle_split_even():
401-
# Test the in StratifiedShuffleSplit, indices are drawn with a
400+
def test_stratified_shuffle_split_even():
401+
# Test the StratifiedShuffleSplit, indices are drawn with a
402402
# equal chance
403403
n_folds = 5
404404
n_iter = 1000
@@ -431,6 +431,7 @@ def assert_counts_are_ok(idx_counts, p):
431431

432432
assert_equal(len(train), splits.n_train)
433433
assert_equal(len(test), splits.n_test)
434+
assert_equal(len(set(train).intersection(test)), 0)
434435

435436
label_counts = np.unique(labels)
436437
assert_equal(splits.test_size, 1.0 / n_folds)

0 commit comments

Comments
 (0)