3
3
~~~~~~~~~
4
4
5
5
A classifier program for recognizing handwritten digits from the MNIST
6
- data set, using an SVM classifier.
7
-
8
- The program first trains the classifier, and then applies the
9
- classifier to the MNIST test data to see how many digits are correctly
10
- classified. Finally, the program uses matplotlib to create an image
11
- of the first 10 digits which are incorrectly classified."""
6
+ data set, using an SVM classifier."""
12
7
13
8
#### Libraries
14
9
# My libraries
15
- import mnist_loader # to load the MNIST data. For details on the
16
- # format the data is loaded in, see the module's
17
- # code
10
+ import mnist_loader
18
11
19
12
# Third-party libraries
20
- import matplotlib
21
- import matplotlib .pyplot as plt
22
- import numpy as np
23
13
from sklearn import svm
24
14
25
15
def svm_baseline ():
@@ -32,18 +22,6 @@ def svm_baseline():
32
22
num_correct = sum (int (a == y ) for a , y in zip (predictions , test_data [1 ]))
33
23
print "Baseline classifier using an SVM."
34
24
print "%s of %s values correct." % (num_correct , len (test_data [1 ]))
35
- # finally, plot the first ten images where the classifier fails
36
- failure_indices = [j for (j , z ) in enumerate (zip (predictions , test_data [1 ]))
37
- if z [0 ] != z [1 ]]
38
- failed_images = [np .reshape (test_data [0 ][failure_indices [j ]], (- 1 , 28 ))
39
- for j in xrange (10 )]
40
- fig = plt .figure ()
41
- for j in xrange (1 , 11 ):
42
- ax = fig .add_subplot (1 , 10 , j )
43
- ax .matshow (failed_images [j - 1 ], cmap = matplotlib .cm .binary )
44
- plt .xticks (np .array ([]))
45
- plt .yticks (np .array ([]))
46
- plt .show ()
47
25
48
26
if __name__ == "__main__" :
49
27
svm_baseline ()
0 commit comments