Skip to content

Commit 58994ca

Browse files
fbchowsalliewalecka
authored andcommitted
MAINT Standardize sample weights validation in KernelDensity (scikit-learn#15519)
Co-authored-by: Sallie Walecka <[email protected]>
1 parent 225a8f3 commit 58994ca

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

sklearn/neighbors/_kde.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
from scipy.special import gammainc
99
from ..base import BaseEstimator
10-
from ..utils import check_array, check_random_state, check_consistent_length
10+
from ..utils import check_array, check_random_state
11+
from ..utils.validation import _check_sample_weight
1112

1213
from ..utils.extmath import row_norms
1314
from ._ball_tree import BallTree, DTYPE
@@ -154,13 +155,7 @@ def fit(self, X, y=None, sample_weight=None):
154155
X = check_array(X, order='C', dtype=DTYPE)
155156

156157
if sample_weight is not None:
157-
sample_weight = check_array(sample_weight, order='C', dtype=DTYPE,
158-
ensure_2d=False)
159-
if sample_weight.ndim != 1:
160-
raise ValueError("the shape of sample_weight must be ({0},),"
161-
" but was {1}".format(X.shape[0],
162-
sample_weight.shape))
163-
check_consistent_length(X, sample_weight)
158+
sample_weight = _check_sample_weight(sample_weight, X, DTYPE)
164159
if sample_weight.min() <= 0:
165160
raise ValueError("sample_weight must have positive values")
166161

sklearn/neighbors/tests/test_kde.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,21 @@ def test_kde_sample_weights():
204204
assert_allclose(scores_scaled_weight, scores_weight)
205205

206206

207+
def test_sample_weight_invalid():
208+
# Check sample weighting raises errors.
209+
kde = KernelDensity()
210+
data = np.reshape([1., 2., 3.], (-1, 1))
211+
212+
sample_weight = [0.1, 0.2]
213+
with pytest.raises(ValueError):
214+
kde.fit(data, sample_weight=sample_weight)
215+
216+
sample_weight = [0.1, -0.2, 0.3]
217+
expected_err = "sample_weight must have positive values"
218+
with pytest.raises(ValueError, match=expected_err):
219+
kde.fit(data, sample_weight=sample_weight)
220+
221+
207222
@pytest.mark.parametrize('sample_weight', [None, [0.1, 0.2, 0.3]])
208223
def test_pickling(tmpdir, sample_weight):
209224
# Make sure that predictions are the same before and after pickling. Used

0 commit comments

Comments
 (0)