Skip to content

Commit 7675f33

Browse files
author
maxime
committed
allow Ridge's fit() method to take initial coefficients
1 parent c2b2ce4 commit 7675f33

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

sklearn/linear_model/ridge.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..metrics.scorer import check_scoring
3030

3131

32-
def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3):
32+
def _solve_sparse_cg(X, y, alpha, max_iter=None, tol=1e-3, coef_init=None):
3333
n_samples, n_features = X.shape
3434
X1 = sp_linalg.aslinearoperator(X)
3535
coefs = np.empty((y.shape[1], n_features))
@@ -54,15 +54,19 @@ def _mv(x):
5454
# w = X.T * inv(X X^t + alpha*Id) y
5555
C = sp_linalg.LinearOperator(
5656
(n_samples, n_samples), matvec=mv, dtype=X.dtype)
57-
coef, info = sp_linalg.cg(C, y_column, tol=tol)
57+
if coef_init is not None:
58+
x0 = X1.matvec(coef_init)
59+
else:
60+
x0 = None
61+
coef, info = sp_linalg.cg(C, y_column, x0=x0, tol=tol)
5862
coefs[i] = X1.rmatvec(coef)
5963
else:
6064
# linear ridge
6165
# w = inv(X^t X + alpha*Id) * X.T y
6266
y_column = X1.rmatvec(y_column)
6367
C = sp_linalg.LinearOperator(
6468
(n_features, n_features), matvec=mv, dtype=X.dtype)
65-
coefs[i], info = sp_linalg.cg(C, y_column, maxiter=max_iter,
69+
coefs[i], info = sp_linalg.cg(C, y_column, x0=coef_init, maxiter=max_iter,
6670
tol=tol)
6771
if info != 0:
6872
raise ValueError("Failed with error code %d" % info)
@@ -187,7 +191,7 @@ def _deprecate_dense_cholesky(solver):
187191

188192

189193
def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
190-
max_iter=None, tol=1e-3):
194+
max_iter=None, tol=1e-3, coef_init=None):
191195
"""Solve the ridge equation by the method of normal equations.
192196
193197
Parameters
@@ -306,7 +310,7 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
306310
raise ValueError('Solver %s not understood' % solver)
307311

308312
if solver == 'sparse_cg':
309-
coef = _solve_sparse_cg(X, y, alpha, max_iter, tol)
313+
coef = _solve_sparse_cg(X, y, alpha, max_iter, tol, coef_init)
310314

311315
elif solver == "lsqr":
312316
coef = _solve_lsqr(X, y, alpha, max_iter, tol)
@@ -353,7 +357,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
353357
self.tol = tol
354358
self.solver = solver
355359

356-
def fit(self, X, y, sample_weight=None):
360+
def fit(self, X, y, coef_init=None, sample_weight=None):
357361
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=np.float, multi_output=True)
358362

359363
if ((sample_weight is not None) and
@@ -371,7 +375,9 @@ def fit(self, X, y, sample_weight=None):
371375
sample_weight=sample_weight,
372376
max_iter=self.max_iter,
373377
tol=self.tol,
374-
solver=solver)
378+
solver=solver,
379+
coef_init=coef_init,
380+
)
375381
self._set_intercept(X_mean, y_mean, X_std)
376382
return self
377383

@@ -464,7 +470,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
464470
normalize=normalize, copy_X=copy_X,
465471
max_iter=max_iter, tol=tol, solver=solver)
466472

467-
def fit(self, X, y, sample_weight=None):
473+
def fit(self, X, y, coef_init=None, sample_weight=None):
468474
"""Fit Ridge regression model
469475
470476
Parameters
@@ -482,7 +488,7 @@ def fit(self, X, y, sample_weight=None):
482488
-------
483489
self : returns an instance of self.
484490
"""
485-
return super(Ridge, self).fit(X, y, sample_weight=sample_weight)
491+
return super(Ridge, self).fit(X, y, coef_init=coef_init, sample_weight=sample_weight)
486492

487493

488494
class RidgeClassifier(LinearClassifierMixin, _BaseRidge):

0 commit comments

Comments
 (0)