Skip to content

Commit 46f4c55

Browse files
committed
Merge pull request scikit-learn#1169 from larsmans/ridge-cg
Ridge CG performance improvements
2 parents 986ad8c + 0e00956 commit 46f4c55

File tree

3 files changed

+80
-62
lines changed

3 files changed

+80
-62
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
Changelog
99
---------
1010

11-
- :class:`feature_selection.SelectPercentile` now breaks ties deterministically
12-
instead of returning all equally ranked features.
11+
- :class:`feature_selection.SelectPercentile` now breaks ties
12+
deterministically instead of returning all equally ranked features.
13+
14+
- Ridge regression and ridge classification fitting no longer has
15+
quadratic memory complexity.
1316

1417
- Speed up of :func:`metrics.precision_recall_curve` by Conrad Lee.
1518

examples/document_classification_20newsgroups.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sklearn.datasets import fetch_20newsgroups
3333
from sklearn.feature_extraction.text import TfidfVectorizer
3434
from sklearn.feature_selection import SelectKBest, chi2
35+
from sklearn.linear_model import RidgeClassifier
3536
from sklearn.svm import LinearSVC
3637
from sklearn.linear_model import SGDClassifier
3738
from sklearn.linear_model import Perceptron
@@ -190,7 +191,8 @@ def benchmark(clf):
190191

191192

192193
results = []
193-
for clf, name in ((Perceptron(n_iter=50), "Perceptron"),
194+
for clf, name in ((RidgeClassifier(tol=1e-1), "Ridge Classifier"),
195+
(Perceptron(n_iter=50), "Perceptron"),
194196
(KNeighborsClassifier(n_neighbors=10), "kNN")):
195197
print 80 * '='
196198
print name

sklearn/linear_model/ridge.py

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
Ridge regression
33
"""
44

5-
# Author: Mathieu Blondel <[email protected]>
6-
# Reuben Fletcher-Costin <[email protected]>
5+
# Author: Mathieu Blondel <[email protected]>
6+
# Reuben Fletcher-Costin <[email protected]>
7+
# Fabian Pedregosa <[email protected]>
78
# License: Simplified BSD
89

910

1011
from abc import ABCMeta, abstractmethod
1112
import warnings
13+
1214
import numpy as np
15+
from scipy import linalg
16+
from scipy import sparse
17+
from scipy.sparse import linalg as sp_linalg
1318

1419
from .base import LinearClassifierMixin, LinearModel
1520
from ..base import RegressorMixin
@@ -19,49 +24,22 @@
1924
from ..grid_search import GridSearchCV
2025

2126

22-
def _solve(A, b, solver, tol):
23-
# helper method for ridge_regression, A is symmetric positive
24-
25-
if solver == 'auto':
26-
if hasattr(A, 'todense'):
27-
solver = 'sparse_cg'
28-
else:
29-
solver = 'dense_cholesky'
30-
31-
if solver == 'sparse_cg':
32-
if b.ndim < 2:
33-
from scipy.sparse import linalg as sp_linalg
34-
sol, error = sp_linalg.cg(A, b, tol=tol)
35-
if error:
36-
raise ValueError("Failed with error code %d" % error)
37-
return sol
38-
else:
39-
# sparse_cg cannot handle a 2-d b.
40-
sol = []
41-
for j in range(b.shape[1]):
42-
sol.append(_solve(A, b[:, j], solver="sparse_cg", tol=tol))
43-
return np.array(sol).T
44-
45-
elif solver == 'dense_cholesky':
46-
from scipy import linalg
47-
if hasattr(A, 'todense'):
48-
A = A.todense()
49-
return linalg.solve(A, b, sym_pos=True, overwrite_a=True)
50-
else:
51-
raise NotImplementedError('Solver %s not implemented' % solver)
52-
53-
54-
def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto', tol=1e-3):
27+
def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
28+
max_iter=None, tol=1e-3):
5529
"""Solve the ridge equation by the method of normal equations.
5630
5731
Parameters
5832
----------
59-
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
33+
X : {array-like, sparse matrix, LinearOperator}, shape = [n_samples, n_features]
6034
Training data
6135
6236
y : array-like, shape = [n_samples] or [n_samples, n_responses]
6337
Target values
6438
39+
max_iter : int, optional
40+
Maximum number of iterations for conjugate gradient solver.
41+
The default value is determined by scipy.sparse.linalg.
42+
6543
sample_weight : float or numpy array of shape [n_samples]
6644
Individual weights for each sample
6745
@@ -86,26 +64,55 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto', tol=1e-3):
8664
"""
8765

8866
n_samples, n_features = X.shape
89-
is_sparse = False
9067

91-
if hasattr(X, 'todense'): # lazy import of scipy.sparse
92-
from scipy import sparse
93-
is_sparse = sparse.issparse(X)
94-
95-
if is_sparse:
96-
if n_features > n_samples or \
97-
isinstance(sample_weight, np.ndarray) or \
98-
sample_weight != 1.0:
68+
if solver == 'auto':
69+
# cholesky if it's a dense array and cg in
70+
# any other case
71+
if hasattr(X, '__array__'):
72+
solver = 'dense_cholesky'
73+
else:
74+
solver = 'sparse_cg'
9975

100-
I = sparse.lil_matrix((n_samples, n_samples))
101-
I.setdiag(np.ones(n_samples) * alpha * sample_weight)
102-
c = _solve(X * X.T + I, y, solver, tol)
103-
coef = X.T * c
76+
if solver == 'sparse_cg':
77+
# gradient descent
78+
X1 = sp_linalg.aslinearoperator(X)
79+
if y.ndim == 1:
80+
y1 = np.reshape(y, (-1, 1))
10481
else:
105-
I = sparse.lil_matrix((n_features, n_features))
106-
I.setdiag(np.ones(n_features) * alpha)
107-
coef = _solve(X.T * X + I, X.T * y, solver, tol)
82+
y1 = y
83+
coefs = np.empty((y1.shape[1], n_features))
84+
85+
for i in range(y1.shape[1]):
86+
y_column = y1[:, i]
87+
if n_features > n_samples:
88+
# kernel ridge
89+
# w = X.T * inv(X X^t + alpha*Id) y
90+
def mv(x):
91+
return X1.matvec(X1.rmatvec(x)) + alpha * x
92+
C = sp_linalg.LinearOperator(
93+
(n_samples, n_samples), matvec=mv, dtype=X.dtype)
94+
coef, info = sp_linalg.cg(C, y_column, tol=tol)
95+
coefs[i] = X1.rmatvec(coef)
96+
else:
97+
# ridge
98+
# w = inv(X^t X + alpha*Id) * X.T y
99+
def mv(x):
100+
return X1.rmatvec(X1.matvec(x)) + alpha * x
101+
y_column = X1.rmatvec(y_column)
102+
C = sp_linalg.LinearOperator(
103+
(n_features, n_features), matvec=mv, dtype=X.dtype)
104+
coefs[i], info = sp_linalg.cg(C, y_column, maxiter=max_iter,
105+
tol=tol)
106+
if info != 0:
107+
raise ValueError("Failed with error code %d" % info)
108+
109+
if y.ndim == 1:
110+
return np.ravel(coefs)
111+
return coefs
108112
else:
113+
# normal equations (cholesky) method
114+
if sparse.issparse(X):
115+
X = X.toarray()
109116
if n_features > n_samples or \
110117
isinstance(sample_weight, np.ndarray) or \
111118
sample_weight != 1.0:
@@ -114,13 +121,13 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto', tol=1e-3):
114121
# w = X.T * inv(X X^t + alpha*Id) y
115122
A = np.dot(X, X.T)
116123
A.flat[::n_samples + 1] += alpha * sample_weight
117-
coef = np.dot(X.T, _solve(A, y, solver, tol))
124+
coef = np.dot(X.T, linalg.solve(A, y, sym_pos=True, overwrite_a=True))
118125
else:
119126
# ridge
120127
# w = inv(X^t X + alpha*Id) * X.T y
121128
A = np.dot(X.T, X)
122129
A.flat[::n_features + 1] += alpha
123-
coef = _solve(A, np.dot(X.T, y), solver, tol)
130+
coef = linalg.solve(A, np.dot(X.T, y), sym_pos=True, overwrite_a=True)
124131

125132
return coef.T
126133

@@ -130,7 +137,7 @@ class _BaseRidge(LinearModel):
130137

131138
@abstractmethod
132139
def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
133-
copy_X=True, tol=1e-3):
140+
copy_X=True, max_iter=None, tol=1e-3):
134141
self.alpha = alpha
135142
self.fit_intercept = fit_intercept
136143
self.normalize = normalize
@@ -204,6 +211,10 @@ class Ridge(_BaseRidge, RegressorMixin):
204211
copy_X : boolean, optional, default True
205212
If True, X will be copied; else, it may be overwritten.
206213
214+
max_iter : int, optional
215+
Maximum number of iterations for conjugate gradient solver.
216+
The default value is determined by scipy.sparse.linalg.
217+
207218
tol : float
208219
Precision of the solution.
209220
@@ -257,6 +268,10 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
257268
copy_X : boolean, optional, default True
258269
If True, X will be copied; else, it may be overwritten.
259270
271+
max_iter : int, optional
272+
Maximum number of iterations for conjugate gradient solver.
273+
The default value is determined by scipy.sparse.linalg.
274+
260275
tol : float
261276
Precision of the solution.
262277
@@ -281,10 +296,10 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
281296
advantage of the multi-variate response support in Ridge.
282297
"""
283298
def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
284-
copy_X=True, tol=1e-3, class_weight=None):
299+
copy_X=True, max_iter=None, tol=1e-3, class_weight=None):
285300
super(RidgeClassifier, self).__init__(alpha=alpha,
286301
fit_intercept=fit_intercept, normalize=normalize,
287-
copy_X=copy_X, tol=tol)
302+
copy_X=copy_X, max_iter=max_iter, tol=tol)
288303
self.class_weight = class_weight
289304

290305
def fit(self, X, y, solver='auto'):
@@ -381,7 +396,6 @@ def __init__(self, alphas=[0.1, 1.0, 10.0], fit_intercept=True,
381396
def _pre_compute(self, X, y):
382397
# even if X is very sparse, K is usually very dense
383398
K = safe_sparse_dot(X, X.T, dense_output=True)
384-
from scipy import linalg
385399
v, Q = linalg.eigh(K)
386400
QT_y = np.dot(Q.T, y)
387401
return v, Q, QT_y
@@ -418,7 +432,6 @@ def _values(self, alpha, y, v, Q, QT_y):
418432
return y - (c / G_diag), c
419433

420434
def _pre_compute_svd(self, X, y):
421-
from scipy import sparse
422435
if sparse.issparse(X) and hasattr(X, 'toarray'):
423436
X = X.toarray()
424437
U, s, _ = np.linalg.svd(X, full_matrices=0)

0 commit comments

Comments
 (0)