|
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | from scipy import interpolate, sparse |
9 | | -from copy import deepcopy |
10 | 9 |
|
11 | 10 | from sklearn.utils.testing import assert_array_almost_equal |
12 | 11 | from sklearn.utils.testing import assert_almost_equal |
@@ -470,12 +469,23 @@ def test_precompute_invalid_argument(): |
470 | 469 |
|
471 | 470 | def test_warm_start_convergence(): |
472 | 471 | X, y, _, _ = build_dataset() |
473 | | - cold_model = ElasticNet(alpha=1e-3, tol=1e-5, max_iter=2000) |
474 | | - cold_model.fit(X, y) |
475 | | - warm_model = deepcopy(cold_model).set_params(warm_start=True) |
476 | | - warm_model.fit(X, y) |
477 | | - assert_greater(cold_model.n_iter_, warm_model.n_iter_) |
478 | | - assert_greater(warm_model.n_iter_, 1) |
| 472 | + model = ElasticNet(alpha=1e-3, tol=1e-3).fit(X, y) |
| 473 | + n_iter_reference = model.n_iter_ |
| 474 | + |
| 475 | + # This dataset is not trivial enough for the model to converge in one pass. |
| 476 | + assert_greater(n_iter_reference, 2) |
| 477 | + |
| 478 | + # Fit the same model again, using a cold start |
| 479 | + model.fit(X, y) |
| 480 | + n_iter_cold_start = model.n_iter_ |
| 481 | + assert_equal(n_iter_cold_start, n_iter_reference) |
| 482 | + |
| 483 | + # Fit the same model again, using a warm start: the optimizer just perform |
| 484 | + # a single pass before checking that it has already converged |
| 485 | + model.set_params(warm_start=True) |
| 486 | + model.fit(X, y) |
| 487 | + n_iter_warm_start = model.n_iter_ |
| 488 | + assert_equal(n_iter_warm_start, 1) |
479 | 489 |
|
480 | 490 |
|
481 | 491 | if __name__ == '__main__': |
|
0 commit comments