Skip to content

Commit 55e4b80

Browse files
committed
ENH Compute & display a confusion matrix
1 parent 23cc505 commit 55e4b80

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

ch02/seeds_knn_sklearn.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,24 @@
6767
crossed = cross_val_score(classifier, features, labels)
6868
print('Result with prescaling: {}'.format(crossed))
6969

70+
71+
# Now, generate & print a cross-validated confusion matrix for the same result
72+
from sklearn.metrics import confusion_matrix
73+
names = list(set(labels))
74+
labels = np.array([names.index(ell) for ell in labels])
75+
preds = labels.copy()
76+
preds[:] = -1
77+
for train, test in kf:
78+
classifier.fit(features[train], labels[train])
79+
preds[test] = classifier.predict(features[test])
80+
81+
cmat = confusion_matrix(labels, preds)
82+
print()
83+
print('Confusion matrix: [rows represent true outcome, columns predicted outcome]')
84+
print(cmat)
85+
86+
# The explicit float() conversion is necessary in Python 2
87+
# (Otherwise, result is rounded to 0)
88+
acc = cmat.trace()/float(cmat.sum())
89+
print('Accuracy: {0:.1%}'.format(acc))
90+

0 commit comments

Comments
 (0)