Skip to content

Commit 7b5c703

Browse files
authored
ENH Add fitted check for kde (scikit-learn#16762)
1 parent bcc24c9 commit 7b5c703

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

sklearn/neighbors/_kde.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy.special import gammainc
99
from ..base import BaseEstimator
1010
from ..utils import check_array, check_random_state
11-
from ..utils.validation import _check_sample_weight
11+
from ..utils.validation import _check_sample_weight, check_is_fitted
1212

1313
from ..utils.extmath import row_norms
1414
from ._ball_tree import BallTree, DTYPE
@@ -184,6 +184,7 @@ def score_samples(self, X):
184184
probability densities, so values will be low for high-dimensional
185185
data.
186186
"""
187+
check_is_fitted(self)
187188
# The returned density is normalized to the number of points.
188189
# For it to be a probability, we must scale it. For this reason
189190
# we'll also scale atol.
@@ -241,6 +242,7 @@ def sample(self, n_samples=1, random_state=None):
241242
X : array_like, shape (n_samples, n_features)
242243
List of samples.
243244
"""
245+
check_is_fitted(self)
244246
# TODO: implement sampling for other valid kernel shapes
245247
if self.kernel not in ['gaussian', 'tophat']:
246248
raise NotImplementedError()

sklearn/neighbors/tests/test_kde.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.datasets import make_blobs
1010
from sklearn.model_selection import GridSearchCV
1111
from sklearn.preprocessing import StandardScaler
12+
from sklearn.exceptions import NotFittedError
1213
import joblib
1314

1415

@@ -235,3 +236,15 @@ def test_pickling(tmpdir, sample_weight):
235236
scores_pickled = kde.score_samples(X)
236237

237238
assert_allclose(scores, scores_pickled)
239+
240+
241+
@pytest.mark.parametrize('method', ['score_samples', 'sample'])
242+
def test_check_is_fitted(method):
243+
# Check that predict raises an exception in an unfitted estimator.
244+
# Unfitted estimators should raise a NotFittedError.
245+
rng = np.random.RandomState(0)
246+
X = rng.randn(10, 2)
247+
kde = KernelDensity()
248+
249+
with pytest.raises(NotFittedError):
250+
getattr(kde, method)(X)

0 commit comments

Comments
 (0)