|
36 | 36 | import matplotlib.pyplot as plt |
37 | 37 |
|
38 | 38 | from sklearn import svm, datasets |
39 | | -from sklearn.metrics import roc_curve, auc |
| 39 | +from sklearn.metrics import auc |
| 40 | +from sklearn.metrics import plot_roc_curve |
40 | 41 | from sklearn.model_selection import StratifiedKFold |
41 | 42 |
|
42 | 43 | # ############################################################################# |
|
65 | 66 | aucs = [] |
66 | 67 | mean_fpr = np.linspace(0, 1, 100) |
67 | 68 |
|
68 | | -i = 0 |
69 | | -for train, test in cv.split(X, y): |
70 | | - probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test]) |
71 | | - # Compute ROC curve and area the curve |
72 | | - fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1]) |
73 | | - tprs.append(interp(mean_fpr, fpr, tpr)) |
74 | | - tprs[-1][0] = 0.0 |
75 | | - roc_auc = auc(fpr, tpr) |
76 | | - aucs.append(roc_auc) |
77 | | - plt.plot(fpr, tpr, lw=1, alpha=0.3, |
78 | | - label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc)) |
79 | | - |
80 | | - i += 1 |
81 | | -plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', |
82 | | - label='Chance', alpha=.8) |
| 69 | +fig, ax = plt.subplots() |
| 70 | +for i, (train, test) in enumerate(cv.split(X, y)): |
| 71 | + classifier.fit(X[train], y[train]) |
| 72 | + viz = plot_roc_curve(classifier, X[test], y[test], |
| 73 | + name='ROC fold {}'.format(i), |
| 74 | + alpha=0.3, lw=1, ax=ax) |
| 75 | + interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr) |
| 76 | + interp_tpr[0] = 0.0 |
| 77 | + tprs.append(interp_tpr) |
| 78 | + aucs.append(viz.roc_auc) |
| 79 | + |
| 80 | +ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', |
| 81 | + label='Chance', alpha=.8) |
83 | 82 |
|
84 | 83 | mean_tpr = np.mean(tprs, axis=0) |
85 | 84 | mean_tpr[-1] = 1.0 |
86 | 85 | mean_auc = auc(mean_fpr, mean_tpr) |
87 | 86 | std_auc = np.std(aucs) |
88 | | -plt.plot(mean_fpr, mean_tpr, color='b', |
89 | | - label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc), |
90 | | - lw=2, alpha=.8) |
| 87 | +ax.plot(mean_fpr, mean_tpr, color='b', |
| 88 | + label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc), |
| 89 | + lw=2, alpha=.8) |
91 | 90 |
|
92 | 91 | std_tpr = np.std(tprs, axis=0) |
93 | 92 | tprs_upper = np.minimum(mean_tpr + std_tpr, 1) |
94 | 93 | tprs_lower = np.maximum(mean_tpr - std_tpr, 0) |
95 | | -plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, |
96 | | - label=r'$\pm$ 1 std. dev.') |
97 | | - |
98 | | -plt.xlim([-0.05, 1.05]) |
99 | | -plt.ylim([-0.05, 1.05]) |
100 | | -plt.xlabel('False Positive Rate') |
101 | | -plt.ylabel('True Positive Rate') |
102 | | -plt.title('Receiver operating characteristic example') |
103 | | -plt.legend(loc="lower right") |
| 94 | +ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, |
| 95 | + label=r'$\pm$ 1 std. dev.') |
| 96 | + |
| 97 | +ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], |
| 98 | + title="Receiver operating characteristic example") |
| 99 | +ax.legend(loc="lower right") |
104 | 100 | plt.show() |
0 commit comments