Skip to content

Commit 0c22cbb

Browse files
committed
FIX numerical stability issue in BernoulliRBM
Fixes scikit-learn#2785 by replacing log1p(exp(x)) with logaddexp(0, x).
1 parent 7324e8f commit 0c22cbb

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

sklearn/neural_network/rbm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Restricted Boltzmann Machine
22
"""
33

4-
# Main author: Yann N. Dauphin <[email protected]>
5-
# Author: Vlad Niculae
6-
# Author: Gabriel Synnaeve
7-
# License: BSD Style.
4+
# Authors: Yann N. Dauphin <[email protected]>
5+
# Vlad Niculae
6+
# Gabriel Synnaeve
7+
# License: BSD 3 clause
88

99
import time
1010

@@ -188,8 +188,8 @@ def _free_energy(self, v):
188188
The value of the free energy.
189189
"""
190190
return (- safe_sparse_dot(v, self.intercept_visible_)
191-
- np.log1p(np.exp(safe_sparse_dot(v, self.components_.T)
192-
+ self.intercept_hidden_)).sum(axis=1))
191+
- np.logaddexp(0, safe_sparse_dot(v, self.components_.T)
192+
+ self.intercept_hidden_).sum(axis=1))
193193

194194
def gibbs(self, v):
195195
"""Perform one Gibbs sampling step.

sklearn/neural_network/tests/test_rbm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def test_score_samples():
127127
s_score = rbm1.score_samples(lil_matrix(X))
128128
assert_almost_equal(d_score, s_score)
129129

130+
# Test numerical stability (#2785): would previously generate infinities
131+
# and crash with an exception.
132+
with np.errstate(under='ignore'):
133+
rbm1.score_samples(np.arange(1000) * 100)
134+
130135

131136
def test_rbm_verbose():
132137
rbm = BernoulliRBM(n_iter=2, verbose=10)

0 commit comments

Comments
 (0)