3636from .base import _average_binary_score
3737
3838
39- def auc (x , y , reorder = False ):
40- """Compute Area Under the Curve (AUC) using the trapezoidal rule
39+ def auc (x , y , reorder = False , interpolation = 'linear' ,
40+ interpolation_direction = 'right' ):
41+ """Estimate Area Under the Curve (AUC) using finitely many points and an
42+ interpolation strategy.
4143
4244 This is a general function, given points on a curve. For computing the
4345 area under the ROC-curve, see :func:`roc_auc_score`.
@@ -54,6 +56,24 @@ def auc(x, y, reorder=False):
5456 If True, assume that the curve is ascending in the case of ties, as for
5557 an ROC curve. If the curve is non-ascending, the result will be wrong.
5658
59+ interpolation : string ['trapezoid' (default), 'step']
60+ This determines the type of interpolation performed on the data.
61+
62+ ``'linear'``:
63+ Use the trapezoidal rule (linearly interpolating between points).
64+ ``'step'``:
65+ Use a step function where we ascend/descend from each point to the
66+ y-value of the subsequent point.
67+
68+ interpolation_direction : string ['right' (default), 'left']
69+ This determines the direction to interpolate from. The value is ignored
70+ unless interpolation is 'step'.
71+
72+ ``'right'``:
73+ Intermediate points inherit their y-value from the subsequent point.
74+ ``'left'``:
75+ Intermediate points inherit their y-value from the previous point.
76+
5777 Returns
5878 -------
5979 auc : float
@@ -76,13 +96,6 @@ def auc(x, y, reorder=False):
7696 Compute precision-recall pairs for different probability thresholds
7797
7898 """
79- check_consistent_length (x , y )
80- x = column_or_1d (x )
81- y = column_or_1d (y )
82-
83- if x .shape [0 ] < 2 :
84- raise ValueError ('At least 2 points are needed to compute'
85- ' area under curve, but x.shape = %s' % x .shape )
8699
87100 direction = 1
88101 if reorder :
@@ -99,20 +112,42 @@ def auc(x, y, reorder=False):
99112 raise ValueError ("Reordering is not turned on, and "
100113 "the x array is not increasing: %s" % x )
101114
102- area = direction * np .trapz (y , x )
103- if isinstance (area , np .memmap ):
104- # Reductions such as .sum used internally in np.trapz do not return a
105- # scalar by default for numpy.memmap instances contrary to
106- # regular numpy.ndarray instances.
107- area = area .dtype .type (area )
115+ if interpolation == 'linear' :
116+
117+ area = direction * np .trapz (y , x )
118+
119+ elif interpolation == 'step' :
120+
121+ # we need the data to start in ascending order
122+ if direction == - 1 :
123+ x , y = list (reversed (x )), list (reversed (y ))
124+
125+ if interpolation_direction == 'right' :
126+ # The left-most y-value is not used
127+ area = sum (np .diff (x ) * np .array (y )[1 :])
128+
129+ elif interpolation_direction == 'left' :
130+ # The right-most y-value is not used
131+ area = sum (np .diff (x ) * np .array (y )[:- 1 ])
132+
133+ else :
134+ raise ValueError ("interpolation_direction '{}' not recognised."
135+ " Should be one of ['right', 'left']" .format (
136+ interpolation_direction ))
137+ else :
138+ raise ValueError ("interpolation value '{}' not recognized. "
139+ "Should be one of ['linear', 'step']" .format (
140+ interpolation ))
141+
108142 return area
109143
110144
111145def average_precision_score (y_true , y_score , average = "macro" ,
112- sample_weight = None ):
146+ sample_weight = None , interpolation = "linear" ):
113147 """Compute average precision (AP) from prediction scores
114148
115- This score corresponds to the area under the precision-recall curve.
149+ This score corresponds to the area under the precision-recall curve, where
150+ points are joined using either linear or step-wise interpolation.
116151
117152 Note: this implementation is restricted to the binary classification task
118153 or multilabel classification task.
@@ -126,8 +161,7 @@ def average_precision_score(y_true, y_score, average="macro",
126161
127162 y_score : array, shape = [n_samples] or [n_samples, n_classes]
128163 Target scores, can either be probability estimates of the positive
129- class, confidence values, or non-thresholded measure of decisions
130- (as returned by "decision_function" on some classifiers).
164+ class, confidence values, or binary decisions.
131165
132166 average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
133167 If ``None``, the scores for each class are returned. Otherwise,
@@ -148,14 +182,24 @@ def average_precision_score(y_true, y_score, average="macro",
148182 sample_weight : array-like of shape = [n_samples], optional
149183 Sample weights.
150184
185+ interpolation : string ['linear' (default), 'step']
186+ Determines the kind of interpolation used when computed AUC. If there are
187+ many repeated scores, 'step' is recommended to avoid under- or over-
188+ estimating the AUC. See www.roamanalytics.com/etc for details.
189+
190+ ``'linear'``:
191+ Linearly interpolates between operating points.
192+ ``'step'``:
193+ Uses a step function to interpolate between operating points.
194+
151195 Returns
152196 -------
153197 average_precision : float
154198
155199 References
156200 ----------
157201 .. [1] `Wikipedia entry for the Average precision
158- <https ://en.wikipedia.org/wiki/Average_precision>`_
202+ <http ://en.wikipedia.org/wiki/Average_precision>`_
159203
160204 See also
161205 --------
@@ -177,8 +221,20 @@ def average_precision_score(y_true, y_score, average="macro",
177221 def _binary_average_precision (y_true , y_score , sample_weight = None ):
178222 precision , recall , thresholds = precision_recall_curve (
179223 y_true , y_score , sample_weight = sample_weight )
180- return auc (recall , precision )
181-
224+ return auc (recall , precision , interpolation = interpolation ,
225+ interpolation_direction = 'right' )
226+
227+ if interpolation == "linear" :
228+ # Check for number of unique predictions. If this is substantially less
229+ # than the number of predictions, linear interpolation is likely to be
230+ # biased.
231+ n_discrete_predictions = len (np .unique (y_score ))
232+ if n_discrete_predictions < 0.75 * len (y_score ):
233+ warnings .warn ("Number of unique scores is less than 75% of the "
234+ "number of scores provided. Linear interpolation "
235+ "is likely to be biased in this case. You may wish "
236+ "to use step interpolation instead. See docstring "
237+ "for details." )
182238 return _average_binary_score (_binary_average_precision , y_true , y_score ,
183239 average , sample_weight = sample_weight )
184240
@@ -252,7 +308,7 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None):
252308
253309 fpr , tpr , tresholds = roc_curve (y_true , y_score ,
254310 sample_weight = sample_weight )
255- return auc (fpr , tpr , reorder = True )
311+ return auc (fpr , tpr , reorder = True , interpolation = 'linear' )
256312
257313 return _average_binary_score (
258314 _binary_roc_auc_score , y_true , y_score , average ,
0 commit comments