Skip to content

Commit 3129917

Browse files
lucasmouranorvig
authored andcommitted
Update load_MNIST on learning.ipynb (aimacode#339)
Make load_MNIST function easier to read
1 parent 3f5f856 commit 3129917

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

learning.ipynb

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -374,27 +374,21 @@
374374
"source": [
375375
"def load_MNIST(path=\"aima-data/MNIST\"):\n",
376376
" \"helper function to load MNIST data\"\n",
377-
" train_img_file = open(os.path.join(path, \"train-images-idx3-ubyte\"), \"rb\")\n",
378-
" train_lbl_file = open(os.path.join(path, \"train-labels-idx1-ubyte\"), \"rb\")\n",
379-
" test_img_file = open(os.path.join(path, \"t10k-images-idx3-ubyte\"), \"rb\")\n",
380-
" test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), \"rb\")\n",
377+
" with open(os.path.join(path, \"train-images-idx3-ubyte\"), \"rb\") as train_img_file:\n",
378+
" magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(\">IIII\", train_img_file.read(16))\n",
379+
" tr_img = array.array(\"B\", train_img_file.read())\n",
381380
" \n",
382-
" magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(\">IIII\", train_img_file.read(16))\n",
383-
" tr_img = array.array(\"B\", train_img_file.read())\n",
384-
" train_img_file.close() \n",
385-
" magic_nr, tr_size = struct.unpack(\">II\", train_lbl_file.read(8))\n",
386-
" tr_lbl = array.array(\"b\", train_lbl_file.read())\n",
387-
" train_lbl_file.close()\n",
381+
" with open(os.path.join(path, \"train-labels-idx1-ubyte\"), \"rb\") as train_lbl_file:\n",
382+
" magic_nr, tr_size = struct.unpack(\">II\", train_lbl_file.read(8))\n",
383+
" tr_lbl = array.array(\"b\", train_lbl_file.read())\n",
388384
" \n",
389-
" magic_nr, te_size, te_rows, te_cols = struct.unpack(\">IIII\", test_img_file.read(16))\n",
390-
" te_img = array.array(\"B\", test_img_file.read())\n",
391-
" test_img_file.close()\n",
392-
" magic_nr, te_size = struct.unpack(\">II\", test_lbl_file.read(8))\n",
393-
" te_lbl = array.array(\"b\", test_lbl_file.read())\n",
394-
" test_lbl_file.close()\n",
395-
"\n",
396-
"# print(len(tr_img), len(tr_lbl), tr_size)\n",
397-
"# print(len(te_img), len(te_lbl), te_size)\n",
385+
" with open(os.path.join(path, \"t10k-images-idx3-ubyte\"), \"rb\") as test_img_file:\n",
386+
" magic_nr, te_size, te_rows, te_cols = struct.unpack(\">IIII\", test_img_file.read(16))\n",
387+
" te_img = array.array(\"B\", test_img_file.read())\n",
388+
" \n",
389+
" with open(os.path.join(path, \"t10k-labels-idx1-ubyte\"), \"rb\") as test_lbl_file:\n",
390+
" magic_nr, te_size = struct.unpack(\">II\", test_lbl_file.read(8))\n",
391+
" te_lbl = array.array(\"b\", test_lbl_file.read())\n",
398392
" \n",
399393
" train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)\n",
400394
" train_lbl = np.zeros((tr_size,), dtype=np.int8)\n",

0 commit comments

Comments
 (0)