Skip to content

Commit f79ac7c

Browse files
committed
MAINT: Remove deprecation warnings in enet_path and lasso_path
1 parent 87b4784 commit f79ac7c

File tree

4 files changed

+22
-152
lines changed

4 files changed

+22
-152
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ API changes summary
281281
only. Similar changes apply to `'precision'` and `'recall'`.
282282
By `Joel Nothman`_.
283283

284+
- The ``fit_intercept``, ``normalize`` and ``return_models`` parameters in
285+
:func:`linear_model.enet_path` and :func:`linear_model.lasso_path` have
286+
been removed. They were deprecated since 0.14
287+
284288
.. _changes_0_15_2:
285289

286290
0.15.2

sklearn/linear_model/coordinate_descent.py

Lines changed: 14 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,8 @@ def _alpha_grid(X, y, Xy=None, l1_ratio=1.0, fit_intercept=True,
104104

105105

106106
def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
107-
precompute='auto', Xy=None, fit_intercept=None,
108-
normalize=None, copy_X=True, coef_init=None,
109-
verbose=False, return_models=False, return_n_iter=False,
110-
positive=False, **params):
107+
precompute='auto', Xy=None, copy_X=True, coef_init=None,
108+
verbose=False, return_n_iter=False, positive=False, **params):
111109
"""Compute Lasso path with coordinate descent
112110
113111
The Lasso optimization function varies for mono and multi-outputs.
@@ -156,14 +154,6 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
156154
Xy = np.dot(X.T, y) that can be precomputed. It is useful
157155
only when the Gram matrix is precomputed.
158156
159-
fit_intercept : bool
160-
Fit or not an intercept.
161-
WARNING : deprecated, will be removed in 0.16.
162-
163-
normalize : boolean, optional, default False
164-
If ``True``, the regressors X will be normalized before regression.
165-
WARNING : deprecated, will be removed in 0.16.
166-
167157
copy_X : boolean, optional, default True
168158
If ``True``, X will be copied; else, it may be overwritten.
169159
@@ -173,12 +163,6 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
173163
verbose : bool or integer
174164
Amount of verbosity.
175165
176-
return_models : boolean, optional, default True
177-
If ``True``, the function will return list of models. Setting it
178-
to ``False`` will change the function output returning the values
179-
of the alphas and the coefficients along the path. Returning the
180-
model list will be removed in version 0.16.
181-
182166
params : kwargs
183167
keyword arguments passed to the coordinate descent solver.
184168
@@ -187,30 +171,19 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
187171
188172
Returns
189173
-------
190-
models : a list of models along the regularization path
191-
(Is returned if ``return_models`` is set ``True`` (default).
192-
193174
alphas : array, shape (n_alphas,)
194175
The alphas along the path where models are computed.
195-
(Is returned, along with ``coefs``, when ``return_models`` is set
196-
to ``False``)
197176
198177
coefs : array, shape (n_features, n_alphas) or
199178
(n_outputs, n_features, n_alphas)
200179
Coefficients along the path.
201-
(Is returned, along with ``alphas``, when ``return_models`` is set
202-
to ``False``).
203180
204181
dual_gaps : array, shape (n_alphas,)
205182
The dual gaps at the end of the optimization for each alpha.
206-
(Is returned, along with ``alphas``, when ``return_models`` is set
207-
to ``False``).
208183
209184
n_iters : array-like, shape (n_alphas,)
210185
The number of iterations taken by the coordinate descent optimizer to
211186
reach the specified tolerance for each alpha.
212-
(Is returned, along with ``alphas``, when ``return_models`` is set
213-
to ``False`` and ``return_n_iter`` is set to ``True``).
214187
215188
Notes
216189
-----
@@ -225,11 +198,6 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
225198
interpolation can be used to retrieve model coefficients between the
226199
values output by lars_path
227200
228-
Deprecation Notice: Setting ``return_models`` to ``False`` will make
229-
the Lasso Path return an output in the style used by :func:`lars_path`.
230-
This will be become the norm as of version 0.16. Leaving ``return_models``
231-
set to `True` will let the function return a list of models as before.
232-
233201
Examples
234202
---------
235203
@@ -238,8 +206,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
238206
>>> X = np.array([[1, 2, 3.1], [2.3, 5.4, 4.3]]).T
239207
>>> y = np.array([1, 2, 3.1])
240208
>>> # Use lasso_path to compute a coefficient path
241-
>>> _, coef_path, _ = lasso_path(X, y, alphas=[5., 1., .5],
242-
... fit_intercept=False)
209+
>>> _, coef_path, _ = lasso_path(X, y, alphas=[5., 1., .5])
243210
>>> print(coef_path)
244211
[[ 0. 0. 0.46874778]
245212
[ 0.2159048 0.4425765 0.23689075]]
@@ -267,17 +234,13 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
267234
"""
268235
return enet_path(X, y, l1_ratio=1., eps=eps, n_alphas=n_alphas,
269236
alphas=alphas, precompute=precompute, Xy=Xy,
270-
fit_intercept=fit_intercept, normalize=normalize,
271237
copy_X=copy_X, coef_init=coef_init, verbose=verbose,
272-
return_models=return_models, positive=positive,
273-
**params)
238+
positive=positive, **params)
274239

275240

276241
def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
277-
precompute='auto', Xy=None, fit_intercept=True,
278-
normalize=False, copy_X=True, coef_init=None,
279-
verbose=False, return_models=False, return_n_iter=False,
280-
positive=False, **params):
242+
precompute='auto', Xy=None, copy_X=True, coef_init=None,
243+
verbose=False, return_n_iter=False, positive=False, **params):
281244
"""Compute elastic net path with coordinate descent
282245
283246
The elastic net optimization function varies for mono and multi-outputs.
@@ -334,14 +297,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
334297
Xy = np.dot(X.T, y) that can be precomputed. It is useful
335298
only when the Gram matrix is precomputed.
336299
337-
fit_intercept : bool
338-
Fit or not an intercept.
339-
WARNING : deprecated, will be removed in 0.16.
340-
341-
normalize : boolean, optional, default False
342-
If ``True``, the regressors X will be normalized before regression.
343-
WARNING : deprecated, will be removed in 0.16.
344-
345300
copy_X : boolean, optional, default True
346301
If ``True``, X will be copied; else, it may be overwritten.
347302
@@ -351,12 +306,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
351306
verbose : bool or integer
352307
Amount of verbosity.
353308
354-
return_models : boolean, optional, default False
355-
If ``True``, the function will return list of models. Setting it
356-
to ``False`` will change the function output returning the values
357-
of the alphas and the coefficients along the path. Returning the
358-
model list will be removed in version 0.16.
359-
360309
params : kwargs
361310
keyword arguments passed to the coordinate descent solver.
362311
@@ -368,75 +317,33 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
368317
369318
Returns
370319
-------
371-
models : a list of models along the regularization path
372-
(Is returned if ``return_models`` is set ``True`` (default).
373-
374320
alphas : array, shape (n_alphas,)
375321
The alphas along the path where models are computed.
376-
(Is returned, along with ``coefs``, when ``return_models`` is set
377-
to ``False``)
378322
379323
coefs : array, shape (n_features, n_alphas) or
380324
(n_outputs, n_features, n_alphas)
381325
Coefficients along the path.
382-
(Is returned, along with ``alphas``, when ``return_models`` is set
383-
to ``False``).
384326
385327
dual_gaps : array, shape (n_alphas,)
386328
The dual gaps at the end of the optimization for each alpha.
387-
(Is returned, along with ``alphas``, when ``return_models`` is set
388-
to ``False``).
389329
390330
n_iters : array-like, shape (n_alphas,)
391331
The number of iterations taken by the coordinate descent optimizer to
392332
reach the specified tolerance for each alpha.
393-
(Is returned, along with ``alphas``, when ``return_models`` is set
394-
to ``False`` and ``return_n_iter`` is set to True).
333+
(Is returned when ``return_n_iter`` is set to True).
395334
396335
Notes
397336
-----
398337
See examples/plot_lasso_coordinate_descent_path.py for an example.
399338
400-
Deprecation Notice: Setting ``return_models`` to ``False`` will make
401-
the Lasso Path return an output in the style used by :func:`lars_path`.
402-
This will be become the norm as of version 0.15. Leaving ``return_models``
403-
set to `True` will let the function return a list of models as before.
404-
405339
See also
406340
--------
407341
MultiTaskElasticNet
408342
MultiTaskElasticNetCV
409343
ElasticNet
410344
ElasticNetCV
411345
"""
412-
if return_models:
413-
warnings.warn("Use enet_path(return_models=False), as it returns the"
414-
" coefficients and alphas instead of just a list of"
415-
" models as previously `lasso_path`/`enet_path` did."
416-
" `return_models` will eventually be removed in 0.16,"
417-
" after which, returning alphas and coefs"
418-
" will become the norm.",
419-
DeprecationWarning, stacklevel=2)
420-
421-
if normalize is True:
422-
warnings.warn("normalize param will be removed in 0.16."
423-
" Intercept fitting and feature normalization will be"
424-
" done in estimators.",
425-
DeprecationWarning, stacklevel=2)
426-
else:
427-
normalize = False
428-
429-
if fit_intercept is True or fit_intercept is None:
430-
warnings.warn("fit_intercept param will be removed in 0.16."
431-
" Intercept fitting and feature normalization will be"
432-
" done in estimators.",
433-
DeprecationWarning, stacklevel=2)
434-
435-
if fit_intercept is None:
436-
fit_intercept = True
437-
438-
X = check_array(X, 'csc', dtype=np.float64, order='F', copy=copy_X and
439-
fit_intercept)
346+
X = check_array(X, 'csc', dtype=np.float64, order='F', copy=copy_X)
440347
n_samples, n_features = X.shape
441348

442349
multi_output = False
@@ -453,8 +360,11 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
453360
else:
454361
X_sparse_scaling = np.zeros(n_features)
455362

363+
# X should be normalized and fit already.
456364
X, y, X_mean, y_mean, X_std, precompute, Xy = \
457-
_pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy=False)
365+
_pre_fit(X, y, Xy, precompute, normalize=False, fit_intercept=False,
366+
copy=False)
367+
458368
if alphas is None:
459369
# No need to normalize of fit_intercept: it has been done
460370
# above
@@ -520,24 +430,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
520430
' to increase the number of iterations',
521431
ConvergenceWarning)
522432

523-
if return_models:
524-
if not multi_output:
525-
model = ElasticNet(
526-
alpha=alpha, l1_ratio=l1_ratio,
527-
fit_intercept=fit_intercept
528-
if sparse.isspmatrix(X) else False,
529-
precompute=precompute)
530-
else:
531-
model = MultiTaskElasticNet(
532-
alpha=alpha, l1_ratio=l1_ratio, fit_intercept=False)
533-
model.dual_gap_ = dual_gaps[i]
534-
model.coef_ = coefs[..., i]
535-
model.n_iter_ = n_iters[i]
536-
if (fit_intercept and not sparse.isspmatrix(X)) or multi_output:
537-
model.fit_intercept = True
538-
model._set_intercept(X_mean, y_mean, X_std)
539-
models.append(model)
540-
541433
if verbose:
542434
if verbose > 2:
543435
print(model)
@@ -546,12 +438,9 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
546438
else:
547439
sys.stderr.write('.')
548440

549-
if return_models:
550-
return models
551-
elif return_n_iter:
441+
if return_n_iter:
552442
return alphas, coefs, dual_gaps, n_iters
553-
else:
554-
return alphas, coefs, dual_gaps
443+
return alphas, coefs, dual_gaps
555444

556445

557446
###############################################################################
@@ -998,8 +887,6 @@ def _path_residuals(X, y, train, test, path, path_params, alphas=None,
998887
copy=False)
999888

1000889
path_params = path_params.copy()
1001-
path_params['fit_intercept'] = False
1002-
path_params['normalize'] = False
1003890
path_params['Xy'] = Xy
1004891
path_params['X_mean'] = X_mean
1005892
path_params['X_std'] = X_std

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,28 +206,20 @@ def test_lasso_path_return_models_vs_new_return_gives_same_coefficients():
206206
X = np.array([[1, 2, 3.1], [2.3, 5.4, 4.3]]).T
207207
y = np.array([1, 2, 3.1])
208208
alphas = [5., 1., .5]
209-
# Compute the lasso_path
210-
f = ignore_warnings
211-
coef_path = [e.coef_ for e in f(lasso_path)(X, y, alphas=alphas,
212-
return_models=True,
213-
fit_intercept=False)]
214209

215210
# Use lars_path and lasso_path(new output) with 1D linear interpolation
216211
# to compute the the same path
217212
alphas_lars, _, coef_path_lars = lars_path(X, y, method='lasso')
218213
coef_path_cont_lars = interpolate.interp1d(alphas_lars[::-1],
219214
coef_path_lars[:, ::-1])
220215
alphas_lasso2, coef_path_lasso2, _ = lasso_path(X, y, alphas=alphas,
221-
fit_intercept=False,
222216
return_models=False)
223217
coef_path_cont_lasso = interpolate.interp1d(alphas_lasso2[::-1],
224218
coef_path_lasso2[:, ::-1])
225219

226-
np.testing.assert_array_almost_equal(coef_path_cont_lasso(alphas),
227-
np.asarray(coef_path).T, decimal=1)
228-
np.testing.assert_array_almost_equal(coef_path_cont_lasso(alphas),
229-
coef_path_cont_lars(alphas),
230-
decimal=1)
220+
assert_array_almost_equal(
221+
coef_path_cont_lasso(alphas), coef_path_cont_lars(alphas),
222+
decimal=1)
231223

232224

233225
def test_enet_path():

sklearn/linear_model/tests/test_least_angle.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -325,20 +325,7 @@ def test_lasso_lars_vs_lasso_cd_ill_conditioned():
325325
tol=1e-6,
326326
fit_intercept=False)
327327

328-
# Check that the deprecated return_models=True yields the same coefs path
329-
with ignore_warnings():
330-
lasso_coef = np.zeros((w.shape[0], len(lars_alphas)))
331-
iter_models = enumerate(linear_model.lasso_path(X, y,
332-
alphas=lars_alphas,
333-
tol=1e-6,
334-
return_models=True,
335-
fit_intercept=False))
336-
for i, model in iter_models:
337-
lasso_coef[:, i] = model.coef_
338-
339-
np.testing.assert_array_almost_equal(lars_coef, lasso_coef, decimal=1)
340-
np.testing.assert_array_almost_equal(lars_coef, lasso_coef2, decimal=1)
341-
np.testing.assert_array_almost_equal(lasso_coef, lasso_coef2, decimal=1)
328+
assert_array_almost_equal(lars_coef, lasso_coef2, decimal=1)
342329

343330

344331
def test_lasso_lars_vs_lasso_cd_ill_conditioned2():

0 commit comments

Comments
 (0)