Skip to content

Commit 270d47a

Browse files
committed
Fix gibbs sampling behavior in RBM with integral random_state.
1 parent 73e5cf5 commit 270d47a

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

sklearn/neural_network/rbm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ def gibbs(self, v):
210210
Values of the visible layer after one Gibbs step.
211211
"""
212212
check_is_fitted(self, "components_")
213-
214-
rng = check_random_state(self.random_state)
215-
h_ = self._sample_hiddens(v, rng)
216-
v_ = self._sample_visibles(h_, rng)
213+
if not hasattr(self, "random_state_"):
214+
self.random_state_ = check_random_state(self.random_state)
215+
h_ = self._sample_hiddens(v, self.random_state_)
216+
v_ = self._sample_visibles(h_, self.random_state_)
217217

218218
return v_
219219

sklearn/neural_network/tests/test_rbm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,16 @@ def test_fit_gibbs_sparse():
131131

132132

133133
def test_gibbs_smoke():
134-
"""Check if we don't get NaNs sampling the full digits dataset."""
135-
rng = np.random.RandomState(42)
134+
"""Check if we don't get NaNs sampling the full digits dataset.
135+
Also check that sampling again will yield different results."""
136136
X = Xdigits
137137
rbm1 = BernoulliRBM(n_components=42, batch_size=40,
138-
n_iter=20, random_state=rng)
138+
n_iter=20, random_state=42)
139139
rbm1.fit(X)
140140
X_sampled = rbm1.gibbs(X)
141141
assert_all_finite(X_sampled)
142+
X_sampled2 = rbm1.gibbs(X)
143+
assert_true(np.all((X_sampled != X_sampled2).max(axis=1)))
142144

143145

144146
def test_score_samples():

0 commit comments

Comments
 (0)