Skip to content

Commit d0cdcde

Browse files
committed
Test correctness of average_precision_score.
1 parent 67eb1b8 commit d0cdcde

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,28 @@ def _auc(y_true, y_score):
316316
return n_correct / float(len(pos) * len(neg))
317317

318318

319+
def _average_precision(y_true, y_score):
320+
pos_label = np.unique(y_true)[1]
321+
n_pos = np.sum(y_true == pos_label)
322+
order = np.argsort(y_score)[::-1]
323+
y_score = y_score[order]
324+
y_true = y_true[order]
325+
326+
score = 0
327+
for i in xrange(len(y_score)):
328+
if y_true[i] == pos_label:
329+
# Compute precision up to document i
330+
# i.e, percentage of relevant documents up to document i.
331+
prec = 0
332+
for j in xrange(0, i + 1):
333+
if y_true[j] == pos_label:
334+
prec += 1.0
335+
prec /= (i + 1.0)
336+
score += prec
337+
338+
return score / n_pos
339+
340+
319341
def test_roc_curve():
320342
"""Test Area under Receiver Operating Characteristic (ROC) curve"""
321343
y_true, _, probas_pred = make_prediction(binary=True)
@@ -917,6 +939,8 @@ def _test_precision_recall_curve(y_true, probas_pred):
917939
assert_array_almost_equal(precision_recall_auc, 0.85, 2)
918940
assert_array_almost_equal(precision_recall_auc,
919941
average_precision_score(y_true, probas_pred))
942+
assert_almost_equal(_average_precision(y_true, probas_pred),
943+
precision_recall_auc, 1)
920944
assert_equal(p.size, r.size)
921945
assert_equal(p.size, thresholds.size + 1)
922946
# Smoke test in the case of proba having only one value

0 commit comments

Comments
 (0)