Skip to content

Commit c52476a

Browse files
committed
Merge pull request scikit-learn#2882 from larsmans/expit
Speed up RBM training with scipy.special.expit
2 parents 89d94ca + fd54575 commit c52476a

File tree

5 files changed

+71
-55
lines changed

5 files changed

+71
-55
lines changed

sklearn/neural_network/rbm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Authors: Yann N. Dauphin <[email protected]>
55
# Vlad Niculae
66
# Gabriel Synnaeve
7+
# Lars Buitinck
78
# License: BSD 3 clause
89

910
import time
@@ -19,7 +20,8 @@
1920
from ..utils import gen_even_slices
2021
from ..utils import issparse
2122
from ..utils.extmath import safe_sparse_dot
22-
from ..utils.extmath import logistic_sigmoid
23+
from ..utils.extmath import log_logistic
24+
from ..utils.fixes import expit # logistic function
2325

2426

2527
class BernoulliRBM(BaseEstimator, TransformerMixin):
@@ -130,8 +132,9 @@ def _mean_hiddens(self, v):
130132
h : array-like, shape (n_samples, n_components)
131133
Corresponding mean field values for the hidden layer.
132134
"""
133-
return logistic_sigmoid(safe_sparse_dot(v, self.components_.T)
134-
+ self.intercept_hidden_)
135+
p = safe_sparse_dot(v, self.components_.T)
136+
p += self.intercept_hidden_
137+
return expit(p, out=p)
135138

136139
def _sample_hiddens(self, v, rng):
137140
"""Sample from the distribution P(h|v).
@@ -169,8 +172,9 @@ def _sample_visibles(self, h, rng):
169172
v : array-like, shape (n_samples, n_features)
170173
Values of the visible layer.
171174
"""
172-
p = logistic_sigmoid(np.dot(h, self.components_)
173-
+ self.intercept_visible_)
175+
p = np.dot(h, self.components_)
176+
p += self.intercept_visible_
177+
expit(p, out=p)
174178
p[rng.uniform(size=p.shape) < p] = 1.
175179
return np.floor(p, p)
176180

@@ -274,7 +278,7 @@ def score_samples(self, X):
274278

275279
fe = self._free_energy(v)
276280
fe_ = self._free_energy(v_)
277-
return v.shape[1] * logistic_sigmoid(fe_ - fe, log=True)
281+
return v.shape[1] * log_logistic(fe_ - fe)
278282

279283
def fit(self, X, y=None):
280284
"""Fit the model to the data X.

sklearn/utils/extmath.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from scipy.sparse import issparse
1515
from distutils.version import LooseVersion
1616

17-
from . import check_random_state
17+
from . import check_random_state, deprecated
1818
from .fixes import qr_economic
1919
from ._logistic_sigmoid import _log_logistic_sigmoid
2020
from ..externals.six.moves import xrange
@@ -567,36 +567,38 @@ def svd_flip(u, v):
567567
return u, v
568568

569569

570+
@deprecated('to be removed in 0.17; use scipy.special.expit or log_logistic')
570571
def logistic_sigmoid(X, log=False, out=None):
571-
"""
572-
Implements the logistic function, ``1 / (1 + e ** -x)`` and its log.
572+
"""Logistic function, ``1 / (1 + e ** (-x))``, or its log."""
573+
from .fixes import expit
574+
fn = log_logistic if log else expit
575+
return fn(X, out)
576+
573577

574-
This implementation is more stable by splitting on positive and negative
575-
values and computing::
576578

577-
1 / (1 + exp(-x_i)) if x_i > 0
578-
exp(x_i) / (1 + exp(x_i)) if x_i <= 0
579+
def log_logistic(X, out=None):
580+
"""Compute the log of the logistic function, ``log(1 / (1 + e ** -x))``.
579581
580-
The log is computed using::
582+
This implementation is numerically stable because it splits positive and
583+
negative values::
581584
582-
-log(1 + exp(-x_i)) if x_i > 0
585+
-log(1 + exp(-x_i)) if x_i > 0
583586
x_i - log(1 + exp(x_i)) if x_i <= 0
584587
588+
For the ordinary logistic function, use ``sklearn.utils.fixes.expit``.
589+
585590
Parameters
586591
----------
587592
X: array-like, shape (M, N)
588593
Argument to the logistic function
589594
590-
log: boolean, default: False
591-
Whether to compute the logarithm of the logistic function.
592-
593595
out: array-like, shape: (M, N), optional:
594596
Preallocated output array.
595597
596598
Returns
597599
-------
598600
out: array, shape (M, N)
599-
Value of the logistic function evaluated at every point in x
601+
Log of the logistic function evaluated at every point in x
600602
601603
Notes
602604
-----
@@ -611,15 +613,7 @@ def logistic_sigmoid(X, log=False, out=None):
611613
if out is None:
612614
out = np.empty_like(X)
613615

614-
if log:
615-
_log_logistic_sigmoid(n_samples, n_features, X, out)
616-
else:
617-
# logistic(x) = (1 + tanh(x / 2)) / 2
618-
out[:] = X
619-
out *= .5
620-
np.tanh(out, out)
621-
out += 1
622-
out *= .5
616+
_log_logistic_sigmoid(n_samples, n_features, X, out)
623617

624618
if is_1d:
625619
return np.squeeze(out)

sklearn/utils/fixes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Authors: Emmanuelle Gouillart <[email protected]>
77
# Gael Varoquaux <[email protected]>
88
# Fabian Pedregosa <[email protected]>
9+
# Lars Buitinck
910
#
1011
# License: BSD 3 clause
1112

@@ -109,6 +110,27 @@ def _logaddexp(x1, x2, out=None):
109110
logaddexp = np.logaddexp
110111

111112

113+
try:
114+
from scipy.special import expit # SciPy >= 0.10
115+
except ImportError:
116+
def expit(x, out=None):
117+
"""Logistic sigmoid function, ``1 / (1 + exp(-x))``.
118+
119+
See sklearn.utils.extmath.log_logistic for the log of this function.
120+
"""
121+
if out is None:
122+
out = np.copy(x)
123+
124+
# 1 / (1 + exp(-x)) = (1 + tanh(x / 2)) / 2
125+
# This way of computing the logistic is both fast and stable.
126+
out *= .5
127+
np.tanh(out, out)
128+
out += 1
129+
out *= .5
130+
131+
return out
132+
133+
112134
def _bincount(X, weights=None, minlength=None):
113135
"""Replacing np.bincount in numpy < 1.6 to provide minlength."""
114136
result = np.bincount(X, weights)

sklearn/utils/tests/test_extmath.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Denis Engemann <[email protected]>
44
#
55
# License: BSD 3 clause
6+
import warnings
7+
68
import numpy as np
79
from scipy import sparse
810
from scipy import linalg
@@ -22,7 +24,7 @@
2224
from sklearn.utils.extmath import row_norms
2325
from sklearn.utils.extmath import weighted_mode
2426
from sklearn.utils.extmath import cartesian
25-
from sklearn.utils.extmath import logistic_sigmoid
27+
from sklearn.utils.extmath import log_logistic, logistic_sigmoid
2628
from sklearn.utils.extmath import fast_dot, _fast_dot
2729
from sklearn.datasets.samples_generator import make_low_rank_matrix
2830

@@ -273,32 +275,16 @@ def test_cartesian():
273275

274276
def test_logistic_sigmoid():
275277
"""Check correctness and robustness of logistic sigmoid implementation"""
276-
naive_logsig = lambda x: 1 / (1 + np.exp(-x))
277-
naive_log_logsig = lambda x: np.log(naive_logsig(x))
278-
279-
# Simulate the previous Cython implementations of logistic_sigmoid based on
280-
#http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression
281-
def stable_logsig(x):
282-
out = np.zeros_like(x)
283-
positive = x > 0
284-
negative = x <= 0
285-
out[positive] = 1. / (1 + np.exp(-x[positive]))
286-
out[negative] = np.exp(x[negative]) / (1. + np.exp(x[negative]))
287-
return out
278+
naive_logistic = lambda x: 1 / (1 + np.exp(-x))
279+
naive_log_logistic = lambda x: np.log(naive_logistic(x))
288280

289281
x = np.linspace(-2, 2, 50)
290-
assert_array_almost_equal(logistic_sigmoid(x), naive_logsig(x))
291-
assert_array_almost_equal(logistic_sigmoid(x, log=True),
292-
naive_log_logsig(x))
293-
assert_array_almost_equal(logistic_sigmoid(x), stable_logsig(x),
294-
decimal=16)
295-
296-
extreme_x = np.array([-100, 100], dtype=np.float)
297-
assert_array_almost_equal(logistic_sigmoid(extreme_x), [0, 1])
298-
assert_array_almost_equal(logistic_sigmoid(extreme_x, log=True), [-100, 0])
299-
assert_array_almost_equal(logistic_sigmoid(extreme_x),
300-
stable_logsig(extreme_x),
301-
decimal=16)
282+
with warnings.catch_warnings(record=True):
283+
assert_array_almost_equal(logistic_sigmoid(x), naive_logistic(x))
284+
assert_array_almost_equal(log_logistic(x), naive_log_logistic(x))
285+
286+
extreme_x = np.array([-100., 100.])
287+
assert_array_almost_equal(log_logistic(extreme_x), [-100, 0])
302288

303289

304290
def test_fast_dot():

sklearn/utils/tests/test_fixes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import numpy as np
66

77
from nose.tools import assert_equal
8-
from numpy.testing import assert_array_equal
8+
from numpy.testing import assert_almost_equal, assert_array_equal
99

10-
from ..fixes import _in1d, _copysign, divide
10+
from ..fixes import _in1d, _copysign, divide, expit
1111

1212

1313
def test_in1d():
@@ -16,6 +16,16 @@ def test_in1d():
1616
assert_equal(_in1d(a, b).sum(), 5)
1717

1818

19+
def test_expit():
20+
"""Check numerical stability of expit (logistic function)."""
21+
22+
# Simulate our previous Cython implementation, based on
23+
#http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression
24+
assert_almost_equal(expit(100.), 1. / (1. + np.exp(-100.)), decimal=16)
25+
assert_almost_equal(expit(-100.), np.exp(-100.) / (1. + np.exp(-100.)),
26+
decimal=16)
27+
28+
1929
def test_divide():
2030
assert_equal(divide(.6, 1), .600000000000)
2131

0 commit comments

Comments
 (0)