Skip to content

Commit acd2316

Browse files
committed
Merge pull request scikit-learn#2941 from perimosocordiae/patch-1
BUG: avoid NaNs throwing off class probabilities
2 parents 3324a41 + 55ef5a4 commit acd2316

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

sklearn/neighbors/classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def predict_proba(self, X):
200200
weights = _get_weights(neigh_dist, self.weights)
201201
if weights is None:
202202
weights = np.ones_like(neigh_ind)
203+
else:
204+
# Some weights may be infinite (zero distance), which can cause
205+
# downstream NaN values when used for normalization.
206+
weights[np.isinf(weights)] = np.finfo('f').max
203207

204208
all_rows = np.arange(X.shape[0])
205209
probabilities = []

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ def test_kneighbors_classifier_predict_proba():
201201
cls.fit(X, y.astype(str))
202202
y_prob = cls.predict_proba(X)
203203
assert_array_equal(real_prob, y_prob)
204+
# Check that it works with weights='distance'
205+
cls = neighbors.KNeighborsClassifier(
206+
n_neighbors=2, p=1, weights='distance')
207+
cls.fit(X, y)
208+
y_prob = cls.predict_proba(np.array([[0, 2, 0], [2, 2, 2]]))
209+
real_prob = np.array([[0, 1, 0], [0, 0.4, 0.6]])
210+
assert_array_almost_equal(real_prob, y_prob)
211+
204212

205213

206214
def test_radius_neighbors_classifier(n_samples=40,

0 commit comments

Comments
 (0)