Skip to content

Commit 3c7818e

Browse files
committed
ENH allow y to be a list in GridSearchCV, cross_val_score and train_test_split.
1 parent 4e438a9 commit 3c7818e

File tree

5 files changed

+81
-21
lines changed

5 files changed

+81
-21
lines changed

sklearn/cross_validation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import scipy.sparse as sp
2323

2424
from .base import is_classifier, clone
25-
from .utils import check_arrays, check_random_state, safe_mask
25+
from .utils import check_arrays, check_random_state, safe_indexing
2626
from .utils.validation import _num_samples
2727
from .externals.joblib import Parallel, delayed, logger
2828
from .externals.six import with_metaclass
@@ -1136,8 +1136,6 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
11361136
"""
11371137
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True,
11381138
allow_nans=True, allow_nd=True)
1139-
if y is not None:
1140-
y = np.asarray(y)
11411139

11421140
cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
11431141
scorer = check_scoring(estimator, score_func=score_func, scoring=scoring)
@@ -1278,10 +1276,10 @@ def _safe_split(estimator, X, y, indices, train_indices=None):
12781276
else:
12791277
X_subset = X[np.ix_(indices, train_indices)]
12801278
else:
1281-
X_subset = X[safe_mask(X, indices)]
1279+
X_subset = safe_indexing(X, indices)
12821280

12831281
if y is not None:
1284-
y_subset = y[safe_mask(y, indices)]
1282+
y_subset = safe_indexing(y, indices)
12851283
else:
12861284
y_subset = None
12871285

@@ -1527,12 +1525,12 @@ def train_test_split(*arrays, **options):
15271525
[0, 1],
15281526
[6, 7]])
15291527
>>> b_train
1530-
array([2, 0, 3])
1528+
[2, 0, 3]
15311529
>>> a_test
15321530
array([[2, 3],
15331531
[8, 9]])
15341532
>>> b_test
1535-
array([1, 4])
1533+
[1, 4]
15361534
15371535
"""
15381536
n_arrays = len(arrays)
@@ -1544,18 +1542,21 @@ def train_test_split(*arrays, **options):
15441542
random_state = options.pop('random_state', None)
15451543
options['sparse_format'] = 'csr'
15461544
options['allow_nans'] = True
1545+
if not "allow_lists" in options:
1546+
options["allow_lists"] = True
15471547

15481548
if test_size is None and train_size is None:
15491549
test_size = 0.25
15501550

15511551
arrays = check_arrays(*arrays, **options)
1552-
n_samples = arrays[0].shape[0]
1552+
n_samples = _num_samples(arrays[0])
15531553
cv = ShuffleSplit(n_samples, test_size=test_size,
15541554
train_size=train_size,
15551555
random_state=random_state)
15561556

15571557
train, test = next(iter(cv))
1558-
return list(chain.from_iterable((a[train], a[test]) for a in arrays))
1558+
return list(chain.from_iterable((safe_indexing(a, train),
1559+
safe_indexing(a, test)) for a in arrays))
15591560

15601561

15611562
train_test_split.__test__ = False # to avoid a pb with nosetests

sklearn/grid_search.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ def _fit(self, X, y, parameter_iterable):
355355
raise ValueError('Target variable (y) has a different number '
356356
'of samples (%i) than data (X: %i samples)'
357357
% (len(y), n_samples))
358-
y = np.asarray(y)
359358
cv = check_cv(cv, X, y, classifier=is_classifier(estimator))
360359

361360
if self.verbose > 0:

sklearn/tests/test_cross_validation.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,20 @@
4242
class MockListClassifier(BaseEstimator):
4343
"""Dummy classifier to test the cross-validation.
4444
45-
Checks that GridSearchCV didn't convert X to array.
45+
Checks that GridSearchCV didn't convert X or Y to array.
4646
"""
47-
def __init__(self, foo_param=0):
47+
def __init__(self, foo_param=0, check_y=False, check_X=True):
4848
self.foo_param = foo_param
49+
self.check_y = check_y
50+
self.check_X = check_X
4951

5052
def fit(self, X, Y):
5153
assert_true(len(X) == len(Y))
52-
assert_true(isinstance(X, list))
54+
if self.check_X:
55+
assert_true(isinstance(X, list))
56+
if self.check_y:
57+
assert_true(isinstance(Y, list))
58+
5359
return self
5460

5561
def predict(self, T):
@@ -513,6 +519,9 @@ def test_cross_val_score():
513519
clf = MockListClassifier()
514520
scores = cval.cross_val_score(clf, X.tolist(), y.tolist())
515521

522+
clf = MockListClassifier(check_X=False, check_y=True)
523+
scores = cval.cross_val_score(clf, X, y.tolist())
524+
516525
assert_raises(ValueError, cval.cross_val_score, clf, X, y,
517526
scoring="sklearn")
518527

@@ -596,7 +605,7 @@ def test_train_test_split():
596605
X = np.arange(100).reshape((10, 10))
597606
X_s = coo_matrix(X)
598607
y = range(10)
599-
split = cval.train_test_split(X, X_s, y)
608+
split = cval.train_test_split(X, X_s, y, allow_lists=False)
600609
X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
601610
assert_array_equal(X_train, X_s_train.toarray())
602611
assert_array_equal(X_test, X_s_test.toarray())
@@ -606,6 +615,11 @@ def test_train_test_split():
606615
X_train, X_test, y_train, y_test = split
607616
assert_equal(len(y_test), len(y_train))
608617

618+
split = cval.train_test_split(X, X_s, y)
619+
X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
620+
assert_true(isinstance(y_train, list))
621+
assert_true(isinstance(y_test, list))
622+
609623

610624
def test_cross_val_score_with_score_func_classification():
611625
iris = load_iris()
@@ -911,7 +925,7 @@ def test_train_test_split_allow_nans():
911925
X = np.arange(200, dtype=np.float64).reshape(10, -1)
912926
X[2, :] = np.nan
913927
y = np.repeat([0, 1], X.shape[0]/2)
914-
split = cval.train_test_split(X, y, test_size=0.2, random_state=42)
928+
cval.train_test_split(X, y, test_size=0.2, random_state=42)
915929

916930

917931
def test_permutation_test_score_allow_nans():

sklearn/tests/test_grid_search.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,20 @@ def set_params(self, **params):
7979
class MockListClassifier(object):
8080
"""Dummy classifier to test the cross-validation.
8181
82-
Checks that GridSearchCV didn't convert X to array.
82+
Checks that GridSearchCV didn't convert X (or y) to array.
8383
"""
84-
def __init__(self, foo_param=0):
84+
def __init__(self, foo_param=0, check_y=False, check_X=True):
8585
self.foo_param = foo_param
86+
self.check_y = check_y
87+
self.check_X = check_X
8688

8789
def fit(self, X, Y):
8890
assert_true(len(X) == len(Y))
89-
assert_true(isinstance(X, list))
91+
if self.check_X:
92+
assert_true(isinstance(X, list))
93+
if self.check_y:
94+
assert_true(isinstance(Y, list))
95+
9096
return self
9197

9298
def predict(self, T):
@@ -100,10 +106,15 @@ def score(self, X=None, Y=None):
100106
return score
101107

102108
def get_params(self, deep=False):
103-
return {'foo_param': self.foo_param}
109+
return {'foo_param': self.foo_param, 'check_X': self.check_X,
110+
'check_y': self.check_y}
104111

105112
def set_params(self, **params):
106113
self.foo_param = params['foo_param']
114+
if "check_y" in params:
115+
self.check_y = params["check_y"]
116+
if "check_X" in params:
117+
self.check_X = params["check_X"]
107118
return self
108119

109120

@@ -478,6 +489,18 @@ def test_X_as_list():
478489
assert_true(hasattr(grid_search, "grid_scores_"))
479490

480491

492+
def test_y_as_list():
493+
"""Pass y as list in GridSearchCV"""
494+
X = np.arange(100).reshape(10, 10)
495+
y = np.array([0] * 5 + [1] * 5)
496+
497+
clf = MockListClassifier(check_X=False, check_y=True)
498+
cv = KFold(n=len(X), n_folds=3)
499+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv)
500+
grid_search.fit(X, y.tolist()).score(X, y)
501+
assert_true(hasattr(grid_search, "grid_scores_"))
502+
503+
481504
def test_unsupervised_grid_search():
482505
# test grid-search with unsupervised estimator
483506
X, y = make_blobs(random_state=0)

sklearn/utils/__init__.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818

1919
__all__ = ["murmurhash3_32", "as_float_array", "check_arrays", "safe_asarray",
2020
"assert_all_finite", "array2d", "atleast2d_or_csc",
21-
"atleast2d_or_csr", "warn_if_not_float", "check_random_state",
22-
"compute_class_weight", "minimum_spanning_tree", "column_or_1d"]
21+
"atleast2d_or_csr",
22+
"warn_if_not_float",
23+
"check_random_state",
24+
"compute_class_weight",
25+
"minimum_spanning_tree",
26+
"column_or_1d", "safe_indexing"]
2327

2428

2529
class deprecated(object):
@@ -129,6 +133,25 @@ def safe_mask(X, mask):
129133
return mask
130134

131135

136+
def safe_indexing(X, indices):
137+
"""Return items or rows from X using indices.
138+
139+
Allows simple indexing of lists or arrays.
140+
141+
Parameters
142+
----------
143+
X : array-like, sparse-matrix, list.
144+
Data from which to sample rows or items.
145+
146+
indices : array-like, list
147+
Indices according to which X will be subsampled.
148+
"""
149+
if hasattr(X, "shape"):
150+
return X[indices]
151+
else:
152+
return [X[idx] for idx in indices]
153+
154+
132155
def resample(*arrays, **options):
133156
"""Resample arrays or sparse matrices in a consistent way
134157

0 commit comments

Comments
 (0)