Skip to content

Commit 059c834

Browse files
committed
TST: Better to test that warm_start runs only once after the prev model has converged
1 parent 947da72 commit 059c834

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
from scipy import interpolate, sparse
9-
from copy import deepcopy
109

1110
from sklearn.utils.testing import assert_array_almost_equal
1211
from sklearn.utils.testing import assert_almost_equal
@@ -470,12 +469,23 @@ def test_precompute_invalid_argument():
470469

471470
def test_warm_start_convergence():
472471
X, y, _, _ = build_dataset()
473-
cold_model = ElasticNet(alpha=1e-3, tol=1e-5, max_iter=2000)
474-
cold_model.fit(X, y)
475-
warm_model = deepcopy(cold_model).set_params(warm_start=True)
476-
warm_model.fit(X, y)
477-
assert_greater(cold_model.n_iter_, warm_model.n_iter_)
478-
assert_greater(warm_model.n_iter_, 1)
472+
model = ElasticNet(alpha=1e-3, tol=1e-3).fit(X, y)
473+
n_iter_reference = model.n_iter_
474+
475+
# This dataset is not trivial enough for the model to converge in one pass.
476+
assert_greater(n_iter_reference, 2)
477+
478+
# Fit the same model again, using a cold start
479+
model.fit(X, y)
480+
n_iter_cold_start = model.n_iter_
481+
assert_equal(n_iter_cold_start, n_iter_reference)
482+
483+
# Fit the same model again, using a warm start: the optimizer just perform
484+
# a single pass before checking that it has already converged
485+
model.set_params(warm_start=True)
486+
model.fit(X, y)
487+
n_iter_warm_start = model.n_iter_
488+
assert_equal(n_iter_warm_start, 1)
479489

480490

481491
if __name__ == '__main__':

0 commit comments

Comments
 (0)