Skip to content

Commit 8d3f4df

Browse files
committed
ENH cosmetic reorg of the confusion matrix example
1 parent 77ff749 commit 8d3f4df

File tree

1 file changed

+19
-27
lines changed

1 file changed

+19
-27
lines changed

examples/model_selection/plot_confusion_matrix.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727
print(__doc__)
2828

2929
import numpy as np
30+
import matplotlib.pyplot as plt
3031

3132
from sklearn import svm, datasets
3233
from sklearn.cross_validation import train_test_split
3334
from sklearn.metrics import confusion_matrix
3435

35-
import matplotlib.pyplot as plt
36-
3736
# import some data to play with
3837
iris = datasets.load_iris()
3938
X = iris.data
@@ -47,40 +46,33 @@
4746
classifier = svm.SVC(kernel='linear', C=0.01)
4847
y_pred = classifier.fit(X_train, y_train).predict(X_test)
4948

49+
50+
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
51+
plt.imshow(cm, interpolation='nearest', cmap=cmap)
52+
plt.title(title)
53+
plt.colorbar()
54+
tick_marks = np.arange(len(iris.target_names))
55+
plt.xticks(tick_marks, iris.target_names, rotation=45)
56+
plt.yticks(tick_marks, iris.target_names)
57+
plt.tight_layout()
58+
plt.ylabel('True label')
59+
plt.xlabel('Predicted label')
60+
61+
5062
# Compute confusion matrix
5163
cm = confusion_matrix(y_test, y_pred)
64+
np.set_printoptions(precision=2)
5265
print('Confusion matrix, without normalization')
5366
print(cm)
54-
55-
# Show confusion matrix in a separate window
56-
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
57-
plt.title('Confusion matrix')
58-
plt.set_cmap('Blues')
59-
plt.colorbar()
60-
tick_marks = np.arange(len(iris.target_names))
61-
plt.xticks(tick_marks, iris.target_names, rotation=60)
62-
plt.yticks(tick_marks, iris.target_names)
63-
plt.ylabel('True label')
64-
plt.xlabel('Predicted label')
65-
# Convenience function to adjust plot parameters for a clear layout.
66-
plt.tight_layout()
67+
plt.figure()
68+
plot_confusion_matrix(cm)
6769

6870
# Normalize the confusion matrix by row (i.e by the number of samples
6971
# in each class)
7072
cm_normalized = cm.astype('float') / cm.sum(axis=1)
71-
7273
print('Normalized confusion matrix')
7374
print(cm_normalized)
74-
75-
# Show normalized confusion matrix in a separate window
7675
plt.figure()
77-
plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.binary)
78-
plt.title('Normalized confusion matrix')
79-
plt.set_cmap('Blues')
80-
plt.colorbar()
81-
plt.xticks(tick_marks, iris.target_names, rotation=60)
82-
plt.yticks(tick_marks, iris.target_names)
83-
plt.ylabel('True label')
84-
plt.xlabel('Predicted label')
85-
plt.tight_layout()
76+
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
77+
8678
plt.show()

0 commit comments

Comments
 (0)