Skip to content

Commit d173ab1

Browse files
committed
COSMIT in cross-validation tests
1 parent 2c984ee commit d173ab1

File tree

1 file changed

+33
-36
lines changed

1 file changed

+33
-36
lines changed

sklearn/tests/test_cross_validation.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ def __init__(self, a=0):
3535
def fit(self, X, Y=None, sample_weight=None, class_prior=None):
3636
if sample_weight is not None:
3737
assert_true(sample_weight.shape[0] == X.shape[0],
38-
'MockClassifier extra fit_param sample_weight.shape[0] is {0}, '
39-
'should be {1}'.format(sample_weight.shape[0], X.shape[0]))
38+
'MockClassifier extra fit_param sample_weight.shape[0]'
39+
' is {0}, should be {1}'.format(sample_weight.shape[0],
40+
X.shape[0]))
4041
if class_prior is not None:
4142
assert_true(class_prior.shape[0] == len(np.unique(y)),
42-
'MockClassifier extra fit_param class_prior.shape[0] is {0}, '
43-
'should be {1}'.format(class_prior.shape[0], len(np.unique(y))))
43+
'MockClassifier extra fit_param class_prior.shape[0]'
44+
' is {0}, should be {1}'.format(class_prior.shape[0],
45+
len(np.unique(y))))
4446
return self
4547

4648
def predict(self, T):
@@ -144,26 +146,23 @@ def test_stratified_shuffle_split_init():
144146

145147

146148
def test_stratified_shuffle_split_iter():
147-
ys = [
148-
np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
149-
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
150-
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
151-
np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
152-
np.array([-1] * 800 + [1] * 50)
153-
]
149+
ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
150+
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
151+
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
152+
np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
153+
np.array([-1] * 800 + [1] * 50)
154+
]
154155

155156
for y in ys:
156157
sss = cval.StratifiedShuffleSplit(y, 6, test_size=0.33,
157158
random_state=0, indices=True)
158159
for train, test in sss:
159160
assert_array_equal(unique(y[train]), unique(y[test]))
160161
# Checks if folds keep classes proportions
161-
p_train = np.bincount(
162-
unique(y[train], return_inverse=True)[1]
163-
) / float(len(y[train]))
164-
p_test = np.bincount(
165-
unique(y[test], return_inverse=True)[1]
166-
) / float(len(y[test]))
162+
p_train = (np.bincount(unique(y[train], return_inverse=True)[1]) /
163+
float(len(y[train])))
164+
p_test = (np.bincount(unique(y[test], return_inverse=True)[1]) /
165+
float(len(y[test])))
167166
assert_array_almost_equal(p_train, p_test, 1)
168167
assert_equal(y[train].size + y[test].size, y.size)
169168
assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
@@ -245,27 +244,26 @@ class BrokenEstimator:
245244

246245
def test_train_test_split_errors():
247246
assert_raises(ValueError, cval.train_test_split)
247+
assert_raises(ValueError, cval.train_test_split, range(3), train_size=1.1)
248+
assert_raises(ValueError, cval.train_test_split, range(3), test_size=0.6,
249+
train_size=0.6)
248250
assert_raises(ValueError, cval.train_test_split, range(3),
249-
train_size=1.1)
251+
test_size=np.float32(0.6), train_size=np.float32(0.6))
250252
assert_raises(ValueError, cval.train_test_split, range(3),
251-
test_size=0.6, train_size=0.6)
252-
assert_raises(ValueError, cval.train_test_split, range(3),
253-
test_size=np.float32(0.6), train_size=np.float32(0.6))
254-
assert_raises(ValueError, cval.train_test_split, range(3),
255-
test_size="wrong_type")
256-
assert_raises(ValueError, cval.train_test_split, range(3),
257-
test_size=2, train_size=4)
253+
test_size="wrong_type")
254+
assert_raises(ValueError, cval.train_test_split, range(3), test_size=2,
255+
train_size=4)
258256
assert_raises(TypeError, cval.train_test_split, range(3),
259-
some_argument=1.1)
257+
some_argument=1.1)
260258
assert_raises(ValueError, cval.train_test_split, range(3), range(42))
261259

262260

263261
def test_train_test_split():
264262
X = np.arange(100).reshape((10, 10))
265263
X_s = coo_matrix(X)
266264
y = range(10)
267-
X_train, X_test, X_s_train, X_s_test, y_train, y_test = \
268-
cval.train_test_split(X, X_s, y)
265+
split = cval.train_test_split(X, X_s, y)
266+
X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
269267
assert_array_equal(X_train, X_s_train.toarray())
270268
assert_array_equal(X_test, X_s_test.toarray())
271269
assert_array_equal(X_train[:, 0], y_train * 10)
@@ -283,13 +281,13 @@ def test_cross_val_score_with_score_func_classification():
283281
# Correct classification score (aka. zero / one score) - should be the
284282
# same as the default estimator score
285283
zo_scores = cval.cross_val_score(clf, iris.data, iris.target,
286-
score_func=zero_one_score, cv=5)
284+
score_func=zero_one_score, cv=5)
287285
assert_array_almost_equal(zo_scores, [1., 0.97, 0.90, 0.97, 1.], 2)
288286

289287
# F1 score (class are balanced so f1_score should be equal to zero/one
290288
# score
291289
f1_scores = cval.cross_val_score(clf, iris.data, iris.target,
292-
score_func=f1_score, cv=5)
290+
score_func=f1_score, cv=5)
293291
assert_array_almost_equal(f1_scores, [1., 0.97, 0.90, 0.97, 1.], 2)
294292

295293

@@ -309,13 +307,13 @@ def test_cross_val_score_with_score_func_regression():
309307

310308
# Mean squared error
311309
mse_scores = cval.cross_val_score(reg, X, y, cv=5,
312-
score_func=mean_squared_error)
310+
score_func=mean_squared_error)
313311
expected_mse = np.array([763.07, 553.16, 274.38, 273.26, 1681.99])
314312
assert_array_almost_equal(mse_scores, expected_mse, 2)
315313

316314
# Explained variance
317315
ev_scores = cval.cross_val_score(reg, X, y, cv=5,
318-
score_func=explained_variance_score)
316+
score_func=explained_variance_score)
319317
assert_array_almost_equal(ev_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
320318

321319

@@ -353,7 +351,7 @@ def test_permutation_score():
353351
y = np.mod(np.arange(len(y)), 3)
354352

355353
score, scores, pvalue = cval.permutation_test_score(svm, X, y,
356-
zero_one_score, cv)
354+
zero_one_score, cv)
357355

358356
assert_less(score, 0.5)
359357
assert_greater(pvalue, 0.4)
@@ -411,11 +409,10 @@ def test_shufflesplit_errors():
411409
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=2.0)
412410
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=1.0)
413411
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=0.1,
414-
train_size=0.95)
412+
train_size=0.95)
415413
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=11)
416414
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=10)
417-
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=8,
418-
train_size=3)
415+
assert_raises(ValueError, cval.ShuffleSplit, 10, test_size=8, train_size=3)
419416
assert_raises(ValueError, cval.ShuffleSplit, 10, train_size=1j)
420417

421418

0 commit comments

Comments
 (0)