Skip to content

Commit 2dfe13d

Browse files
committed
BUG in RadiusNeighborClassifier outlier handling
Ruggedized the test, which wasn't doing much previously, and added a test for c0d4015. Also replaced flatten (which copies) with ravel and improved the error message for outliers in the face of outlier_label=None.
1 parent c0d4015 commit 2dfe13d

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

sklearn/neighbors/classification.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -289,22 +289,24 @@ def predict(self, X):
289289
pred_labels = [self._y[ind] for ind in neigh_ind]
290290

291291
if self.outlier_label is not None:
292-
outlier_label = np.array((self.outlier_label, ))
293-
small_value = np.array((1e-6, ))
292+
outlier_label = np.array([self.outlier_label])
293+
small_value = np.array([1e-6])
294294
for i, pl in enumerate(pred_labels):
295295
# Check that all have at least 1 neighbor
296296
if len(pl) < 1:
297297
pred_labels[i] = outlier_label
298298
neigh_dist[i] = small_value
299299
else:
300-
for pl in pred_labels:
300+
for i, pl in enumerate(pred_labels):
301301
# Check that all have at least 1 neighbor
302+
# TODO we should gather all outliers, or the first k,
303+
# before constructing the error message.
302304
if len(pl) < 1:
303-
raise ValueError('no neighbors found for a test sample, '
305+
raise ValueError('No neighbors found for test sample %d, '
304306
'you can try using larger radius, '
305307
'give a label for outliers, '
306-
'or consider removing them in your '
307-
'dataset')
308+
'or consider removing it from your '
309+
'dataset.' % i)
308310

309311
weights = _get_weights(neigh_dist, self.weights)
310312

@@ -316,10 +318,10 @@ def predict(self, X):
316318
for (pl, w) in zip(pred_labels, weights)],
317319
dtype=np.int)
318320

319-
mode = mode.flatten().astype(np.int)
321+
mode = mode.ravel().astype(np.int)
320322
# map indices to classes
321323
prediction = self.classes_.take(mode)
322324
if self.outlier_label is not None:
323325
# reset outlier label
324-
prediction[mode == outlier_label] = self.outlier_label
326+
prediction[prediction == outlier_label] = self.outlier_label
325327
return prediction

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,20 @@ def test_radius_neighbors_classifier_when_no_neighbors():
240240

241241
weight_func = _weight_func
242242

243-
for algorithm in ALGORITHMS:
244-
for weights in ['uniform', 'distance', weight_func]:
245-
clf = neighbors.RadiusNeighborsClassifier(radius=radius,
246-
weights=weights,
247-
algorithm=algorithm)
248-
clf.fit(X, y)
249-
clf.predict(z1)
250-
assert_raises(ValueError, clf.predict, z2)
243+
for outlier_label in [0, -1, None]:
244+
for algorithm in ALGORITHMS:
245+
for weights in ['uniform', 'distance', weight_func]:
246+
rnc = neighbors.RadiusNeighborsClassifier
247+
clf = rnc(radius=radius, weights=weights, algorithm=algorithm,
248+
outlier_label=outlier_label)
249+
clf.fit(X, y)
250+
assert_array_equal(np.array([1, 2]),
251+
clf.predict(z1))
252+
if outlier_label is None:
253+
assert_raises(ValueError, clf.predict, z2)
254+
elif False:
255+
assert_array_equal(np.array([1, outlier_label]),
256+
clf.predict(z2))
251257

252258

253259
def test_radius_neighbors_classifier_outlier_labeling():

0 commit comments

Comments
 (0)