Skip to content

Commit 1a549d6

Browse files
ndingwallGaelVaroquaux
authored andcommitted
Adds support for step-wise interpolation to auc and average_precision_score
1 parent 0fb9a50 commit 1a549d6

File tree

1 file changed

+79
-23
lines changed

1 file changed

+79
-23
lines changed

sklearn/metrics/ranking.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@
3636
from .base import _average_binary_score
3737

3838

39-
def auc(x, y, reorder=False):
40-
"""Compute Area Under the Curve (AUC) using the trapezoidal rule
39+
def auc(x, y, reorder=False, interpolation='linear',
40+
interpolation_direction='right'):
41+
"""Estimate Area Under the Curve (AUC) using finitely many points and an
42+
interpolation strategy.
4143
4244
This is a general function, given points on a curve. For computing the
4345
area under the ROC-curve, see :func:`roc_auc_score`.
@@ -54,6 +56,24 @@ def auc(x, y, reorder=False):
5456
If True, assume that the curve is ascending in the case of ties, as for
5557
an ROC curve. If the curve is non-ascending, the result will be wrong.
5658
59+
interpolation : string ['trapezoid' (default), 'step']
60+
This determines the type of interpolation performed on the data.
61+
62+
``'linear'``:
63+
Use the trapezoidal rule (linearly interpolating between points).
64+
``'step'``:
65+
Use a step function where we ascend/descend from each point to the
66+
y-value of the subsequent point.
67+
68+
interpolation_direction : string ['right' (default), 'left']
69+
This determines the direction to interpolate from. The value is ignored
70+
unless interpolation is 'step'.
71+
72+
``'right'``:
73+
Intermediate points inherit their y-value from the subsequent point.
74+
``'left'``:
75+
Intermediate points inherit their y-value from the previous point.
76+
5777
Returns
5878
-------
5979
auc : float
@@ -76,13 +96,6 @@ def auc(x, y, reorder=False):
7696
Compute precision-recall pairs for different probability thresholds
7797
7898
"""
79-
check_consistent_length(x, y)
80-
x = column_or_1d(x)
81-
y = column_or_1d(y)
82-
83-
if x.shape[0] < 2:
84-
raise ValueError('At least 2 points are needed to compute'
85-
' area under curve, but x.shape = %s' % x.shape)
8699

87100
direction = 1
88101
if reorder:
@@ -99,20 +112,42 @@ def auc(x, y, reorder=False):
99112
raise ValueError("Reordering is not turned on, and "
100113
"the x array is not increasing: %s" % x)
101114

102-
area = direction * np.trapz(y, x)
103-
if isinstance(area, np.memmap):
104-
# Reductions such as .sum used internally in np.trapz do not return a
105-
# scalar by default for numpy.memmap instances contrary to
106-
# regular numpy.ndarray instances.
107-
area = area.dtype.type(area)
115+
if interpolation == 'linear':
116+
117+
area = direction * np.trapz(y, x)
118+
119+
elif interpolation == 'step':
120+
121+
# we need the data to start in ascending order
122+
if direction == -1:
123+
x, y = list(reversed(x)), list(reversed(y))
124+
125+
if interpolation_direction == 'right':
126+
# The left-most y-value is not used
127+
area = sum(np.diff(x) * np.array(y)[1:])
128+
129+
elif interpolation_direction == 'left':
130+
# The right-most y-value is not used
131+
area = sum(np.diff(x) * np.array(y)[:-1])
132+
133+
else:
134+
raise ValueError("interpolation_direction '{}' not recognised."
135+
" Should be one of ['right', 'left']".format(
136+
interpolation_direction))
137+
else:
138+
raise ValueError("interpolation value '{}' not recognized. "
139+
"Should be one of ['linear', 'step']".format(
140+
interpolation))
141+
108142
return area
109143

110144

111145
def average_precision_score(y_true, y_score, average="macro",
112-
sample_weight=None):
146+
sample_weight=None, interpolation="linear"):
113147
"""Compute average precision (AP) from prediction scores
114148
115-
This score corresponds to the area under the precision-recall curve.
149+
This score corresponds to the area under the precision-recall curve, where
150+
points are joined using either linear or step-wise interpolation.
116151
117152
Note: this implementation is restricted to the binary classification task
118153
or multilabel classification task.
@@ -126,8 +161,7 @@ def average_precision_score(y_true, y_score, average="macro",
126161
127162
y_score : array, shape = [n_samples] or [n_samples, n_classes]
128163
Target scores, can either be probability estimates of the positive
129-
class, confidence values, or non-thresholded measure of decisions
130-
(as returned by "decision_function" on some classifiers).
164+
class, confidence values, or binary decisions.
131165
132166
average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
133167
If ``None``, the scores for each class are returned. Otherwise,
@@ -148,14 +182,24 @@ def average_precision_score(y_true, y_score, average="macro",
148182
sample_weight : array-like of shape = [n_samples], optional
149183
Sample weights.
150184
185+
interpolation : string ['linear' (default), 'step']
186+
Determines the kind of interpolation used when computed AUC. If there are
187+
many repeated scores, 'step' is recommended to avoid under- or over-
188+
estimating the AUC. See www.roamanalytics.com/etc for details.
189+
190+
``'linear'``:
191+
Linearly interpolates between operating points.
192+
``'step'``:
193+
Uses a step function to interpolate between operating points.
194+
151195
Returns
152196
-------
153197
average_precision : float
154198
155199
References
156200
----------
157201
.. [1] `Wikipedia entry for the Average precision
158-
<https://en.wikipedia.org/wiki/Average_precision>`_
202+
<http://en.wikipedia.org/wiki/Average_precision>`_
159203
160204
See also
161205
--------
@@ -177,8 +221,20 @@ def average_precision_score(y_true, y_score, average="macro",
177221
def _binary_average_precision(y_true, y_score, sample_weight=None):
178222
precision, recall, thresholds = precision_recall_curve(
179223
y_true, y_score, sample_weight=sample_weight)
180-
return auc(recall, precision)
181-
224+
return auc(recall, precision, interpolation=interpolation,
225+
interpolation_direction='right')
226+
227+
if interpolation == "linear":
228+
# Check for number of unique predictions. If this is substantially less
229+
# than the number of predictions, linear interpolation is likely to be
230+
# biased.
231+
n_discrete_predictions = len(np.unique(y_score))
232+
if n_discrete_predictions < 0.75 * len(y_score):
233+
warnings.warn("Number of unique scores is less than 75% of the "
234+
"number of scores provided. Linear interpolation "
235+
"is likely to be biased in this case. You may wish "
236+
"to use step interpolation instead. See docstring "
237+
"for details.")
182238
return _average_binary_score(_binary_average_precision, y_true, y_score,
183239
average, sample_weight=sample_weight)
184240

@@ -252,7 +308,7 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None):
252308

253309
fpr, tpr, tresholds = roc_curve(y_true, y_score,
254310
sample_weight=sample_weight)
255-
return auc(fpr, tpr, reorder=True)
311+
return auc(fpr, tpr, reorder=True, interpolation='linear')
256312

257313
return _average_binary_score(
258314
_binary_roc_auc_score, y_true, y_score, average,

0 commit comments

Comments
 (0)