Skip to content

Commit 5147fd0

Browse files
Arthur Menschogrisel
authored andcommitted
Add SAGA solver for LogisticRegression and Ridge (scikit-learn#8446)
1 parent 7877f3c commit 5147fd0

File tree

13 files changed

+1112
-256
lines changed

13 files changed

+1112
-256
lines changed

benchmarks/bench_mnist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def load_data(dtype=np.float32, order='F'):
9191
Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)),
9292
'SampledRBF-SVM': make_pipeline(
9393
RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
94-
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4),
94+
'LogisticRegression-SAG': LogisticRegression(solver='sag', tol=1e-1,
95+
C=1e4),
96+
'LogisticRegression-SAGA': LogisticRegression(solver='saga', tol=1e-1,
97+
C=1e4),
9598
'MultilayerPerceptron': MLPClassifier(
9699
hidden_layer_sizes=(100, 100), max_iter=400, alpha=1e-4,
97100
solver='sgd', learning_rate_init=0.2, momentum=0.9, verbose=1,

benchmarks/bench_saga.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Author: Arthur Mensch
2+
3+
Benchmarks of sklearn SAGA vs lightning SAGA vs Liblinear. Shows the gain
4+
in using multinomial logistic regression in term of learning time.
5+
"""
6+
import json
7+
import time
8+
from os.path import expanduser
9+
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
13+
from sklearn.datasets import fetch_rcv1, load_iris, load_digits, \
14+
fetch_20newsgroups_vectorized
15+
from sklearn.externals.joblib import delayed, Parallel, Memory
16+
from sklearn.linear_model import LogisticRegression
17+
from sklearn.metrics import log_loss
18+
from sklearn.model_selection import train_test_split
19+
from sklearn.preprocessing import LabelBinarizer, LabelEncoder
20+
from sklearn.utils.extmath import safe_sparse_dot, softmax
21+
22+
23+
def fit_single(solver, X, y, penalty='l2', single_target=True, C=1,
24+
max_iter=10, skip_slow=False):
25+
if skip_slow and solver == 'lightning' and penalty == 'l1':
26+
print('skip_slowping l1 logistic regression with solver lightning.')
27+
return
28+
29+
print('Solving %s logistic regression with penalty %s, solver %s.'
30+
% ('binary' if single_target else 'multinomial',
31+
penalty, solver))
32+
33+
if solver == 'lightning':
34+
from lightning.classification import SAGAClassifier
35+
36+
if single_target or solver not in ['sag', 'saga']:
37+
multi_class = 'ovr'
38+
else:
39+
multi_class = 'multinomial'
40+
41+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42,
42+
stratify=y)
43+
n_samples = X_train.shape[0]
44+
n_classes = np.unique(y_train).shape[0]
45+
test_scores = [1]
46+
train_scores = [1]
47+
accuracies = [1 / n_classes]
48+
times = [0]
49+
50+
if penalty == 'l2':
51+
alpha = 1. / (C * n_samples)
52+
beta = 0
53+
lightning_penalty = None
54+
else:
55+
alpha = 0.
56+
beta = 1. / (C * n_samples)
57+
lightning_penalty = 'l1'
58+
59+
for this_max_iter in range(1, max_iter + 1, 2):
60+
print('[%s, %s, %s] Max iter: %s' %
61+
('binary' if single_target else 'multinomial',
62+
penalty, solver, this_max_iter))
63+
if solver == 'lightning':
64+
lr = SAGAClassifier(loss='log', alpha=alpha, beta=beta,
65+
penalty=lightning_penalty,
66+
tol=-1, max_iter=this_max_iter)
67+
else:
68+
lr = LogisticRegression(solver=solver,
69+
multi_class=multi_class,
70+
C=C,
71+
penalty=penalty,
72+
fit_intercept=False, tol=1e-24,
73+
max_iter=this_max_iter,
74+
random_state=42,
75+
)
76+
t0 = time.clock()
77+
lr.fit(X_train, y_train)
78+
train_time = time.clock() - t0
79+
80+
scores = []
81+
for (X, y) in [(X_train, y_train), (X_test, y_test)]:
82+
try:
83+
y_pred = lr.predict_proba(X)
84+
except NotImplementedError:
85+
# Lightning predict_proba is not implemented for n_classes > 2
86+
y_pred = _predict_proba(lr, X)
87+
score = log_loss(y, y_pred, normalize=False) / n_samples
88+
score += (0.5 * alpha * np.sum(lr.coef_ ** 2) +
89+
beta * np.sum(np.abs(lr.coef_)))
90+
scores.append(score)
91+
train_score, test_score = tuple(scores)
92+
93+
y_pred = lr.predict(X_test)
94+
accuracy = np.sum(y_pred == y_test) / y_test.shape[0]
95+
test_scores.append(test_score)
96+
train_scores.append(train_score)
97+
accuracies.append(accuracy)
98+
times.append(train_time)
99+
return lr, times, train_scores, test_scores, accuracies
100+
101+
102+
def _predict_proba(lr, X):
103+
pred = safe_sparse_dot(X, lr.coef_.T)
104+
if hasattr(lr, "intercept_"):
105+
pred += lr.intercept_
106+
return softmax(pred)
107+
108+
109+
def exp(solvers, penalties, single_target, n_samples=30000, max_iter=20,
110+
dataset='rcv1', n_jobs=1, skip_slow=False):
111+
mem = Memory(cachedir=expanduser('~/cache'), verbose=0)
112+
113+
if dataset == 'rcv1':
114+
rcv1 = fetch_rcv1()
115+
116+
lbin = LabelBinarizer()
117+
lbin.fit(rcv1.target_names)
118+
119+
X = rcv1.data
120+
y = rcv1.target
121+
y = lbin.inverse_transform(y)
122+
le = LabelEncoder()
123+
y = le.fit_transform(y)
124+
if single_target:
125+
y_n = y.copy()
126+
y_n[y > 16] = 1
127+
y_n[y <= 16] = 0
128+
y = y_n
129+
130+
elif dataset == 'digits':
131+
digits = load_digits()
132+
X, y = digits.data, digits.target
133+
if single_target:
134+
y_n = y.copy()
135+
y_n[y < 5] = 1
136+
y_n[y >= 5] = 0
137+
y = y_n
138+
elif dataset == 'iris':
139+
iris = load_iris()
140+
X, y = iris.data, iris.target
141+
elif dataset == '20newspaper':
142+
ng = fetch_20newsgroups_vectorized()
143+
X = ng.data
144+
y = ng.target
145+
if single_target:
146+
y_n = y.copy()
147+
y_n[y > 4] = 1
148+
y_n[y <= 16] = 0
149+
y = y_n
150+
151+
X = X[:n_samples]
152+
y = y[:n_samples]
153+
154+
cached_fit = mem.cache(fit_single)
155+
out = Parallel(n_jobs=n_jobs, mmap_mode=None)(
156+
delayed(cached_fit)(solver, X, y,
157+
penalty=penalty, single_target=single_target,
158+
C=1, max_iter=max_iter, skip_slow=skip_slow)
159+
for solver in solvers
160+
for penalty in penalties)
161+
162+
res = []
163+
idx = 0
164+
for solver in solvers:
165+
for penalty in penalties:
166+
if not (skip_slow and solver == 'lightning' and penalty == 'l1'):
167+
lr, times, train_scores, test_scores, accuracies = out[idx]
168+
this_res = dict(solver=solver, penalty=penalty,
169+
single_target=single_target,
170+
times=times, train_scores=train_scores,
171+
test_scores=test_scores,
172+
accuracies=accuracies)
173+
res.append(this_res)
174+
idx += 1
175+
176+
with open('bench_saga.json', 'w+') as f:
177+
json.dump(res, f)
178+
179+
180+
def plot():
181+
import pandas as pd
182+
with open('bench_saga.json', 'r') as f:
183+
f = json.load(f)
184+
res = pd.DataFrame(f)
185+
res.set_index(['single_target', 'penalty'], inplace=True)
186+
187+
grouped = res.groupby(level=['single_target', 'penalty'])
188+
189+
colors = {'saga': 'blue', 'liblinear': 'orange', 'lightning': 'green'}
190+
191+
for idx, group in grouped:
192+
single_target, penalty = idx
193+
fig = plt.figure(figsize=(12, 4))
194+
ax = fig.add_subplot(131)
195+
196+
train_scores = group['train_scores'].values
197+
ref = np.min(np.concatenate(train_scores)) * 0.999
198+
199+
for scores, times, solver in zip(group['train_scores'], group['times'],
200+
group['solver']):
201+
scores = scores / ref - 1
202+
ax.plot(times, scores, label=solver, color=colors[solver])
203+
ax.set_xlabel('Time (s)')
204+
ax.set_ylabel('Training objective (relative to min)')
205+
ax.set_yscale('log')
206+
207+
ax = fig.add_subplot(132)
208+
209+
test_scores = group['test_scores'].values
210+
ref = np.min(np.concatenate(test_scores)) * 0.999
211+
212+
for scores, times, solver in zip(group['test_scores'], group['times'],
213+
group['solver']):
214+
scores = scores / ref - 1
215+
ax.plot(times, scores, label=solver, color=colors[solver])
216+
ax.set_xlabel('Time (s)')
217+
ax.set_ylabel('Test objective (relative to min)')
218+
ax.set_yscale('log')
219+
220+
ax = fig.add_subplot(133)
221+
222+
for accuracy, times, solver in zip(group['accuracies'], group['times'],
223+
group['solver']):
224+
ax.plot(times, accuracy, label=solver, color=colors[solver])
225+
ax.set_xlabel('Time (s)')
226+
ax.set_ylabel('Test accuracy')
227+
ax.legend()
228+
name = 'single_target' if single_target else 'multi_target'
229+
name += '_%s' % penalty
230+
plt.suptitle(name)
231+
name += '.png'
232+
fig.tight_layout()
233+
fig.subplots_adjust(top=0.9)
234+
plt.savefig(name)
235+
plt.close(fig)
236+
237+
238+
if __name__ == '__main__':
239+
solvers = ['saga', 'liblinear', 'lightning']
240+
penalties = ['l1', 'l2']
241+
single_target = True
242+
exp(solvers, penalties, single_target, n_samples=None, n_jobs=1,
243+
dataset='20newspaper', max_iter=20)
244+
plot()

doc/modules/linear_model.rst

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ optimization problem
721721
.. math:: \underset{w, c}{min\,} \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) .
722722

723723
The solvers implemented in the class :class:`LogisticRegression`
724-
are "liblinear", "newton-cg", "lbfgs" and "sag":
724+
are "liblinear", "newton-cg", "lbfgs", "sag" and "saga":
725725

726726
The solver "liblinear" uses a coordinate descent (CD) algorithm, and relies
727727
on the excellent C++ `LIBLINEAR library
@@ -739,25 +739,31 @@ The "lbfgs", "sag" and "newton-cg" solvers only support L2 penalization and
739739
are found to converge faster for some high dimensional data. Setting
740740
`multi_class` to "multinomial" with these solvers learns a true multinomial
741741
logistic regression model [5]_, which means that its probability estimates
742-
should be better calibrated than the default "one-vs-rest" setting. The
743-
"lbfgs", "sag" and "newton-cg"" solvers cannot optimize L1-penalized models,
744-
therefore the "multinomial" setting does not learn sparse models.
742+
should be better calibrated than the default "one-vs-rest" setting.
745743

746-
The solver "sag" uses a Stochastic Average Gradient descent [6]_. It is faster
744+
The "sag" solver uses a Stochastic Average Gradient descent [6]_. It is faster
747745
than other solvers for large datasets, when both the number of samples and the
748746
number of features are large.
749747

748+
The "saga" solver [7]_ is a variant of "sag" that also supports the
749+
non-smooth `penalty="l1"` option. This is therefore the solver of choice
750+
for sparse multinomial logistic regression.
751+
750752
In a nutshell, one may choose the solver with the following rules:
751753

752-
================================= =============================
754+
================================= =====================================
753755
Case Solver
754-
================================= =============================
755-
Small dataset or L1 penalty "liblinear"
756-
Multinomial loss or large dataset "lbfgs", "sag" or "newton-cg"
757-
Very Large dataset "sag"
758-
================================= =============================
756+
================================= =====================================
757+
L1 penalty "liblinear" or "saga"
758+
Multinomial loss "lbfgs", "sag", "saga" or "newton-cg"
759+
Very Large dataset (`n_samples`) "sag" or "saga"
760+
================================= =====================================
761+
762+
The "saga" solver is often the best choice. The "liblinear" solver is
763+
used by default for historical reasons.
759764

760-
For large dataset, you may also consider using :class:`SGDClassifier` with 'log' loss.
765+
For large dataset, you may also consider using :class:`SGDClassifier`
766+
with 'log' loss.
761767

762768
.. topic:: Examples:
763769

@@ -767,6 +773,10 @@ For large dataset, you may also consider using :class:`SGDClassifier` with 'log'
767773

768774
* :ref:`sphx_glr_auto_examples_linear_model_plot_logistic_multinomial.py`
769775

776+
* :ref:`sphx_glr_auto_examples_linear_model_plot_sparse_logistic_regression_20newsgroups.py`
777+
778+
* :ref:`sphx_glr_auto_examples_linear_model_plot_sparse_logistic_regression_mnist.py`
779+
770780
.. _liblinear_differences:
771781

772782
.. topic:: Differences from liblinear:
@@ -788,20 +798,23 @@ For large dataset, you may also consider using :class:`SGDClassifier` with 'log'
788798
thus be used to perform feature selection, as detailed in
789799
:ref:`l1_feature_selection`.
790800

791-
:class:`LogisticRegressionCV` implements Logistic Regression with builtin
792-
cross-validation to find out the optimal C parameter. "newton-cg", "sag" and
793-
"lbfgs" solvers are found to be faster for high-dimensional dense data, due to
794-
warm-starting. For the multiclass case, if `multi_class` option is set to
795-
"ovr", an optimal C is obtained for each class and if the `multi_class` option
796-
is set to "multinomial", an optimal C is obtained by minimizing the cross-
797-
entropy loss.
801+
:class:`LogisticRegressionCV` implements Logistic Regression with
802+
builtin cross-validation to find out the optimal C parameter.
803+
"newton-cg", "sag", "saga" and "lbfgs" solvers are found to be faster
804+
for high-dimensional dense data, due to warm-starting. For the
805+
multiclass case, if `multi_class` option is set to "ovr", an optimal C
806+
is obtained for each class and if the `multi_class` option is set to
807+
"multinomial", an optimal C is obtained by minimizing the cross-entropy
808+
loss.
798809

799810
.. topic:: References:
800811

801812
.. [5] Christopher M. Bishop: Pattern Recognition and Machine Learning, Chapter 4.3.4
802813
803814
.. [6] Mark Schmidt, Nicolas Le Roux, and Francis Bach: `Minimizing Finite Sums with the Stochastic Average Gradient. <https://hal.inria.fr/hal-00860051/document>`_
804815
816+
.. [7] Aaron Defazio, Francis Bach, Simon Lacoste-Julien: `SAGA: A Fast Incremental Gradient Method With Support for Non-Strongly Convex Composite Objectives. <https://arxiv.org/abs/1407.0202>`_
817+
805818
Stochastic Gradient Descent - SGD
806819
=================================
807820

doc/whats_new.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ New features
5050
particularly useful for targets with an exponential trend.
5151
:issue:`7655` by :user:`Karan Desai <karandesai-96>`.
5252

53+
- Added solver ``saga`` that implements the improved version of Stochastic
54+
Average Gradient, in :class:`linear_model.LogisticRegression` and
55+
:class:`linear_model.Ridge`. It allows the use of L1 penalty with
56+
multinomial logistic loss, and behaves marginally better than 'sag'
57+
during the first epochs of ridge and logistic regression.
58+
By `Arthur Mensch`_.
59+
5360
Enhancements
5461
............
5562

@@ -5037,3 +5044,4 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
50375044
.. _Anish Shah: https://github.com/AnishShah
50385045

50395046
.. _Neeraj Gangwar: http://neerajgangwar.in
5047+
.. _Arthur Mensch: https://amensch.fr

0 commit comments

Comments
 (0)