|
| 1 | +"""A neural network for recognizing the handdrawn digits in the MNIST |
| 2 | +database. This follows the ideas from Franklin Ch 14 and the book by |
| 3 | +Rashid. |
| 4 | +
|
| 5 | +""" |
| 6 | + |
| 7 | +# Note: this uses progressbar2 to show status durning training. Do |
| 8 | +# pip3 install progressbar2 --user |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import matplotlib.pyplot as plt |
| 12 | +import progressbar |
| 13 | + |
| 14 | +class TrainingDigit(object): |
| 15 | + """a handwritten digit from the MNIST training set""" |
| 16 | + |
| 17 | + def __init__(self, raw_string): |
| 18 | + """we feed this a single line from the MNIST data set""" |
| 19 | + self.raw_string = raw_string |
| 20 | + |
| 21 | + # make the data range from 0.01 to 1.00 |
| 22 | + _tmp = raw_string.split(",") |
| 23 | + self.scaled = np.asfarray(_tmp[1:])/255.0 * 0.99 + 0.01 |
| 24 | + |
| 25 | + # the correct answer |
| 26 | + self.num = int(_tmp[0]) |
| 27 | + |
| 28 | + # the output for the NN as a bit array -- make this lie in [0.01, 0.99] |
| 29 | + self.out = np.zeros((10)) + 0.01 |
| 30 | + self.out[self.num] = 0.99 |
| 31 | + |
| 32 | + def plot(self, outfile=None, output=None): |
| 33 | + """plot the digit""" |
| 34 | + plt.clf() |
| 35 | + plt.imshow(self.scaled.reshape((28, 28)), |
| 36 | + cmap="Greys", interpolation="nearest") |
| 37 | + if output is not None: |
| 38 | + dstr = ["{}: {:6.4f}".format(n, v) for n, v in enumerate(output)] |
| 39 | + ostr = "correct digit: {}\n".format(self.num) |
| 40 | + ostr += " ".join(dstr[0:5]) + "\n" + " ".join(dstr[5:]) |
| 41 | + plt.title("{}".format(ostr), fontsize="x-small") |
| 42 | + if outfile is not None: |
| 43 | + plt.savefig(outfile, dpi=150) |
| 44 | + |
| 45 | +class TrainingSet(object): |
| 46 | + """Manage reading digits from the MNIST training set csv and provide |
| 47 | + methods for returning the next digit, reading only what is needed |
| 48 | + each time. |
| 49 | +
|
| 50 | + """ |
| 51 | + |
| 52 | + def __init__(self): |
| 53 | + self.f = open("mnist_train.csv", "r") |
| 54 | + |
| 55 | + def get_next(self): |
| 56 | + """return the next digit from the training database""" |
| 57 | + return TrainingDigit(self.f.readline()) |
| 58 | + |
| 59 | + def reset(self): |
| 60 | + """reset the training set to the first digit in the database""" |
| 61 | + self.f.seek(0) |
| 62 | + |
| 63 | + def close(self): |
| 64 | + """close the training set file""" |
| 65 | + # should implement a context manager |
| 66 | + self.f.close() |
| 67 | + |
| 68 | +class UnknownDigit(TrainingDigit): |
| 69 | + """A digit from the MNIST test database. This provides a method to |
| 70 | + compare a NN result to the correct answer |
| 71 | +
|
| 72 | + """ |
| 73 | + def __init__(self, raw_string): |
| 74 | + super().__init__(raw_string) |
| 75 | + self.out = None |
| 76 | + |
| 77 | + def check_output(self, out): |
| 78 | + """given the output array from the NN, return True if it is |
| 79 | + correct for this digit""" |
| 80 | + guess = np.argmax(out) |
| 81 | + return guess == self.num |
| 82 | + |
| 83 | +class TestSet(object): |
| 84 | + """read the next digit from the test set csv file and return a |
| 85 | + UnknownDigit object""" |
| 86 | + |
| 87 | + def __init__(self): |
| 88 | + self.f = open("mnist_test.csv", "r") |
| 89 | + |
| 90 | + def get_next(self): |
| 91 | + """return the next digit from the test database""" |
| 92 | + return UnknownDigit(self.f.readline()) |
| 93 | + |
| 94 | + def close(self): |
| 95 | + """close the test set file""" |
| 96 | + # should implement a context manager |
| 97 | + self.f.close() |
| 98 | + |
| 99 | + |
| 100 | +class NeuralNetwork(object): |
| 101 | + """A neural network class with a single hidden layer.""" |
| 102 | + |
| 103 | + def __init__(self, num_training_unique=100, n_epochs=10, |
| 104 | + learning_rate=0.1, |
| 105 | + hidden_layer_size=100): |
| 106 | + |
| 107 | + self.num_training_unique = num_training_unique |
| 108 | + self.n_epochs = n_epochs |
| 109 | + |
| 110 | + self.train_set = TrainingSet() |
| 111 | + |
| 112 | + # learning rate |
| 113 | + self.eta = learning_rate |
| 114 | + |
| 115 | + # we get the size of the layers from the length of the input |
| 116 | + # and output |
| 117 | + d = self.train_set.get_next() |
| 118 | + |
| 119 | + # the number of nodes/neurons on the output layer |
| 120 | + self.m = len(d.out) |
| 121 | + |
| 122 | + # the number of nodes/neurons on the input layer |
| 123 | + self.n = len(d.scaled) |
| 124 | + |
| 125 | + # the number of nodes/neurons on the hidden layer |
| 126 | + self.k = hidden_layer_size |
| 127 | + |
| 128 | + # we will initialize the weights with Gaussian normal random |
| 129 | + # numbers centered on 0 with a width of 1/sqrt(n), where n is |
| 130 | + # the length of the input state |
| 131 | + |
| 132 | + # A is the set of weights between the hidden layer and output layer |
| 133 | + self.A = np.random.normal(0.0, 1.0/np.sqrt(self.k), (self.m, self.k)) |
| 134 | + |
| 135 | + # B is the set of weights between the input layer and hidden layer |
| 136 | + self.B = np.random.normal(0.0, 1.0/np.sqrt(self.n), (self.k, self.n)) |
| 137 | + |
| 138 | + def g(self, p): |
| 139 | + """our sigmoid function that operates on the hidden layer""" |
| 140 | + return 1.0/(1.0 + np.exp(-p)) |
| 141 | + |
| 142 | + def train(self): |
| 143 | + """Train the neural network by doing gradient descent with back |
| 144 | + propagation to set the matrix elements in B (the weights |
| 145 | + between the input and hidden layer) and A (the weights between |
| 146 | + the hidden layer and output layer) |
| 147 | +
|
| 148 | + """ |
| 149 | + |
| 150 | + for i in range(self.n_epochs): |
| 151 | + self.train_set.reset() |
| 152 | + print("epoch {} of {}".format(i+1, self.n_epochs)) |
| 153 | + bar = progressbar.ProgressBar() |
| 154 | + for q in bar(range(self.num_training_unique)): |
| 155 | + |
| 156 | + model = self.train_set.get_next() |
| 157 | + |
| 158 | + x = model.scaled.reshape(self.n, 1) |
| 159 | + y = model.out.reshape(self.m, 1) |
| 160 | + |
| 161 | + z_tilde = self.g(self.B @ x) |
| 162 | + z = self.g(self.A @ z_tilde) |
| 163 | + |
| 164 | + e = z - y |
| 165 | + e_tilde = self.A.T @ e |
| 166 | + |
| 167 | + dA = -2*self.eta * e * z*(1-z) @ z_tilde.T |
| 168 | + dB = -2*self.eta * e_tilde * z_tilde*(1-z_tilde) @ x.T |
| 169 | + |
| 170 | + self.A[:, :] += dA |
| 171 | + self.B[:, :] += dB |
| 172 | + |
| 173 | + |
| 174 | + self.train_set.close() |
| 175 | + |
| 176 | + def predict(self, model): |
| 177 | + """ predict the outcome using our trained matrix A """ |
| 178 | + y = self.g(self.A @ (self.g(self.B @ model.scaled))) |
| 179 | + return y |
| 180 | + |
| 181 | + |
| 182 | +def main(n_epochs=5, hidden_layer_size=100, |
| 183 | + num_training_unique=60000, |
| 184 | + learning_rate=0.1, |
| 185 | + do_plots=True): |
| 186 | + """a driver for the NN""" |
| 187 | + |
| 188 | + nn = NeuralNetwork(num_training_unique=num_training_unique, |
| 189 | + learning_rate=learning_rate, |
| 190 | + hidden_layer_size=hidden_layer_size, |
| 191 | + n_epochs=n_epochs) |
| 192 | + |
| 193 | + # train |
| 194 | + nn.train() |
| 195 | + |
| 196 | + # histogram of the matrix elements |
| 197 | + if do_plots: |
| 198 | + plt.clf() |
| 199 | + plt.hist(nn.A.flatten(), bins=20) |
| 200 | + plt.title(r"${\bf A}$") |
| 201 | + plt.savefig("A_hist.png", dpi=150, bbox_inches="tight") |
| 202 | + |
| 203 | + plt.clf() |
| 204 | + plt.hist(nn.B.flatten(), bins=20) |
| 205 | + plt.title(r"${\bf B}$") |
| 206 | + plt.savefig("B_hist.png", dpi=150, bbox_inches="tight") |
| 207 | + |
| 208 | + # now try it out on the unseen data from the test database |
| 209 | + unk = TestSet() |
| 210 | + n_test = 10000 |
| 211 | + n_correct = 0 |
| 212 | + for t in range(n_test): |
| 213 | + d = unk.get_next() |
| 214 | + num_guess = nn.predict(d) |
| 215 | + if d.check_output(num_guess): |
| 216 | + n_correct += 1 |
| 217 | + else: |
| 218 | + if do_plots: |
| 219 | + d.plot(outfile="incorrect_digit_{:04d}.png".format(t), output=num_guess) |
| 220 | + |
| 221 | + unk.close() |
| 222 | + |
| 223 | + return n_correct/n_test |
| 224 | + |
| 225 | + |
| 226 | +if __name__ == "__main__": |
| 227 | + f_right = main(do_plots=False) |
| 228 | + |
| 229 | + print("correct fraction: {}".format(f_right)) |
| 230 | + |
0 commit comments