Skip to content

Commit 38f9e1c

Browse files
committed
Merge remote-tracking branch 'upstream/pr/6182'
2 parents fa76b2e + 6148686 commit 38f9e1c

File tree

5 files changed

+39
-6
lines changed

5 files changed

+39
-6
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ Enhancements
102102
Bug fixes
103103
.........
104104

105+
- :class:`StratifiedKFold` now raises error if all n_labels for individual classes is less than n_folds.
106+
(`#6182 <https://github.com/scikit-learn/scikit-learn/pull/6182>`_) by `Devashish Deshpande`_.
107+
105108
- :class:`RandomizedPCA` default number of `iterated_power` is 2 instead of 3.
106109
This is a speed up with a minor precision decrease. (`#5141 <https://github.com/scikit-learn/scikit-learn/pull/5141>`_) by `Giorgio Patrini`_.
107110

sklearn/cross_validation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,10 @@ def __init__(self, y, n_folds=3, shuffle=False,
519519
unique_labels, y_inversed = np.unique(y, return_inverse=True)
520520
label_counts = bincount(y_inversed)
521521
min_labels = np.min(label_counts)
522+
if np.all(self.n_folds > label_counts):
523+
raise ValueError("All the n_labels for individual classes"
524+
" are less than %d folds."
525+
% (self.n_folds))
522526
if self.n_folds > min_labels:
523527
warnings.warn(("The least populated class in y has only %d"
524528
" members, which is too few. The minimum"

sklearn/model_selection/_split.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ def _make_test_folds(self, X, y=None, labels=None):
564564
unique_y, y_inversed = np.unique(y, return_inverse=True)
565565
y_counts = bincount(y_inversed)
566566
min_labels = np.min(y_counts)
567+
if np.all(self.n_folds > y_counts):
568+
raise ValueError("All the n_labels for individual classes"
569+
" are less than %d folds."
570+
% (self.n_folds))
567571
if self.n_folds > min_labels:
568572
warnings.warn(("The least populated class in y has only %d"
569573
" members, which is too few. The minimum"

sklearn/model_selection/tests/test_split.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sklearn.utils.testing import assert_array_almost_equal
2222
from sklearn.utils.testing import assert_array_equal
2323
from sklearn.utils.testing import assert_warns_message
24+
from sklearn.utils.testing import assert_raise_message
2425
from sklearn.utils.testing import ignore_warnings
2526
from sklearn.utils.validation import _num_samples
2627
from sklearn.utils.mocking import MockDataFrame
@@ -206,7 +207,7 @@ def test_kfold_valueerrors():
206207

207208
# Check that a warning is raised if the least populated class has too few
208209
# members.
209-
y = np.array([3, 3, -1, -1, 2])
210+
y = np.array([3, 3, -1, -1, 3])
210211

211212
skf_3 = StratifiedKFold(3)
212213
assert_warns_message(Warning, "The least populated class",
@@ -219,11 +220,21 @@ def test_kfold_valueerrors():
219220
warnings.simplefilter("ignore")
220221
check_cv_coverage(skf_3, X2, y, labels=None, expected_n_iter=3)
221222

223+
# Check that errors are raised if all n_labels for individual
224+
# classes are less than n_folds.
225+
y = np.array([3, 3, -1, -1, 2])
226+
227+
assert_raises(ValueError, next, skf_3.split(X2, y))
228+
222229
# Error when number of folds is <= 1
223230
assert_raises(ValueError, KFold, 0)
224231
assert_raises(ValueError, KFold, 1)
225-
assert_raises(ValueError, StratifiedKFold, 0)
226-
assert_raises(ValueError, StratifiedKFold, 1)
232+
error_string = ("k-fold cross-validation requires at least one"
233+
" train/test split")
234+
assert_raise_message(ValueError, error_string,
235+
StratifiedKFold, 0)
236+
assert_raise_message(ValueError, error_string,
237+
StratifiedKFold, 1)
227238

228239
# When n_folds is not integer:
229240
assert_raises(ValueError, KFold, 1.5)

sklearn/tests/test_cross_validation.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.utils.testing import assert_array_almost_equal
2020
from sklearn.utils.testing import assert_array_equal
2121
from sklearn.utils.testing import assert_warns_message
22+
from sklearn.utils.testing import assert_raise_message
2223
from sklearn.utils.testing import ignore_warnings
2324
from sklearn.utils.mocking import CheckingClassifier, MockDataFrame
2425

@@ -160,7 +161,7 @@ def test_kfold_valueerrors():
160161

161162
# Check that a warning is raised if the least populated class has too few
162163
# members.
163-
y = [3, 3, -1, -1, 2]
164+
y = [3, 3, -1, -1, 3]
164165

165166
cv = assert_warns_message(Warning, "The least populated class",
166167
cval.StratifiedKFold, y, 3)
@@ -170,11 +171,21 @@ def test_kfold_valueerrors():
170171
# side of the split at each split
171172
check_cv_coverage(cv, expected_n_iter=3, n_samples=len(y))
172173

174+
# Check that errors are raised if all n_labels for individual
175+
# classes are less than n_folds.
176+
y = [3, 3, -1, -1, 2]
177+
178+
assert_raises(ValueError, cval.StratifiedKFold, y, 3)
179+
173180
# Error when number of folds is <= 1
174181
assert_raises(ValueError, cval.KFold, 2, 0)
175182
assert_raises(ValueError, cval.KFold, 2, 1)
176-
assert_raises(ValueError, cval.StratifiedKFold, y, 0)
177-
assert_raises(ValueError, cval.StratifiedKFold, y, 1)
183+
error_string = ("k-fold cross validation requires at least one"
184+
" train / test split")
185+
assert_raise_message(ValueError, error_string,
186+
cval.StratifiedKFold, y, 0)
187+
assert_raise_message(ValueError, error_string,
188+
cval.StratifiedKFold, y, 1)
178189

179190
# When n is not integer:
180191
assert_raises(ValueError, cval.KFold, 2.5, 2)

0 commit comments

Comments
 (0)