@@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2):
9595# MNIST 
9696
9797
98- def  load_MNIST (path = "aima-data/MNIST"  ):
98+ def  load_MNIST (path = "aima-data/MNIST/Digits"  ,  fashion = False ):
9999    import  os , struct 
100100    import  array 
101101    import  numpy  as  np 
102102    from  collections  import  Counter 
103103
104+     if  fashion :
105+         path  =  "aima-data/MNIST/Fashion" 
106+ 
104107    plt .rcParams .update (plt .rcParamsDefault )
105108    plt .rcParams ['figure.figsize' ] =  (10.0 , 8.0 )
106109    plt .rcParams ['image.interpolation' ] =  'nearest' 
@@ -143,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST"):
143146    return (train_img , train_lbl , test_img , test_lbl )
144147
145148
146- def  show_MNIST (labels , images , samples = 8 ):
147-     classes  =  ["0" , "1" , "2" , "3" , "4" , "5" , "6" , "7" , "8" , "9" ]
149+ digit_classes  =  [str (i ) for  i  in  range (10 )]
150+ fashion_classes  =  ["T-shirt/top" , "Trouser" , "Pullover" , "Dress" , "Coat" ,
151+                    "Sandal" , "Shirt" , "Sneaker" , "Bag" , "Ankle boot" ]
152+ 
153+ 
154+ def  show_MNIST (labels , images , samples = 8 , fashion = False ):
155+     if  not  fashion :
156+         classes  =  digit_classes 
157+     else :
158+         classes  =  fashion_classes 
159+ 
148160    num_classes  =  len (classes )
149161
150162    for  y , cls  in  enumerate (classes ):
@@ -161,13 +173,19 @@ def show_MNIST(labels, images, samples=8):
161173    plt .show ()
162174
163175
164- def  show_ave_MNIST (labels , images ):
165-     classes  =  ["0" , "1" , "2" , "3" , "4" , "5" , "6" , "7" , "8" , "9" ]
176+ def  show_ave_MNIST (labels , images , fashion = False ):
177+     if  not  fashion :
178+         item_type  =  "Digit" 
179+         classes  =  digit_classes 
180+     else :
181+         item_type  =  "Apparel" 
182+         classes  =  fashion_classes 
183+ 
166184    num_classes  =  len (classes )
167185
168186    for  y , cls  in  enumerate (classes ):
169187        idxs  =  np .nonzero ([i  ==  y  for  i  in  labels ])
170-         print ("Digit" , y , ":" , len (idxs [0 ]), "images." )
188+         print (item_type , y , ":" , len (idxs [0 ]), "images." )
171189
172190        ave_img  =  np .mean (np .vstack ([images [i ] for  i  in  idxs [0 ]]), axis  =  0 )
173191        #print(ave_img.shape) 
0 commit comments