@@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2):
9595# MNIST
9696
9797
98- def load_MNIST (path = "aima-data/MNIST" ):
98+ def load_MNIST (path = "aima-data/MNIST/Digits" , fashion = False ):
9999 import os , struct
100100 import array
101101 import numpy as np
102102 from collections import Counter
103103
104+ if fashion :
105+ path = "aima-data/MNIST/Fashion"
106+
104107 plt .rcParams .update (plt .rcParamsDefault )
105108 plt .rcParams ['figure.figsize' ] = (10.0 , 8.0 )
106109 plt .rcParams ['image.interpolation' ] = 'nearest'
@@ -143,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST"):
143146 return (train_img , train_lbl , test_img , test_lbl )
144147
145148
146- def show_MNIST (labels , images , samples = 8 ):
147- classes = ["0" , "1" , "2" , "3" , "4" , "5" , "6" , "7" , "8" , "9" ]
149+ digit_classes = [str (i ) for i in range (10 )]
150+ fashion_classes = ["T-shirt/top" , "Trouser" , "Pullover" , "Dress" , "Coat" ,
151+ "Sandal" , "Shirt" , "Sneaker" , "Bag" , "Ankle boot" ]
152+
153+
154+ def show_MNIST (labels , images , samples = 8 , fashion = False ):
155+ if not fashion :
156+ classes = digit_classes
157+ else :
158+ classes = fashion_classes
159+
148160 num_classes = len (classes )
149161
150162 for y , cls in enumerate (classes ):
@@ -161,13 +173,19 @@ def show_MNIST(labels, images, samples=8):
161173 plt .show ()
162174
163175
164- def show_ave_MNIST (labels , images ):
165- classes = ["0" , "1" , "2" , "3" , "4" , "5" , "6" , "7" , "8" , "9" ]
176+ def show_ave_MNIST (labels , images , fashion = False ):
177+ if not fashion :
178+ item_type = "Digit"
179+ classes = digit_classes
180+ else :
181+ item_type = "Apparel"
182+ classes = fashion_classes
183+
166184 num_classes = len (classes )
167185
168186 for y , cls in enumerate (classes ):
169187 idxs = np .nonzero ([i == y for i in labels ])
170- print ("Digit" , y , ":" , len (idxs [0 ]), "images." )
188+ print (item_type , y , ":" , len (idxs [0 ]), "images." )
171189
172190 ave_img = np .mean (np .vstack ([images [i ] for i in idxs [0 ]]), axis = 0 )
173191 #print(ave_img.shape)
0 commit comments