|
27 | 27 | print(__doc__) |
28 | 28 |
|
29 | 29 | import numpy as np |
| 30 | +import matplotlib.pyplot as plt |
30 | 31 |
|
31 | 32 | from sklearn import svm, datasets |
32 | 33 | from sklearn.cross_validation import train_test_split |
33 | 34 | from sklearn.metrics import confusion_matrix |
34 | 35 |
|
35 | | -import matplotlib.pyplot as plt |
36 | | - |
37 | 36 | # import some data to play with |
38 | 37 | iris = datasets.load_iris() |
39 | 38 | X = iris.data |
|
47 | 46 | classifier = svm.SVC(kernel='linear', C=0.01) |
48 | 47 | y_pred = classifier.fit(X_train, y_train).predict(X_test) |
49 | 48 |
|
| 49 | + |
| 50 | +def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues): |
| 51 | + plt.imshow(cm, interpolation='nearest', cmap=cmap) |
| 52 | + plt.title(title) |
| 53 | + plt.colorbar() |
| 54 | + tick_marks = np.arange(len(iris.target_names)) |
| 55 | + plt.xticks(tick_marks, iris.target_names, rotation=45) |
| 56 | + plt.yticks(tick_marks, iris.target_names) |
| 57 | + plt.tight_layout() |
| 58 | + plt.ylabel('True label') |
| 59 | + plt.xlabel('Predicted label') |
| 60 | + |
| 61 | + |
50 | 62 | # Compute confusion matrix |
51 | 63 | cm = confusion_matrix(y_test, y_pred) |
| 64 | +np.set_printoptions(precision=2) |
52 | 65 | print('Confusion matrix, without normalization') |
53 | 66 | print(cm) |
54 | | - |
55 | | -# Show confusion matrix in a separate window |
56 | | -plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary) |
57 | | -plt.title('Confusion matrix') |
58 | | -plt.set_cmap('Blues') |
59 | | -plt.colorbar() |
60 | | -tick_marks = np.arange(len(iris.target_names)) |
61 | | -plt.xticks(tick_marks, iris.target_names, rotation=60) |
62 | | -plt.yticks(tick_marks, iris.target_names) |
63 | | -plt.ylabel('True label') |
64 | | -plt.xlabel('Predicted label') |
65 | | -# Convenience function to adjust plot parameters for a clear layout. |
66 | | -plt.tight_layout() |
| 67 | +plt.figure() |
| 68 | +plot_confusion_matrix(cm) |
67 | 69 |
|
68 | 70 | # Normalize the confusion matrix by row (i.e by the number of samples |
69 | 71 | # in each class) |
70 | 72 | cm_normalized = cm.astype('float') / cm.sum(axis=1) |
71 | | - |
72 | 73 | print('Normalized confusion matrix') |
73 | 74 | print(cm_normalized) |
74 | | - |
75 | | -# Show normalized confusion matrix in a separate window |
76 | 75 | plt.figure() |
77 | | -plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.binary) |
78 | | -plt.title('Normalized confusion matrix') |
79 | | -plt.set_cmap('Blues') |
80 | | -plt.colorbar() |
81 | | -plt.xticks(tick_marks, iris.target_names, rotation=60) |
82 | | -plt.yticks(tick_marks, iris.target_names) |
83 | | -plt.ylabel('True label') |
84 | | -plt.xlabel('Predicted label') |
85 | | -plt.tight_layout() |
| 76 | +plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix') |
| 77 | + |
86 | 78 | plt.show() |
0 commit comments