Skip to content

Commit 5d8b384

Browse files
Yannick Schwartzamueller
authored andcommitted
stratified shuffle split can return masks
1 parent 694e575 commit 5d8b384

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

sklearn/cross_validation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,15 @@ def __iter__(self):
970970
train.extend(cls_i[:n_i[i]])
971971
test.extend(cls_i[n_i[i]:n_i[i] + t_i[i]])
972972

973-
yield train, test
973+
if self.indices:
974+
yield train, test
975+
else:
976+
train_m = np.zeros(self.n, dtype='bool')
977+
test_m = np.zeros(self.n, dtype='bool')
978+
train_m[train] = True
979+
test_m[test] = True
980+
981+
yield train_m, test_m
974982

975983
def __repr__(self):
976984
return ('%s(labels=%s, n_iterations=%d, test_size=%s, indices=%s, '

0 commit comments

Comments
 (0)