Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2):
# MNIST


def load_MNIST(path="aima-data/MNIST"):
def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
import os, struct
import array
import numpy as np
from collections import Counter

if fashion:
path = "aima-data/MNIST/Fashion"

plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
Expand Down Expand Up @@ -143,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST"):
return(train_img, train_lbl, test_img, test_lbl)


def show_MNIST(labels, images, samples=8):
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
digit_classes = [str(i) for i in range(10)]
fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]


def show_MNIST(labels, images, samples=8, fashion=False):
if not fashion:
classes = digit_classes
else:
classes = fashion_classes

num_classes = len(classes)

for y, cls in enumerate(classes):
Expand All @@ -161,13 +173,19 @@ def show_MNIST(labels, images, samples=8):
plt.show()


def show_ave_MNIST(labels, images):
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
def show_ave_MNIST(labels, images, fashion=False):
if not fashion:
item_type = "Digit"
classes = digit_classes
else:
item_type = "Apparel"
classes = fashion_classes

num_classes = len(classes)

for y, cls in enumerate(classes):
idxs = np.nonzero([i == y for i in labels])
print("Digit", y, ":", len(idxs[0]), "images.")
print(item_type, y, ":", len(idxs[0]), "images.")

ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
#print(ave_img.shape)
Expand Down