Skip to content

Commit d074e40

Browse files
cbrummittlesteve
authored andcommitted
[MRG+1] Fix bug in StratifiedShuffleSplit for multi-label data with targets having > 1000 labels (scikit-learn#9922)
* Use ' '.join(row) for multi-label targets in StratifiedShuffleSplit because str(row) uses an ellipsis when len(row) > 1000 * Add a new test for multilabel problems with more than a thousand labels
1 parent a7f8c32 commit d074e40

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

sklearn/model_selection/_split.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,8 +1534,9 @@ def _iter_indices(self, X, y, groups=None):
15341534
self.train_size)
15351535

15361536
if y.ndim == 2:
1537-
# for multi-label y, map each distinct row to its string repr:
1538-
y = np.array([str(row) for row in y])
1537+
# for multi-label y, map each distinct row to a string repr
1538+
# using join because str(row) uses an ellipsis if len(row) > 1000
1539+
y = np.array([' '.join(row.astype('str')) for row in y])
15391540

15401541
classes, y_indices = np.unique(y, return_inverse=True)
15411542
n_classes = classes.shape[0]

sklearn/model_selection/tests/test_split.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,29 @@ def test_stratified_shuffle_split_multilabel():
726726
assert_equal(expected_ratio, np.mean(y_test[:, 0]))
727727

728728

729+
def test_stratified_shuffle_split_multilabel_many_labels():
730+
# fix in PR #9922: for multilabel data with > 1000 labels, str(row)
731+
# truncates with an ellipsis for elements in positions 4 through
732+
# len(row) - 4, so labels were not being correctly split using the powerset
733+
# method for transforming a multilabel problem to a multiclass one; this
734+
# test checks that this problem is fixed.
735+
row_with_many_zeros = [1, 0, 1] + [0] * 1000 + [1, 0, 1]
736+
row_with_many_ones = [1, 0, 1] + [1] * 1000 + [1, 0, 1]
737+
y = np.array([row_with_many_zeros] * 10 + [row_with_many_ones] * 100)
738+
X = np.ones_like(y)
739+
740+
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
741+
train, test = next(sss.split(X=X, y=y))
742+
y_train = y[train]
743+
y_test = y[test]
744+
745+
# correct stratification of entire rows
746+
# (by design, here y[:, 4] uniquely determines the entire row of y)
747+
expected_ratio = np.mean(y[:, 4])
748+
assert_equal(expected_ratio, np.mean(y_train[:, 4]))
749+
assert_equal(expected_ratio, np.mean(y_test[:, 4]))
750+
751+
729752
def test_predefinedsplit_with_kfold_split():
730753
# Check that PredefinedSplit can reproduce a split generated by Kfold.
731754
folds = -1 * np.ones(10)

0 commit comments

Comments
 (0)