|
5 | 5 | import sys |
6 | 6 | from sklearn.externals.six.moves import cStringIO as StringIO |
7 | 7 | import numpy as np |
| 8 | +import warnings |
8 | 9 | from sklearn.base import BaseEstimator |
9 | 10 | from sklearn.learning_curve import learning_curve, validation_curve |
10 | 11 | from sklearn.utils.testing import assert_raises |
@@ -84,8 +85,11 @@ def test_learning_curve(): |
84 | 85 | n_redundant=0, n_classes=2, |
85 | 86 | n_clusters_per_class=1, random_state=0) |
86 | 87 | estimator = MockImprovingEstimator(20) |
87 | | - train_sizes, train_scores, test_scores = learning_curve( |
88 | | - estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10)) |
| 88 | + with warnings.catch_warnings(record=True) as w: |
| 89 | + train_sizes, train_scores, test_scores = learning_curve( |
| 90 | + estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10)) |
| 91 | + if len(w) > 0: |
| 92 | + raise RuntimeError("Unexpected warning: %r" % w[0].message) |
89 | 93 | assert_equal(train_scores.shape, (10, 3)) |
90 | 94 | assert_equal(test_scores.shape, (10, 3)) |
91 | 95 | assert_array_equal(train_sizes, np.linspace(2, 20, 10)) |
@@ -239,8 +243,12 @@ def test_validation_curve(): |
239 | 243 | n_redundant=0, n_classes=2, |
240 | 244 | n_clusters_per_class=1, random_state=0) |
241 | 245 | param_range = np.linspace(0, 1, 10) |
242 | | - train_scores, test_scores = validation_curve(MockEstimatorWithParameter(), |
243 | | - X, y, param_name="param", |
244 | | - param_range=param_range, cv=2) |
| 246 | + with warnings.catch_warnings(record=True) as w: |
| 247 | + train_scores, test_scores = validation_curve( |
| 248 | + MockEstimatorWithParameter(), X, y, param_name="param", |
| 249 | + param_range=param_range, cv=2) |
| 250 | + if len(w) > 0: |
| 251 | + raise RuntimeError("Unexpected warning: %r" % w[0].message) |
| 252 | + |
245 | 253 | assert_array_almost_equal(train_scores.mean(axis=1), param_range) |
246 | 254 | assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range) |
0 commit comments