Skip to content

Commit 4a4ea60

Browse files
committed
Cleaning up mnist_svm
1 parent 43ce4a2 commit 4a4ea60

File tree

2 files changed

+2
-24
lines changed

2 files changed

+2
-24
lines changed

code/mnist_svm.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,13 @@
33
~~~~~~~~~
44
55
A classifier program for recognizing handwritten digits from the MNIST
6-
data set, using an SVM classifier.
7-
8-
The program first trains the classifier, and then applies the
9-
classifier to the MNIST test data to see how many digits are correctly
10-
classified. Finally, the program uses matplotlib to create an image
11-
of the first 10 digits which are incorrectly classified."""
6+
data set, using an SVM classifier."""
127

138
#### Libraries
149
# My libraries
15-
import mnist_loader # to load the MNIST data. For details on the
16-
# format the data is loaded in, see the module's
17-
# code
10+
import mnist_loader
1811

1912
# Third-party libraries
20-
import matplotlib
21-
import matplotlib.pyplot as plt
22-
import numpy as np
2313
from sklearn import svm
2414

2515
def svm_baseline():
@@ -32,18 +22,6 @@ def svm_baseline():
3222
num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1]))
3323
print "Baseline classifier using an SVM."
3424
print "%s of %s values correct." % (num_correct, len(test_data[1]))
35-
# finally, plot the first ten images where the classifier fails
36-
failure_indices = [j for (j, z) in enumerate(zip(predictions, test_data[1]))
37-
if z[0] != z[1]]
38-
failed_images = [np.reshape(test_data[0][failure_indices[j]], (-1, 28))
39-
for j in xrange(10)]
40-
fig = plt.figure()
41-
for j in xrange(1, 11):
42-
ax = fig.add_subplot(1, 10, j)
43-
ax.matshow(failed_images[j-1], cmap = matplotlib.cm.binary)
44-
plt.xticks(np.array([]))
45-
plt.yticks(np.array([]))
46-
plt.show()
4725

4826
if __name__ == "__main__":
4927
svm_baseline()

code/mnist_svm_failures.png

-9.92 KB
Binary file not shown.

0 commit comments

Comments
 (0)