Skip to content

Commit a03abd6

Browse files
committed
Merge pull request scikit-learn#3090 from ogrisel/learning-curves-warnings
[MRG] FIX: remove deprecation warnings in learning curves under Python 3
2 parents ea91673 + 6c0d41a commit a03abd6

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

sklearn/learning_curve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 5),
132132
verbose, parameters=None, fit_params=None, return_train_score=True)
133133
for train, test in cv for n_train_samples in train_sizes_abs)
134134
out = np.array(out)[:, :2]
135-
n_cv_folds = out.shape[0] / n_unique_ticks
135+
n_cv_folds = out.shape[0] // n_unique_ticks
136136
out = out.reshape(n_cv_folds, n_unique_ticks, 2)
137137

138138
out = np.asarray(out).transpose((2, 1, 0))
@@ -297,7 +297,7 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None,
297297

298298
out = np.asarray(out)[:, :2]
299299
n_params = len(param_range)
300-
n_cv_folds = out.shape[0] / n_params
300+
n_cv_folds = out.shape[0] // n_params
301301
out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0))
302302

303303
return out[0], out[1]

sklearn/tests/test_learning_curve.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
from sklearn.externals.six.moves import cStringIO as StringIO
77
import numpy as np
8+
import warnings
89
from sklearn.base import BaseEstimator
910
from sklearn.learning_curve import learning_curve, validation_curve
1011
from sklearn.utils.testing import assert_raises
@@ -84,8 +85,11 @@ def test_learning_curve():
8485
n_redundant=0, n_classes=2,
8586
n_clusters_per_class=1, random_state=0)
8687
estimator = MockImprovingEstimator(20)
87-
train_sizes, train_scores, test_scores = learning_curve(
88-
estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10))
88+
with warnings.catch_warnings(record=True) as w:
89+
train_sizes, train_scores, test_scores = learning_curve(
90+
estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10))
91+
if len(w) > 0:
92+
raise RuntimeError("Unexpected warning: %r" % w[0].message)
8993
assert_equal(train_scores.shape, (10, 3))
9094
assert_equal(test_scores.shape, (10, 3))
9195
assert_array_equal(train_sizes, np.linspace(2, 20, 10))
@@ -239,8 +243,12 @@ def test_validation_curve():
239243
n_redundant=0, n_classes=2,
240244
n_clusters_per_class=1, random_state=0)
241245
param_range = np.linspace(0, 1, 10)
242-
train_scores, test_scores = validation_curve(MockEstimatorWithParameter(),
243-
X, y, param_name="param",
244-
param_range=param_range, cv=2)
246+
with warnings.catch_warnings(record=True) as w:
247+
train_scores, test_scores = validation_curve(
248+
MockEstimatorWithParameter(), X, y, param_name="param",
249+
param_range=param_range, cv=2)
250+
if len(w) > 0:
251+
raise RuntimeError("Unexpected warning: %r" % w[0].message)
252+
245253
assert_array_almost_equal(train_scores.mean(axis=1), param_range)
246254
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)

0 commit comments

Comments
 (0)