|
26 | 26 |
|
27 | 27 | print(__doc__) |
28 | 28 |
|
| 29 | +import itertools |
29 | 30 | import numpy as np |
30 | 31 | import matplotlib.pyplot as plt |
31 | 32 |
|
|
37 | 38 | iris = datasets.load_iris() |
38 | 39 | X = iris.data |
39 | 40 | y = iris.target |
| 41 | +class_names = iris.target_names |
40 | 42 |
|
41 | 43 | # Split the data into a training set and a test set |
42 | 44 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) |
|
47 | 49 | y_pred = classifier.fit(X_train, y_train).predict(X_test) |
48 | 50 |
|
49 | 51 |
|
50 | | -def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues): |
| 52 | +def plot_confusion_matrix(cm, classes, |
| 53 | + normalize=False, |
| 54 | + title='Confusion matrix', |
| 55 | + cmap=plt.cm.Blues): |
| 56 | + """ |
| 57 | + This function prints and plots the confusion matrix. |
| 58 | + Normalization can be applied by setting `normalize=True`. |
| 59 | + """ |
51 | 60 | plt.imshow(cm, interpolation='nearest', cmap=cmap) |
52 | 61 | plt.title(title) |
53 | 62 | plt.colorbar() |
54 | | - tick_marks = np.arange(len(iris.target_names)) |
55 | | - plt.xticks(tick_marks, iris.target_names, rotation=45) |
56 | | - plt.yticks(tick_marks, iris.target_names) |
| 63 | + tick_marks = np.arange(len(classes)) |
| 64 | + plt.xticks(tick_marks, classes, rotation=45) |
| 65 | + plt.yticks(tick_marks, classes) |
| 66 | + |
| 67 | + if normalize: |
| 68 | + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
| 69 | + print("Normalized confusion matrix") |
| 70 | + else: |
| 71 | + print('Confusion matrix, without normalization') |
| 72 | + |
| 73 | + print(cm) |
| 74 | + |
| 75 | + thresh = cm.max() / 2. |
| 76 | + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
| 77 | + plt.text(j, i, cm[i, j], |
| 78 | + horizontalalignment="center", |
| 79 | + color="white" if cm[i, j] > thresh else "black") |
| 80 | + |
57 | 81 | plt.tight_layout() |
58 | 82 | plt.ylabel('True label') |
59 | 83 | plt.xlabel('Predicted label') |
60 | 84 |
|
61 | | - |
62 | 85 | # Compute confusion matrix |
63 | | -cm = confusion_matrix(y_test, y_pred) |
| 86 | +cnf_matrix = confusion_matrix(y_test, y_pred) |
64 | 87 | np.set_printoptions(precision=2) |
65 | | -print('Confusion matrix, without normalization') |
66 | | -print(cm) |
| 88 | + |
| 89 | +# Plot non-normalized confusion matrix |
67 | 90 | plt.figure() |
68 | | -plot_confusion_matrix(cm) |
| 91 | +plot_confusion_matrix(cnf_matrix, classes=class_names, |
| 92 | + title='Confusion matrix, without normalization') |
69 | 93 |
|
70 | | -# Normalize the confusion matrix by row (i.e by the number of samples |
71 | | -# in each class) |
72 | | -cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
73 | | -print('Normalized confusion matrix') |
74 | | -print(cm_normalized) |
| 94 | +# Plot normalized confusion matrix |
75 | 95 | plt.figure() |
76 | | -plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix') |
| 96 | +plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, |
| 97 | + title='Normalized confusion matrix') |
77 | 98 |
|
78 | 99 | plt.show() |
0 commit comments