Skip to content

Commit e31376a

Browse files
conradleeamueller
authored andcommitted
metrics.py: COSMIT more commets on precision_recall_curve
1 parent 090bed6 commit e31376a

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

sklearn/metrics/metrics.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# Olivier Grisel <[email protected]>
1313
# License: BSD Style.
1414

15-
import itertools
15+
from itertools import izip
1616
import numpy as np
1717
from scipy.sparse import coo_matrix
1818

@@ -855,14 +855,17 @@ def precision_recall_curve(y_true, probas_pred):
855855
raise ValueError("y_true contains non binary labels: %r" % labels)
856856

857857
# Sort pred_probas (and corresponding true labels) by pred_proba value
858-
sort_idxs = np.argsort(probas_pred, kind="mergesort")[::-1]
859-
probas_pred = probas_pred[sort_idxs]
860-
y_true = y_true[sort_idxs]
861-
862-
# Get indices where values of probas_pred decreases
863-
thresh_idxs = np.r_[0,
864-
np.where(np.diff(probas_pred))[0] + 1,
865-
len(probas_pred)]
858+
decreasing_probas_indices = np.argsort(probas_pred, kind="mergesort")[::-1]
859+
probas_pred = probas_pred[decreasing_probas_indices]
860+
y_true = y_true[decreasing_probas_indices]
861+
862+
# Probas_pred typically has many tied values. Here we extract
863+
# the indices associated with the distinct values. We also
864+
# concatenate values onto the ends of the curve.
865+
distinct_value_indices = np.where(np.diff(probas_pred))[0] + 1
866+
threshold_idxs = np.r_[0,
867+
distinct_value_indices,
868+
len(probas_pred)]
866869

867870
# Initialize true and false positive counts, precision and recall
868871
total_positive = float(y_true.sum())
@@ -871,20 +874,20 @@ def precision_recall_curve(y_true, probas_pred):
871874
recall = [0.]
872875
thresholds = []
873876

874-
# Iterate over indices which indicate distinct values of probas_pred --
875-
# each of these distinct values will be represented in the curve with a
876-
# coordinate in precision-recall space. To calculate the precision and
877-
# recall associated with each point, we use these indices to select all
878-
# labels associated with the predictions. By incrementally keeping track
879-
# of the number of positive and negative labels seen so far, we can
880-
# calculate precision and recall.
881-
for l_idx, r_idx in itertools.izip(thresh_idxs[:-1], thresh_idxs[1:]):
882-
thresh_labels = y_true[l_idx:r_idx]
883-
n_thresh = r_idx - l_idx
884-
n_pos_thresh = thresh_labels.sum()
885-
n_neg_thresh = n_thresh - n_pos_thresh
886-
tp_count += n_pos_thresh
887-
fp_count += n_neg_thresh
877+
# Iterate over indices which indicate distinct values (thresholds) of
878+
# probas_pred. Each of these threshold values will be represented in the
879+
# curve with a coordinate in precision-recall space. To calculate the
880+
# precision and recall associated with each point, we use these indices to
881+
# select all labels associated with the predictions. By incrementally
882+
# keeping track of the number of positive and negative labels seen so far,
883+
# we can calculate precision and recall.
884+
for l_idx, r_idx in izip(threshold_idxs[:-1], threshold_idxs[1:]):
885+
threshold_labels = y_true[l_idx:r_idx]
886+
n_at_threshold = r_idx - l_idx
887+
n_pos_at_threshold = threshold_labels.sum()
888+
n_neg_at_threshold = n_at_threshold - n_pos_at_threshold
889+
tp_count += n_pos_at_threshold
890+
fp_count += n_neg_at_threshold
888891
fn_count = total_positive - tp_count
889892
precision.append(tp_count / (tp_count + fp_count))
890893
recall.append(tp_count / (tp_count + fn_count))

0 commit comments

Comments
 (0)