Skip to content

Commit db79e2d

Browse files
thomasjpfanqinhanmin2014
authored andcommitted
FEA Plotting API starting with ROC curve (scikit-learn#14357)
1 parent 696c8a9 commit db79e2d

File tree

12 files changed

+480
-28
lines changed

12 files changed

+480
-28
lines changed

doc/developers/contributing.rst

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,3 +1674,54 @@ make this task easier and faster (in no particular order).
16741674
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely
16751675
useful to see every occurrence of a pattern (e.g. a function call or a
16761676
variable) in the code base.
1677+
1678+
1679+
.. _plotting_api:
1680+
1681+
Plotting API
1682+
============
1683+
1684+
Scikit-learn defines a simple API for creating visualizations for machine
1685+
learning. The key features of this API is to run calculations once and to have
1686+
the flexibility to adjust the visualizations after the fact. This logic is
1687+
encapsulated into a display object where the computed data is stored and
1688+
the plotting is done in a `plot` method. The display object's `__init__`
1689+
method contains only the data needed to create the visualization. The `plot`
1690+
method takes in parameters that only have to do with visualization, such as a
1691+
matplotlib axes. The `plot` method will store the matplotlib artists as
1692+
attributes allowing for style adjustments through the display object. A
1693+
`plot_*` helper function accepts parameters to do the computation and the
1694+
parameters used for plotting. After the helper function creates the display
1695+
object with the computed values, it calls the display's plot method. Note
1696+
that the `plot` method defines attributes related to matplotlib, such as the
1697+
line artist. This allows for customizations after calling the `plot` method.
1698+
1699+
For example, the `RocCurveDisplay` defines the following methods and
1700+
attributes:
1701+
1702+
.. code-block:: python
1703+
1704+
class RocCurveDisplay:
1705+
def __init__(self, fpr, tpr, roc_auc, estimator_name):
1706+
...
1707+
self.fpr = fpr
1708+
self.tpr = tpr
1709+
self.roc_auc = roc_auc
1710+
self.estimator_name = estimator_name
1711+
1712+
def plot(self, ax=None, name=None, **kwargs):
1713+
...
1714+
self.line_ = ...
1715+
self.ax_ = ax
1716+
self.figure_ = ax.figure_
1717+
1718+
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
1719+
drop_intermediate=True, response_method="auto",
1720+
name=None, ax=None, **kwargs):
1721+
# do computation
1722+
viz = RocCurveDisplay(fpr, tpr, roc_auc,
1723+
estimator.__class__.__name__)
1724+
return viz.plot(ax=ax, name=name, **kwargs)
1725+
```
1726+
1727+
Read more in the :ref:`User Guide <visualizations>`.

doc/modules/classes.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,26 @@ See the :ref:`metrics` section of the user guide for further details.
10071007
metrics.pairwise_distances_chunked
10081008

10091009

1010+
Plotting
1011+
--------
1012+
1013+
See the :ref:`visualizations` section of the user guide for further details.
1014+
1015+
.. currentmodule:: sklearn
1016+
1017+
.. autosummary::
1018+
:toctree: generated/
1019+
:template: function.rst
1020+
1021+
metrics.plot_roc_curve
1022+
1023+
.. autosummary::
1024+
:toctree: generated/
1025+
:template: class.rst
1026+
1027+
metrics.RocCurveDisplay
1028+
1029+
10101030
.. _mixture_ref:
10111031

10121032
:mod:`sklearn.mixture`: Gaussian Mixture Models

doc/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ User Guide
1919
unsupervised_learning.rst
2020
model_selection.rst
2121
inspection.rst
22+
visualizations.rst
2223
data_transforms.rst
2324
Dataset loading utilities <datasets/index.rst>
2425
modules/computing.rst

doc/visualizations.rst

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
.. include:: includes/big_toc_css.rst
2+
3+
.. _visualizations:
4+
5+
==============
6+
Visualizations
7+
==============
8+
9+
Scikit-learn defines a simple API for creating visualizations for machine
10+
learning. The key feature of this API is to allow for quick plotting and
11+
visual adjustments without recalculation. In the following example, we plot a
12+
ROC curve for a fitted support vector machine:
13+
14+
.. code-block:: python
15+
16+
from sklearn.model_selection import train_test_split
17+
from sklearn.svm import SVC
18+
from sklearn.metrics import plot_roc_curve
19+
from sklearn.datasets import load_wine
20+
21+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
22+
svc = SVC(random_state=42)
23+
svc.fit(X_train, y_train)
24+
25+
svc_disp = plot_roc_curve(svc, X_test, y_test)
26+
27+
.. figure:: ../auto_examples/images/sphx_glr_plot_roc_curve_visualization_api_001.png
28+
:target: ../auto_examples/plot_roc_curve_visualization_api.html
29+
:align: center
30+
:scale: 75%
31+
32+
The returned `svc_disp` object allows us to continue using the already computed
33+
ROC curve for SVC in future plots. In this case, the `svc_disp` is a
34+
:class:`~sklearn.metrics.RocCurveDisplay` that stores the computed values as
35+
attributes called `roc_auc`, `fpr`, and `tpr`. Next, we train a random forest
36+
classifier and plot the previously computed roc curve again by using the `plot`
37+
method of the `Display` object.
38+
39+
.. code-block:: python
40+
41+
import matplotlib.pyplot as plt
42+
from sklearn.ensemble import RandomForestClassifier
43+
44+
rfc = RandomForestClassifier(random_state=42)
45+
rfc.fit(X_train, y_train)
46+
47+
ax = plt.gca()
48+
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
49+
svc_disp.plot(ax=ax, alpha=0.8)
50+
51+
.. figure:: ../auto_examples/images/sphx_glr_plot_roc_curve_visualization_api_002.png
52+
:target: ../auto_examples/plot_roc_curve_visualization_api.html
53+
:align: center
54+
:scale: 75%
55+
56+
Notice that we pass `alpha=0.8` to the plot functions to adjust the alpha
57+
values of the curves.
58+
59+
.. topic:: Examples:
60+
61+
* :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py`
62+
63+
Available Plotting Utilities
64+
============================
65+
66+
Functions
67+
---------
68+
69+
.. currentmodule:: sklearn
70+
71+
.. autosummary::
72+
73+
metrics.plot_roc_curve
74+
75+
76+
Display Objects
77+
---------------
78+
79+
.. currentmodule:: sklearn
80+
81+
.. autosummary::
82+
83+
metrics.RocCurveDisplay

examples/model_selection/plot_roc_crossval.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
import matplotlib.pyplot as plt
3737

3838
from sklearn import svm, datasets
39-
from sklearn.metrics import roc_curve, auc
39+
from sklearn.metrics import auc
40+
from sklearn.metrics import plot_roc_curve
4041
from sklearn.model_selection import StratifiedKFold
4142

4243
# #############################################################################
@@ -65,40 +66,35 @@
6566
aucs = []
6667
mean_fpr = np.linspace(0, 1, 100)
6768

68-
i = 0
69-
for train, test in cv.split(X, y):
70-
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
71-
# Compute ROC curve and area the curve
72-
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
73-
tprs.append(interp(mean_fpr, fpr, tpr))
74-
tprs[-1][0] = 0.0
75-
roc_auc = auc(fpr, tpr)
76-
aucs.append(roc_auc)
77-
plt.plot(fpr, tpr, lw=1, alpha=0.3,
78-
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))
79-
80-
i += 1
81-
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
82-
label='Chance', alpha=.8)
69+
fig, ax = plt.subplots()
70+
for i, (train, test) in enumerate(cv.split(X, y)):
71+
classifier.fit(X[train], y[train])
72+
viz = plot_roc_curve(classifier, X[test], y[test],
73+
name='ROC fold {}'.format(i),
74+
alpha=0.3, lw=1, ax=ax)
75+
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
76+
interp_tpr[0] = 0.0
77+
tprs.append(interp_tpr)
78+
aucs.append(viz.roc_auc)
79+
80+
ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
81+
label='Chance', alpha=.8)
8382

8483
mean_tpr = np.mean(tprs, axis=0)
8584
mean_tpr[-1] = 1.0
8685
mean_auc = auc(mean_fpr, mean_tpr)
8786
std_auc = np.std(aucs)
88-
plt.plot(mean_fpr, mean_tpr, color='b',
89-
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
90-
lw=2, alpha=.8)
87+
ax.plot(mean_fpr, mean_tpr, color='b',
88+
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
89+
lw=2, alpha=.8)
9190

9291
std_tpr = np.std(tprs, axis=0)
9392
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
9493
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
95-
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
96-
label=r'$\pm$ 1 std. dev.')
97-
98-
plt.xlim([-0.05, 1.05])
99-
plt.ylim([-0.05, 1.05])
100-
plt.xlabel('False Positive Rate')
101-
plt.ylabel('True Positive Rate')
102-
plt.title('Receiver operating characteristic example')
103-
plt.legend(loc="lower right")
94+
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
95+
label=r'$\pm$ 1 std. dev.')
96+
97+
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
98+
title="Receiver operating characteristic example")
99+
ax.legend(loc="lower right")
104100
plt.show()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
================================
3+
ROC Curve with Visualization API
4+
================================
5+
Scikit-learn defines a simple API for creating visualizations for machine
6+
learning. The key features of this API is to allow for quick plotting and
7+
visual adjustments without recalculation. In this example, we will demonstrate
8+
how to use the visualization API by comparing ROC curves.
9+
"""
10+
print(__doc__)
11+
12+
##############################################################################
13+
# Load Data and Train a SVC
14+
# -------------------------
15+
# First, we load the wine dataset and convert it to a binary classification
16+
# problem. Then, we train a support vector classifier on a training dataset.
17+
import matplotlib.pyplot as plt
18+
from sklearn.svm import SVC
19+
from sklearn.ensemble import RandomForestClassifier
20+
from sklearn.metrics import plot_roc_curve
21+
from sklearn.datasets import load_wine
22+
from sklearn.model_selection import train_test_split
23+
24+
X, y = load_wine(return_X_y=True)
25+
y = y == 2
26+
27+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
28+
svc = SVC(random_state=42)
29+
svc.fit(X_train, y_train)
30+
31+
##############################################################################
32+
# Plotting the ROC Curve
33+
# ----------------------
34+
# Next, we plot the ROC curve with a single call to
35+
# :func:`sklearn.metrics.plot_roc_curve`. The returned `svc_disp` object allows
36+
# us to continue using the already computed ROC curve for the SVC in future
37+
# plots.
38+
svc_disp = plot_roc_curve(svc, X_test, y_test)
39+
plt.show()
40+
41+
##############################################################################
42+
# Training a Random Forest and Plotting the ROC Curve
43+
# --------------------------------------------------------
44+
# We train a random forest classifier and create a plot comparing it to the SVC
45+
# ROC curve. Notice how `svc_disp` uses
46+
# :func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve
47+
# without recomputing the values of the roc curve itself. Futhermore, we
48+
# pass `alpha=0.8` to the plot functions to adjust the alpha values of the
49+
# curves.
50+
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
51+
rfc.fit(X_train, y_train)
52+
ax = plt.gca()
53+
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
54+
svc_disp.plot(ax=ax, alpha=0.8)
55+
plt.show()

sklearn/metrics/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@
7474
from .scorer import SCORERS
7575
from .scorer import get_scorer
7676

77+
from ._plot.roc_curve import plot_roc_curve
78+
from ._plot.roc_curve import RocCurveDisplay
79+
80+
7781
__all__ = [
7882
'accuracy_score',
7983
'adjusted_mutual_info_score',
@@ -125,11 +129,13 @@
125129
'pairwise_distances_argmin_min',
126130
'pairwise_distances_chunked',
127131
'pairwise_kernels',
132+
'plot_roc_curve',
128133
'precision_recall_curve',
129134
'precision_recall_fscore_support',
130135
'precision_score',
131136
'r2_score',
132137
'recall_score',
138+
'RocCurveDisplay',
133139
'roc_auc_score',
134140
'roc_curve',
135141
'SCORERS',

sklearn/metrics/_plot/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)