Skip to content

Commit a54590b

Browse files
committed
Merge pull request scikit-learn#3656 from dougalsutherland/fix-rbf-samp
FIX RBFSampler's incorrect bandwidth
2 parents cd3c2db + f0f26bc commit a54590b

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

doc/whats_new.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ Bug fixes
9191
appropriate error message and suggests a work around.
9292
By `Danny Sullivan`_.
9393

94+
- :class:`RBFSampler <kernel_approximation.RBFSampler>` with ``gamma=g``
95+
formerly approximated :func:`rbf_kernel <metrics.pairwise.rbf_kernel>`
96+
with ``gamma=g/2.``; the definition of ``gamma`` is now consistent,
97+
which may substantially change your results if you use a fixed value.
98+
(If you cross-validated over ``gamma``, it probably doesn't matter
99+
too much.) By `Dougal Sutherland`_.
100+
94101

95102
API changes summary
96103
-------------------
@@ -2969,3 +2976,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
29692976
.. _Nikolay Mayorov: https://github.com/nmayorov
29702977

29712978
.. _Jatin Shah: http://jatinshah.org/
2979+
2980+
.. _Dougal Sutherland: https://github.com/dougalsutherland

sklearn/kernel_approximation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def fit(self, X, y=None):
6969
random_state = check_random_state(self.random_state)
7070
n_features = X.shape[1]
7171

72-
self.random_weights_ = (np.sqrt(self.gamma) * random_state.normal(
72+
self.random_weights_ = (np.sqrt(2 * self.gamma) * random_state.normal(
7373
size=(n_features, self.n_components)))
7474

7575
self.random_offset_ = random_state.uniform(0, 2 * np.pi,

sklearn/tests/test_kernel_approximation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sklearn.utils.testing import assert_array_equal, assert_equal
55
from sklearn.utils.testing import assert_not_equal
66
from sklearn.utils.testing import assert_array_almost_equal, assert_raises
7+
from sklearn.utils.testing import assert_less_equal
78

89
from sklearn.metrics.pairwise import kernel_metrics
910
from sklearn.kernel_approximation import RBFSampler
@@ -121,7 +122,11 @@ def test_rbf_sampler():
121122
Y_trans = rbf_transform.transform(Y)
122123
kernel_approx = np.dot(X_trans, Y_trans.T)
123124

124-
assert_array_almost_equal(kernel, kernel_approx, 1)
125+
error = kernel - kernel_approx
126+
assert_less_equal(np.abs(np.mean(error)), 0.01) # close to unbiased
127+
np.abs(error, out=error)
128+
assert_less_equal(np.max(error), 0.1) # nothing too far off
129+
assert_less_equal(np.mean(error), 0.05) # mean is fairly close
125130

126131

127132
def test_input_validation():

0 commit comments

Comments
 (0)