Skip to content

Commit 7c9a353

Browse files
zarchjnothman
authored andcommitted
FIX unused pos_label parameter in metrics.precision_recall_curve
1 parent 1f00663 commit 7c9a353

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

sklearn/metrics/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
799799
800800
"""
801801
fps, tps, thresholds = _binary_clf_curve(y_true, probas_pred,
802+
pos_label=pos_label,
802803
sample_weight=sample_weight)
803804

804805
precision = tps / (tps + fps)

sklearn/metrics/tests/test_metrics.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,21 @@ def test_precision_recall_curve():
11111111
assert_equal(p.size, t.size + 1)
11121112

11131113

1114+
def test_precision_recall_curve_pos_label():
1115+
y_true, _, probas_pred = make_prediction(binary=False)
1116+
pos_label = 2
1117+
p, r, thresholds = precision_recall_curve(y_true,
1118+
probas_pred[:, pos_label],
1119+
pos_label=pos_label)
1120+
p2, r2, thresholds2 = precision_recall_curve(y_true == pos_label,
1121+
probas_pred[:, pos_label])
1122+
assert_array_almost_equal(p, p2)
1123+
assert_array_almost_equal(r, r2)
1124+
assert_array_almost_equal(thresholds, thresholds2)
1125+
assert_equal(p.size, r.size)
1126+
assert_equal(p.size, thresholds.size + 1)
1127+
1128+
11141129
def _test_precision_recall_curve(y_true, probas_pred):
11151130
"""Test Precision-Recall and aread under PR curve"""
11161131
p, r, thresholds = precision_recall_curve(y_true, probas_pred)

0 commit comments

Comments
 (0)