1212# Olivier Grisel <[email protected] > 1313# License: BSD Style.
1414
15- import itertools
15+ from itertools import izip
1616import numpy as np
1717from scipy .sparse import coo_matrix
1818
@@ -855,14 +855,17 @@ def precision_recall_curve(y_true, probas_pred):
855855 raise ValueError ("y_true contains non binary labels: %r" % labels )
856856
857857 # Sort pred_probas (and corresponding true labels) by pred_proba value
858- sort_idxs = np .argsort (probas_pred , kind = "mergesort" )[::- 1 ]
859- probas_pred = probas_pred [sort_idxs ]
860- y_true = y_true [sort_idxs ]
861-
862- # Get indices where values of probas_pred decreases
863- thresh_idxs = np .r_ [0 ,
864- np .where (np .diff (probas_pred ))[0 ] + 1 ,
865- len (probas_pred )]
858+ decreasing_probas_indices = np .argsort (probas_pred , kind = "mergesort" )[::- 1 ]
859+ probas_pred = probas_pred [decreasing_probas_indices ]
860+ y_true = y_true [decreasing_probas_indices ]
861+
862+ # Probas_pred typically has many tied values. Here we extract
863+ # the indices associated with the distinct values. We also
864+ # concatenate values onto the ends of the curve.
865+ distinct_value_indices = np .where (np .diff (probas_pred ))[0 ] + 1
866+ threshold_idxs = np .r_ [0 ,
867+ distinct_value_indices ,
868+ len (probas_pred )]
866869
867870 # Initialize true and false positive counts, precision and recall
868871 total_positive = float (y_true .sum ())
@@ -871,20 +874,20 @@ def precision_recall_curve(y_true, probas_pred):
871874 recall = [0. ]
872875 thresholds = []
873876
874- # Iterate over indices which indicate distinct values of probas_pred --
875- # each of these distinct values will be represented in the curve with a
876- # coordinate in precision-recall space. To calculate the precision and
877- # recall associated with each point, we use these indices to select all
878- # labels associated with the predictions. By incrementally keeping track
879- # of the number of positive and negative labels seen so far, we can
880- # calculate precision and recall.
881- for l_idx , r_idx in itertools . izip (thresh_idxs [:- 1 ], thresh_idxs [1 :]):
882- thresh_labels = y_true [l_idx :r_idx ]
883- n_thresh = r_idx - l_idx
884- n_pos_thresh = thresh_labels .sum ()
885- n_neg_thresh = n_thresh - n_pos_thresh
886- tp_count += n_pos_thresh
887- fp_count += n_neg_thresh
877+ # Iterate over indices which indicate distinct values (thresholds) of
878+ # probas_pred. Each of these threshold values will be represented in the
879+ # curve with a coordinate in precision-recall space. To calculate the
880+ # precision and recall associated with each point, we use these indices to
881+ # select all labels associated with the predictions. By incrementally
882+ # keeping track of the number of positive and negative labels seen so far,
883+ # we can calculate precision and recall.
884+ for l_idx , r_idx in izip (threshold_idxs [:- 1 ], threshold_idxs [1 :]):
885+ threshold_labels = y_true [l_idx :r_idx ]
886+ n_at_threshold = r_idx - l_idx
887+ n_pos_at_threshold = threshold_labels .sum ()
888+ n_neg_at_threshold = n_at_threshold - n_pos_at_threshold
889+ tp_count += n_pos_at_threshold
890+ fp_count += n_neg_at_threshold
888891 fn_count = total_positive - tp_count
889892 precision .append (tp_count / (tp_count + fp_count ))
890893 recall .append (tp_count / (tp_count + fn_count ))
0 commit comments