|
374 | 374 | "source": [ |
375 | 375 | "def load_MNIST(path=\"aima-data/MNIST\"):\n", |
376 | 376 | " \"helper function to load MNIST data\"\n", |
377 | | - " train_img_file = open(os.path.join(path, \"train-images-idx3-ubyte\"), \"rb\")\n", |
378 | | - " train_lbl_file = open(os.path.join(path, \"train-labels-idx1-ubyte\"), \"rb\")\n", |
379 | | - " test_img_file = open(os.path.join(path, \"t10k-images-idx3-ubyte\"), \"rb\")\n", |
380 | | - " test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), \"rb\")\n", |
| 377 | + " with open(os.path.join(path, \"train-images-idx3-ubyte\"), \"rb\") as train_img_file:\n", |
| 378 | + " magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(\">IIII\", train_img_file.read(16))\n", |
| 379 | + " tr_img = array.array(\"B\", train_img_file.read())\n", |
381 | 380 | " \n", |
382 | | - " magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(\">IIII\", train_img_file.read(16))\n", |
383 | | - " tr_img = array.array(\"B\", train_img_file.read())\n", |
384 | | - " train_img_file.close() \n", |
385 | | - " magic_nr, tr_size = struct.unpack(\">II\", train_lbl_file.read(8))\n", |
386 | | - " tr_lbl = array.array(\"b\", train_lbl_file.read())\n", |
387 | | - " train_lbl_file.close()\n", |
| 381 | + " with open(os.path.join(path, \"train-labels-idx1-ubyte\"), \"rb\") as train_lbl_file:\n", |
| 382 | + " magic_nr, tr_size = struct.unpack(\">II\", train_lbl_file.read(8))\n", |
| 383 | + " tr_lbl = array.array(\"b\", train_lbl_file.read())\n", |
388 | 384 | " \n", |
389 | | - " magic_nr, te_size, te_rows, te_cols = struct.unpack(\">IIII\", test_img_file.read(16))\n", |
390 | | - " te_img = array.array(\"B\", test_img_file.read())\n", |
391 | | - " test_img_file.close()\n", |
392 | | - " magic_nr, te_size = struct.unpack(\">II\", test_lbl_file.read(8))\n", |
393 | | - " te_lbl = array.array(\"b\", test_lbl_file.read())\n", |
394 | | - " test_lbl_file.close()\n", |
395 | | - "\n", |
396 | | - "# print(len(tr_img), len(tr_lbl), tr_size)\n", |
397 | | - "# print(len(te_img), len(te_lbl), te_size)\n", |
| 385 | + " with open(os.path.join(path, \"t10k-images-idx3-ubyte\"), \"rb\") as test_img_file:\n", |
| 386 | + " magic_nr, te_size, te_rows, te_cols = struct.unpack(\">IIII\", test_img_file.read(16))\n", |
| 387 | + " te_img = array.array(\"B\", test_img_file.read())\n", |
| 388 | + " \n", |
| 389 | + " with open(os.path.join(path, \"t10k-labels-idx1-ubyte\"), \"rb\") as test_lbl_file:\n", |
| 390 | + " magic_nr, te_size = struct.unpack(\">II\", test_lbl_file.read(8))\n", |
| 391 | + " te_lbl = array.array(\"b\", test_lbl_file.read())\n", |
398 | 392 | " \n", |
399 | 393 | " train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)\n", |
400 | 394 | " train_lbl = np.zeros((tr_size,), dtype=np.int8)\n", |
|
0 commit comments