Skip to content

Commit 41b312c

Browse files
committed
BUG: StratifiedShuffleSplit not obeying n_train
StratifiedShuffleSplit was not giving the n_train and n_test requested
1 parent 06d7536 commit 41b312c

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

sklearn/cross_validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,19 @@ def _iter_indices(self):
10041004
train.extend(cls_i[:n_i[i]])
10051005
test.extend(cls_i[n_i[i]:n_i[i] + t_i[i]])
10061006

1007+
# Because of rounding issues (as n_train and n_test are not
1008+
# dividers of the number of elements per class), we may end
1009+
# up here with less samples in train and test than asked for.
1010+
if len(train) < self.n_train or len(test) < self.n_test:
1011+
# We complete by affecting randomly the missing indexes
1012+
missing_idx = np.where(
1013+
np.bincount(train + test,
1014+
minlength=len(self.y)) == 0,
1015+
)[0]
1016+
missing_idx = rng.permutation(missing_idx)
1017+
train.extend(missing_idx[:(self.n_train - len(train))])
1018+
test.extend(missing_idx[:(self.n_test - len(test))])
1019+
10071020
train = rng.permutation(train)
10081021
test = rng.permutation(test)
10091022

sklearn/tests/test_cross_validation.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ def test_stratified_shuffle_split_iter_no_indices():
400400
def test_shuffle_split_even():
401401
# Test the in StratifiedShuffleSplit, indices are drawn with a
402402
# equal chance
403+
n_folds = 5
404+
n_iter = 1000
403405

404406
def assert_counts_are_ok(idx_counts, p):
405407
# Here we test that the distribution of the counts
@@ -412,26 +414,30 @@ def assert_counts_are_ok(idx_counts, p):
412414
"An index is not drawn with chance corresponding "
413415
"to even draws")
414416

415-
for n_labels in (6, 22):
416-
labels = np.array((n_labels // 2) * [0, 1])
417-
n_folds = 5
418-
splits = cval.StratifiedShuffleSplit(labels, n_iter=1000,
417+
for n_samples in (6, 22):
418+
labels = np.array((n_samples // 2) * [0, 1])
419+
splits = cval.StratifiedShuffleSplit(labels, n_iter=n_iter,
419420
test_size=1./n_folds, random_state=0)
420421

421-
train_counts = [0] * len(labels)
422-
test_counts = [0] * len(labels)
422+
train_counts = [0] * n_samples
423+
test_counts = [0] * n_samples
424+
n_splits = 0
423425
for train, test in splits:
426+
n_splits += 1
424427
for counter, ids in [(train_counts, train), (test_counts, test)]:
425428
for id in ids:
426429
counter[id] += 1
430+
assert_equal(n_splits, n_iter)
431+
432+
assert_equal(len(train), splits.n_train)
433+
assert_equal(len(test), splits.n_test)
427434

428-
n_splits = len(splits)
429435
label_counts = np.unique(labels)
430436
assert_equal(splits.test_size, 1.0 / n_folds)
431437
assert_equal(splits.n_train + splits.n_test, len(labels))
432438
assert_equal(len(label_counts), 2)
433-
ex_test_p = (1. * splits.n_test) / n_labels
434-
ex_train_p = 1.0 - ex_test_p
439+
ex_test_p = float(splits.n_test) / n_samples
440+
ex_train_p = float(splits.n_train) / n_samples
435441

436442
assert_counts_are_ok(train_counts, ex_train_p)
437443
assert_counts_are_ok(test_counts, ex_test_p)

0 commit comments

Comments
 (0)