Skip to content

Commit 8307eee

Browse files
committed
Ridge regression now can use sample_weights in feature space. Summary commit over around 20 commits to avoid failing tests
1 parent 50e35b5 commit 8307eee

File tree

3 files changed

+306
-52
lines changed

3 files changed

+306
-52
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ Changelog
182182
:class:`cluster.WardAgglomeration` when no samples are given,
183183
rather than returning meaningless clustering.
184184

185+
- Ridge regression can now deal with sample weights in feature space
186+
(only sample space until then). By `Michael Eickenberg`_.
187+
Both solutions are provided by the Cholesky solver.
188+
185189

186190
API changes summary
187191
-------------------

sklearn/linear_model/ridge.py

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,20 @@ def _solve_lsqr(X, y, alpha, max_iter=None, tol=1e-3):
8484
return coefs
8585

8686

87-
def _solve_dense_cholesky(X, y, alpha, sample_weight=None):
87+
def _solve_cholesky(X, y, alpha, sample_weight=None):
8888
# w = inv(X^t X + alpha*Id) * X.T y
8989
n_samples, n_features = X.shape
9090
n_targets = y.shape[1]
9191

92-
has_sw = (sample_weight is not None) and (
93-
isinstance(sample_weight, np.ndarray) or sample_weight != 1.)
92+
has_sw = sample_weight is not None
9493

9594
if has_sw:
96-
sample_weight = np.atleast_1d(sample_weight).ravel()
97-
A = safe_sparse_dot(X.T * sample_weight, X, dense_output=True)
98-
Xy = safe_sparse_dot(X.T * sample_weight, y, dense_output=True)
95+
sample_weight = sample_weight * np.ones(n_samples)
96+
sample_weight_matrix = sparse.dia_matrix((sample_weight, 0),
97+
shape=(n_samples, n_samples))
98+
weighted_X = safe_sparse_dot(sample_weight_matrix, X)
99+
A = safe_sparse_dot(weighted_X.T, X, dense_output=True)
100+
Xy = safe_sparse_dot(weighted_X.T, y, dense_output=True)
99101
else:
100102
A = safe_sparse_dot(X.T, X, dense_output=True)
101103
Xy = safe_sparse_dot(X.T, y, dense_output=True)
@@ -116,16 +118,17 @@ def _solve_dense_cholesky(X, y, alpha, sample_weight=None):
116118
return coefs
117119

118120

119-
def _solve_dense_cholesky_kernel(K, y, alpha, sample_weight=1.0):
121+
def _solve_cholesky_kernel(K, y, alpha, sample_weight=None):
120122
# dual_coef = inv(X X^t + alpha*Id) y
121123
n_samples = K.shape[0]
122124
n_targets = y.shape[1]
123125

124126
one_alpha = np.array_equal(alpha, len(alpha) * [alpha[0]])
125-
has_sw = isinstance(sample_weight, np.ndarray) or sample_weight != 1.0
127+
128+
has_sw = sample_weight is not None
126129

127130
if has_sw:
128-
sw = np.sqrt(sample_weight)
131+
sw = np.sqrt(np.atleast_1d(sample_weight))
129132
y = y * sw[:, np.newaxis]
130133
K *= np.outer(sw, sw)
131134

@@ -172,7 +175,18 @@ def _solve_svd(X, y, alpha):
172175
return np.dot(Vt.T, d_UT_y).T
173176

174177

175-
def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
178+
def _deprecate_dense_cholesky(solver):
179+
if solver == 'dense_cholesky':
180+
import warnings
181+
warnings.warn(DeprecationWarning("The name 'dense_cholesky' is "
182+
"deprecated. Using 'cholesky' "
183+
"instead. Changed in 0.15"))
184+
solver = 'cholesky'
185+
186+
return solver
187+
188+
189+
def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
176190
max_iter=None, tol=1e-3):
177191
"""Solve the ridge equation by the method of normal equations.
178192
@@ -195,24 +209,25 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
195209
The default value is determined by scipy.sparse.linalg.
196210
197211
sample_weight : float or numpy array of shape [n_samples]
198-
Individual weights for each sample
212+
Individual weights for each sample. If sample_weight is set, then
213+
the solver will automatically be set to 'cholesky'
199214
200-
solver : {'auto', 'svd', 'dense_cholesky', 'lsqr', 'sparse_cg'}
215+
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg'}
201216
Solver to use in the computational routines:
202217
203218
- 'auto' chooses the solver automatically based on the type of data.
204219
205220
- 'svd' uses a Singular Value Decomposition of X to compute the Ridge
206221
coefficients. More stable for singular matrices than
207-
'dense_cholesky'.
222+
'cholesky'.
208223
209-
- 'dense_cholesky' uses the standard scipy.linalg.solve function to
224+
- 'cholesky' uses the standard scipy.linalg.solve function to
210225
obtain a closed-form solution via a Cholesky decomposition of
211226
dot(X.T, X)
212227
213228
- 'sparse_cg' uses the conjugate gradient solver as found in
214229
scipy.sparse.linalg.cg. As an iterative algorithm, this solver is
215-
more appropriate than 'dense_cholesky' for large-scale data
230+
more appropriate than 'cholesky' for large-scale data
216231
(possibility to set `tol` and `max_iter`).
217232
218233
- 'lsqr' uses the dedicated regularized least-squares routine
@@ -250,13 +265,15 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
250265
raise ValueError("Number of samples in X and y does not correspond:"
251266
" %d != %d" % (n_samples, n_samples_))
252267

253-
has_sw = isinstance(sample_weight, np.ndarray) or sample_weight != 1.0
268+
has_sw = sample_weight is not None
269+
270+
solver = _deprecate_dense_cholesky(solver)
254271

255272
if solver == 'auto':
256273
# cholesky if it's a dense array and cg in
257274
# any other case
258-
if hasattr(X, '__array__'):
259-
solver = 'dense_cholesky'
275+
if not sparse.issparse(X) or has_sw:
276+
solver = 'cholesky'
260277
else:
261278
solver = 'sparse_cg'
262279

@@ -265,10 +282,15 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
265282
to sparse_cg.""")
266283
solver = 'sparse_cg'
267284

268-
if has_sw and solver != "dense_cholesky":
269-
warnings.warn("""sample_weight and class_weight not supported in %s,
270-
fall back to dense_cholesky.""" % solver)
271-
solver = 'dense_cholesky'
285+
if has_sw:
286+
if np.atleast_1d(sample_weight).ndim > 1:
287+
raise ValueError("Sample weights must be 1D array or scalar")
288+
289+
if solver != "cholesky":
290+
warnings.warn("sample_weight and class_weight not"
291+
" supported in %s, fall back to "
292+
"cholesky." % solver)
293+
solver = 'cholesky'
272294

273295
# There should be either 1 or n_targets penalties
274296
alpha = safe_asarray(alpha).ravel()
@@ -280,7 +302,7 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
280302
if alpha.size == 1 and n_targets > 1:
281303
alpha = np.repeat(alpha, n_targets)
282304

283-
if solver not in ('sparse_cg', 'dense_cholesky', 'svd', 'lsqr'):
305+
if solver not in ('sparse_cg', 'cholesky', 'svd', 'lsqr'):
284306
ValueError('Solver %s not understood' % solver)
285307

286308
if solver == 'sparse_cg':
@@ -289,11 +311,11 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
289311
elif solver == "lsqr":
290312
coef = _solve_lsqr(X, y, alpha, max_iter, tol)
291313

292-
elif solver == 'dense_cholesky':
314+
elif solver == 'cholesky':
293315
if n_features > n_samples:
294316
K = safe_sparse_dot(X, X.T, dense_output=True)
295317
try:
296-
dual_coef = _solve_dense_cholesky_kernel(K, y, alpha,
318+
dual_coef = _solve_cholesky_kernel(K, y, alpha,
297319
sample_weight)
298320

299321
coef = safe_sparse_dot(X.T, dual_coef, dense_output=True).T
@@ -303,7 +325,7 @@ def ridge_regression(X, y, alpha, sample_weight=1.0, solver='auto',
303325

304326
else:
305327
try:
306-
coef = _solve_dense_cholesky(X, y, alpha, sample_weight)
328+
coef = _solve_cholesky(X, y, alpha, sample_weight)
307329
except linalg.LinAlgError:
308330
# use SVD solver if matrix is singular
309331
solver = 'svd'
@@ -331,20 +353,26 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
331353
self.tol = tol
332354
self.solver = solver
333355

334-
def fit(self, X, y, sample_weight=1.0):
356+
def fit(self, X, y, sample_weight=None):
335357
X = safe_asarray(X, dtype=np.float)
336358
y = np.asarray(y, dtype=np.float)
337359

360+
if ((sample_weight is not None) and
361+
np.atleast_1d(sample_weight).ndim > 1):
362+
raise ValueError("Sample weights must be 1D array or scalar")
363+
338364
X, y, X_mean, y_mean, X_std = self._center_data(
339365
X, y, self.fit_intercept, self.normalize, self.copy_X,
340366
sample_weight=sample_weight)
341367

368+
solver = _deprecate_dense_cholesky(self.solver)
369+
342370
self.coef_ = ridge_regression(X, y,
343371
alpha=self.alpha,
344372
sample_weight=sample_weight,
345373
max_iter=self.max_iter,
346374
tol=self.tol,
347-
solver=self.solver)
375+
solver=solver)
348376
self._set_intercept(X_mean, y_mean, X_std)
349377
return self
350378

@@ -383,21 +411,21 @@ class Ridge(_BaseRidge, RegressorMixin):
383411
normalize : boolean, optional, default False
384412
If True, the regressors X will be normalized before regression.
385413
386-
solver : {'auto', 'svd', 'dense_cholesky', 'lsqr', 'sparse_cg'}
414+
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg'}
387415
Solver to use in the computational routines:
388416
389417
- 'auto' chooses the solver automatically based on the type of data.
390418
391419
- 'svd' uses a Singular Value Decomposition of X to compute the Ridge
392420
coefficients. More stable for singular matrices than
393-
'dense_cholesky'.
421+
'cholesky'.
394422
395-
- 'dense_cholesky' uses the standard scipy.linalg.solve function to
423+
- 'cholesky' uses the standard scipy.linalg.solve function to
396424
obtain a closed-form solution.
397425
398426
- 'sparse_cg' uses the conjugate gradient solver as found in
399427
scipy.sparse.linalg.cg. As an iterative algorithm, this solver is
400-
more appropriate than 'dense_cholesky' for large-scale data
428+
more appropriate than 'cholesky' for large-scale data
401429
(possibility to set `tol` and `max_iter`).
402430
403431
- 'lsqr' uses the dedicated regularized least-squares routine
@@ -437,7 +465,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
437465
normalize=normalize, copy_X=copy_X,
438466
max_iter=max_iter, tol=tol, solver=solver)
439467

440-
def fit(self, X, y, sample_weight=1.0):
468+
def fit(self, X, y, sample_weight=None):
441469
"""Fit Ridge regression model
442470
443471
Parameters
@@ -489,10 +517,10 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
489517
normalize : boolean, optional, default False
490518
If True, the regressors X will be normalized before regression.
491519
492-
solver : {'auto', 'svd', 'dense_cholesky', 'lsqr', 'sparse_cg'}
520+
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg'}
493521
Solver to use in the computational
494522
routines. 'svd' will use a Singular value decomposition to obtain
495-
the solution, 'dense_cholesky' will use the standard
523+
the solution, 'cholesky' will use the standard
496524
scipy.linalg.solve function, 'sparse_cg' will use the
497525
conjugate gradient solver as found in
498526
scipy.sparse.linalg.cg while 'auto' will chose the most
@@ -551,7 +579,7 @@ def fit(self, X, y):
551579
# get the class weight corresponding to each sample
552580
sample_weight = cw[np.searchsorted(self.classes_, y)]
553581
else:
554-
sample_weight = 1.0
582+
sample_weight = None
555583

556584
super(RidgeClassifier, self).fit(X, Y, sample_weight=sample_weight)
557585
return self
@@ -679,7 +707,7 @@ def _values_svd(self, alpha, y, v, U, UT_y):
679707
G_diag = G_diag[:, np.newaxis]
680708
return y - (c / G_diag), c
681709

682-
def fit(self, X, y, sample_weight=1.0):
710+
def fit(self, X, y, sample_weight=None):
683711
"""Fit Ridge regression model
684712
685713
Parameters
@@ -744,10 +772,13 @@ def fit(self, X, y, sample_weight=1.0):
744772
error = scorer is None
745773

746774
for i, alpha in enumerate(self.alphas):
775+
weighted_alpha = (sample_weight * alpha
776+
if sample_weight is not None
777+
else alpha)
747778
if error:
748-
out, c = _errors(sample_weight * alpha, y, v, Q, QT_y)
779+
out, c = _errors(weighted_alpha, y, v, Q, QT_y)
749780
else:
750-
out, c = _values(sample_weight * alpha, y, v, Q, QT_y)
781+
out, c = _values(weighted_alpha, y, v, Q, QT_y)
751782
cv_values[:, i] = out.ravel()
752783
C.append(c)
753784

@@ -797,7 +828,7 @@ def __init__(self, alphas=np.array([0.1, 1.0, 10.0]),
797828
self.gcv_mode = gcv_mode
798829
self.store_cv_values = store_cv_values
799830

800-
def fit(self, X, y, sample_weight=1.0):
831+
def fit(self, X, y, sample_weight=None):
801832
"""Fit Ridge regression model
802833
803834
Parameters
@@ -1002,7 +1033,7 @@ def __init__(self, alphas=np.array([0.1, 1.0, 10.0]), fit_intercept=True,
10021033
score_func=score_func, loss_func=loss_func, cv=cv)
10031034
self.class_weight = class_weight
10041035

1005-
def fit(self, X, y, sample_weight=1.0, class_weight=None):
1036+
def fit(self, X, y, sample_weight=None, class_weight=None):
10061037
"""Fit the ridge classifier.
10071038
10081039
Parameters
@@ -1035,6 +1066,8 @@ def fit(self, X, y, sample_weight=1.0, class_weight=None):
10351066
" Using it in the 'fit' method is deprecated and "
10361067
"will be removed in 0.15.", DeprecationWarning,
10371068
stacklevel=2)
1069+
if sample_weight is None:
1070+
sample_weight = 1.
10381071

10391072
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
10401073
Y = self._label_binarizer.fit_transform(y)

0 commit comments

Comments
 (0)