Skip to content

Commit 43e5454

Browse files
committed
Merge pull request scikit-learn#5251 from TomDLT/sag_multi
[MRG+1] add multinomial SAG solver for LogisticRegression
2 parents 62b48b9 + 9f136ff commit 43e5454

File tree

14 files changed

+888
-483
lines changed

14 files changed

+888
-483
lines changed

doc/modules/linear_model.rst

Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -683,64 +683,62 @@ Logistic regression
683683

684684
Logistic regression, despite its name, is a linear model for classification
685685
rather than regression. Logistic regression is also known in the literature as
686-
logit regression, maximum-entropy classification (MaxEnt)
687-
or the log-linear classifier. In this model, the probabilities describing the possible outcomes of a single trial are modeled using a `logistic function <http://en.wikipedia.org/wiki/Logistic_function>`_.
686+
logit regression, maximum-entropy classification (MaxEnt) or the log-linear
687+
classifier. In this model, the probabilities describing the possible outcomes
688+
of a single trial are modeled using a `logistic function
689+
<http://en.wikipedia.org/wiki/Logistic_function>`_.
688690

689691
The implementation of logistic regression in scikit-learn can be accessed from
690-
class :class:`LogisticRegression`. This
691-
implementation can fit a multiclass (one-vs-rest) logistic regression with optional
692-
L2 or L1 regularization.
692+
class :class:`LogisticRegression`. This implementation can fit binary, One-vs-
693+
Rest, or multinomial logistic regression with optional L2 or L1
694+
regularization.
693695

694-
As an optimization problem, binary class L2 penalized logistic regression minimizes
695-
the following cost function:
696+
As an optimization problem, binary class L2 penalized logistic regression
697+
minimizes the following cost function:
696698

697699
.. math:: \underset{w, c}{min\,} \frac{1}{2}w^T w + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) .
698700

699-
Similarly, L1 regularized logistic regression solves the following optimization problem
701+
Similarly, L1 regularized logistic regression solves the following
702+
optimization problem
700703

701704
.. math:: \underset{w, c}{min\,} \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) .
702705

703706
The solvers implemented in the class :class:`LogisticRegression`
704-
are "liblinear" (which is a wrapper around the C++ library,
705-
LIBLINEAR), "newton-cg", "lbfgs" and "sag".
706-
707-
The "lbfgs" and "newton-cg" solvers only support L2 penalization and are found
708-
to converge faster for some high dimensional data. L1 penalization yields
709-
sparse predicting weights.
710-
711-
The solver "liblinear" uses a coordinate descent (CD) algorithm based on
712-
Liblinear. For L1 penalization :func:`sklearn.svm.l1_min_c` allows to
713-
calculate the lower bound for C in order to get a non "null" (all feature weights to
714-
zero) model. This relies on the excellent
715-
`LIBLINEAR library <http://www.csie.ntu.edu.tw/~cjlin/liblinear/>`_,
716-
which is shipped with scikit-learn. However, the CD algorithm implemented in
717-
liblinear cannot learn a true multinomial (multiclass) model;
718-
instead, the optimization problem is decomposed in a "one-vs-rest" fashion
719-
so separate binary classifiers are trained for all classes.
720-
This happens under the hood, so :class:`LogisticRegression` instances
721-
using this solver behave as multiclass classifiers.
722-
723-
Setting `multi_class` to "multinomial" with the "lbfgs" or "newton-cg" solver
724-
in :class:`LogisticRegression` learns a true multinomial logistic
725-
regression model, which means that its probability estimates should
726-
be better calibrated than the default "one-vs-rest" setting.
727-
"lbfgs", "newton-cg" and "sag" solvers cannot optimize L1-penalized models, though, so the "multinomial" setting does not learn sparse models.
728-
729-
The solver "sag" uses a Stochastic Average Gradient descent [3]_. It does not
730-
handle "multinomial" case, and is limited to L2-penalized models, yet it is
731-
often faster than other solvers for large datasets, when both the number of
732-
samples and the number of features are large.
707+
are "liblinear", "newton-cg", "lbfgs" and "sag":
708+
709+
The solver "liblinear" uses a coordinate descent (CD) algorithm, and relies
710+
on the excellent C++ `LIBLINEAR library
711+
<http://www.csie.ntu.edu.tw/~cjlin/liblinear/>`_, which is shipped with
712+
scikit-learn. However, the CD algorithm implemented in liblinear cannot learn
713+
a true multinomial (multiclass) model; instead, the optimization problem is
714+
decomposed in a "one-vs-rest" fashion so separate binary classifiers are
715+
trained for all classes. This happens under the hood, so
716+
:class:`LogisticRegression` instances using this solver behave as multiclass
717+
classifiers. For L1 penalization :func:`sklearn.svm.l1_min_c` allows to
718+
calculate the lower bound for C in order to get a non "null" (all feature
719+
weights to zero) model.
720+
721+
The "lbfgs", "sag" and "newton-cg" solvers only support L2 penalization and
722+
are found to converge faster for some high dimensional data. Setting
723+
`multi_class` to "multinomial" with these solvers learns a true multinomial
724+
logistic regression model [3]_, which means that its probability estimates
725+
should be better calibrated than the default "one-vs-rest" setting. The
726+
"lbfgs", "sag" and "newton-cg"" solvers cannot optimize L1-penalized models,
727+
therefore the "multinomial" setting does not learn sparse models.
728+
729+
The solver "sag" uses a Stochastic Average Gradient descent [4]_. It is faster
730+
than other solvers for large datasets, when both the number of samples and the
731+
number of features are large.
733732

734733
In a nutshell, one may choose the solver with the following rules:
735734

736-
=========================== ======================
737-
Case Solver
738-
=========================== ======================
739-
Small dataset or L1 penalty "liblinear"
740-
Multinomial loss "lbfgs" or newton-cg"
741-
Large dataset "sag"
742-
=========================== ======================
743-
735+
================================= =============================
736+
Case Solver
737+
================================= =============================
738+
Small dataset or L1 penalty "liblinear"
739+
Multinomial loss or large dataset "lbfgs", "sag" or newton-cg"
740+
Very Large dataset "sag"
741+
================================= =============================
744742
For large dataset, you may also consider using :class:`SGDClassifier` with 'log' loss.
745743

746744
.. topic:: Examples:
@@ -770,18 +768,19 @@ For large dataset, you may also consider using :class:`SGDClassifier` with 'log'
770768
thus be used to perform feature selection, as detailed in
771769
:ref:`l1_feature_selection`.
772770

773-
:class:`LogisticRegressionCV` implements Logistic Regression with
774-
builtin cross-validation to find out the optimal C parameter.
775-
"newton-cg", "sag" and "lbfgs" solvers are found to be faster
776-
for high-dimensional dense data, due to warm-starting.
777-
For the multiclass case, if `multi_class`
778-
option is set to "ovr", an optimal C is obtained for each class and if
779-
the `multi_class` option is set to "multinomial", an optimal C is
780-
obtained that minimizes the cross-entropy loss.
771+
:class:`LogisticRegressionCV` implements Logistic Regression with builtin
772+
cross-validation to find out the optimal C parameter. "newton-cg", "sag" and
773+
"lbfgs" solvers are found to be faster for high-dimensional dense data, due to
774+
warm-starting. For the multiclass case, if `multi_class` option is set to
775+
"ovr", an optimal C is obtained for each class and if the `multi_class` option
776+
is set to "multinomial", an optimal C is obtained by minimizing the cross-
777+
entropy loss.
781778

782779
.. topic:: References:
783780

784-
.. [3] Mark Schmidt, Nicolas Le Roux, and Francis Bach: `Minimizing Finite Sums with the Stochastic Average Gradient. <http://hal.inria.fr/hal-00860051/PDF/sag_journal.pdf>`_
781+
.. [3] Christopher M. Bishop: Pattern Recognition and Machine Learning, Chapter 4.3.4
782+
783+
.. [4] Mark Schmidt, Nicolas Le Roux, and Francis Bach: `Minimizing Finite Sums with the Stochastic Average Gradient. <http://hal.inria.fr/hal-00860051/PDF/sag_journal.pdf>`_
785784
786785
Stochastic Gradient Descent - SGD
787786
=================================

doc/whats_new.rst

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,29 @@ Enhancements
4040

4141
- The random forest, extra trees and decision tree estimators now has a
4242
method ``decision_path`` which returns the decision path of samples in
43-
the tree. By `Arnaud Joly`_
43+
the tree. By `Arnaud Joly`_.
4444

4545

4646
- The random forest, extra tree and decision tree estimators now has a
4747
method ``decision_path`` which returns the decision path of samples in
48-
the tree. By `Arnaud Joly`_
48+
the tree. By `Arnaud Joly`_.
4949

5050
- A new example has been added unveling the decision tree structure.
51-
By `Arnaud Joly`_
51+
By `Arnaud Joly`_.
5252

5353
- Random forest, extra trees, decision trees and gradient boosting estimator
5454
accept the parameter ``min_samples_split`` and ``min_samples_leaf``
5555
provided as a percentage of the training samples. By
56-
`yelite`_ and `Arnaud Joly`_
56+
`yelite`_ and `Arnaud Joly`_.
57+
58+
- Codebase does not contain C/C++ cython generated files: they are
59+
generated during build. Distribution packages will still contain generated
60+
C/C++ files. By `Arthur Mensch`_.
5761

58-
- Codebase does not contain C/C++ cython generated files: they are
59-
generated during build. Distribution packages will still contain generated
60-
C/C++ files. By `Arthur Mensch`_
62+
- In :class:`linear_model.LogisticRegression`, the SAG solver is now
63+
available in the multinomial case.
64+
(`#5251 <https://github.com/scikit-learn/scikit-learn/pull/5251>`_)
65+
By `Tom Dupre la Tour`_.
6166

6267
Bug fixes
6368
.........
@@ -155,10 +160,6 @@ New features
155160
shuffling step in the ``cd`` solver.
156161
By `Tom Dupre la Tour`_ and `Mathieu Blondel`_.
157162

158-
- **IndexError** bug `#5495
159-
<https://github.com/scikit-learn/scikit-learn/issues/5495>`_ when
160-
doing OVR(SVC(decision_function_shape="ovr")). Fixed by `Elvis Dohmatob`_.
161-
162163
Enhancements
163164
............
164165
- :class:`manifold.TSNE` now supports approximate optimization via the
@@ -435,6 +436,10 @@ Bug fixes
435436
``class_weight='balanced'```or ``class_weight='auto'``.
436437
By `Tom Dupre la Tour`_.
437438

439+
- Fixed bug `#5495 <https://github.com/scikit-learn/scikit-learn/issues/5495>`_ when
440+
doing OVR(SVC(decision_function_shape="ovr")). Fixed by `Elvis Dohmatob`_.
441+
442+
438443
API changes summary
439444
-------------------
440445
- Attribute `data_min`, `data_max` and `data_range` in
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
====================================================
3+
Plot multinomial and One-vs-Rest Logistic Regression
4+
====================================================
5+
6+
Plot decision surface of multinomial and One-vs-Rest Logistic Regression.
7+
The hyperplanes corresponding to the three One-vs-Rest (OVR) classifiers
8+
are represented by the dashed lines.
9+
"""
10+
print(__doc__)
11+
# Authors: Tom Dupre la Tour <[email protected]>
12+
# Licence: BSD 3 clause
13+
14+
import numpy as np
15+
import matplotlib.pyplot as plt
16+
from sklearn.datasets import make_blobs
17+
from sklearn.linear_model import LogisticRegression
18+
19+
# make 3-class dataset for classification
20+
centers = [[-5, 0], [0, 1.5], [5, -1]]
21+
X, y = make_blobs(n_samples=1000, centers=centers, random_state=40)
22+
transformation = [[0.4, 0.2], [-0.4, 1.2]]
23+
X = np.dot(X, transformation)
24+
25+
for multi_class in ('multinomial', 'ovr'):
26+
clf = LogisticRegression(solver='sag', max_iter=100, random_state=42,
27+
multi_class=multi_class).fit(X, y)
28+
29+
# print the training scores
30+
print("training score : %.3f (%s)" % (clf.score(X, y), multi_class))
31+
32+
# create a mesh to plot in
33+
h = .02 # step size in the mesh
34+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
35+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
36+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
37+
np.arange(y_min, y_max, h))
38+
39+
# Plot the decision boundary. For that, we will assign a color to each
40+
# point in the mesh [x_min, m_max]x[y_min, y_max].
41+
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
42+
# Put the result into a color plot
43+
Z = Z.reshape(xx.shape)
44+
plt.figure()
45+
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
46+
plt.title("Decision surface of LogisticRegression (%s)" % multi_class)
47+
plt.axis('tight')
48+
49+
# Plot also the training points
50+
colors = "bry"
51+
for i, color in zip(clf.classes_, colors):
52+
idx = np.where(y == i)
53+
plt.scatter(X[idx, 0], X[idx, 1], c=color, cmap=plt.cm.Paired)
54+
55+
# Plot the three one-against-all classifiers
56+
xmin, xmax = plt.xlim()
57+
ymin, ymax = plt.ylim()
58+
coef = clf.coef_
59+
intercept = clf.intercept_
60+
61+
def plot_hyperplane(c, color):
62+
def line(x0):
63+
return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]
64+
plt.plot([xmin, xmax], [line(xmin), line(xmax)],
65+
ls="--", color=color)
66+
67+
for i, color in zip(clf.classes_, colors):
68+
plot_hyperplane(i, color)
69+
70+
plt.show()

sklearn/linear_model/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def make_dataset(X, y, sample_weight, random_state=None):
6161
seed = rng.randint(1, np.iinfo(np.int32).max)
6262

6363
if sp.issparse(X):
64-
dataset = CSRDataset(X.data, X.indptr, X.indices,
65-
y, sample_weight, seed=seed)
64+
dataset = CSRDataset(X.data, X.indptr, X.indices, y, sample_weight,
65+
seed=seed)
6666
intercept_decay = SPARSE_INTERCEPT_DECAY
6767
else:
6868
dataset = ArrayDataset(X, y, sample_weight, seed=seed)

0 commit comments

Comments
 (0)