Skip to content

Commit 0bb4069

Browse files
antmarakisnorvig
authored andcommitted
Learning Notebook: Iris/MNIST Visualization + Learner Evalutation (aimacode#568)
* Update learning.ipynb * Update notebook.py * Update .travis.yml
1 parent 8e5043c commit 0bb4069

File tree

3 files changed

+397
-180
lines changed

3 files changed

+397
-180
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ install:
1111
- pip install six
1212
- pip install flake8
1313
- pip install ipython
14+
- pip install matplotlib
1415

1516
script:
1617
- py.test

learning.ipynb

Lines changed: 254 additions & 179 deletions
Large diffs are not rendered by default.

notebook.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,147 @@
22
from utils import argmax, argmin
33
from games import TicTacToe, alphabeta_player, random_player, Fig52Extended, infinity
44
from logic import parse_definite_clause, standardize_variables, unify, subst
5+
from learning import DataSet
6+
from mpl_toolkits.mplot3d import Axes3D
7+
import matplotlib.pyplot as plt
8+
9+
import os, struct
10+
import array
11+
import numpy as np
12+
from collections import Counter
13+
14+
15+
# ______________________________________________________________________________
16+
17+
18+
def show_iris(i=0, j=1, k=2):
19+
'''Plots the iris dataset in a 3D plot.
20+
The three axes are given by i, j and k,
21+
which correspond to three of the four iris features.'''
22+
plt.rcParams.update(plt.rcParamsDefault)
23+
24+
fig = plt.figure()
25+
ax = fig.add_subplot(111, projection='3d')
26+
27+
iris = DataSet(name="iris")
28+
buckets = iris.split_values_by_classes()
29+
30+
features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
31+
f1, f2, f3 = features[i], features[j], features[k]
32+
33+
a_setosa = [v[i] for v in buckets["setosa"]]
34+
b_setosa = [v[j] for v in buckets["setosa"]]
35+
c_setosa = [v[k] for v in buckets["setosa"]]
36+
37+
a_virginica = [v[i] for v in buckets["virginica"]]
38+
b_virginica = [v[j] for v in buckets["virginica"]]
39+
c_virginica = [v[k] for v in buckets["virginica"]]
40+
41+
a_versicolor = [v[i] for v in buckets["versicolor"]]
42+
b_versicolor = [v[j] for v in buckets["versicolor"]]
43+
c_versicolor = [v[k] for v in buckets["versicolor"]]
44+
45+
46+
for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa),
47+
('g', '^', a_virginica, b_virginica, c_virginica),
48+
('r', 'o', a_versicolor, b_versicolor, c_versicolor)]:
49+
ax.scatter(sl, sw, pl, c=c, marker=m)
50+
51+
ax.set_xlabel(f1)
52+
ax.set_ylabel(f2)
53+
ax.set_zlabel(f3)
54+
55+
plt.show()
56+
57+
# ______________________________________________________________________________
58+
59+
60+
def load_MNIST(path="aima-data/MNIST"):
61+
import os, struct
62+
import array
63+
import numpy as np
64+
from collections import Counter
65+
66+
plt.rcParams.update(plt.rcParamsDefault)
67+
plt.rcParams['figure.figsize'] = (10.0, 8.0)
68+
plt.rcParams['image.interpolation'] = 'nearest'
69+
plt.rcParams['image.cmap'] = 'gray'
70+
71+
train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
72+
train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
73+
test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
74+
test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
75+
76+
magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
77+
tr_img = array.array("B", train_img_file.read())
78+
train_img_file.close()
79+
magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
80+
tr_lbl = array.array("b", train_lbl_file.read())
81+
train_lbl_file.close()
82+
83+
magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16))
84+
te_img = array.array("B", test_img_file.read())
85+
test_img_file.close()
86+
magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8))
87+
te_lbl = array.array("b", test_lbl_file.read())
88+
test_lbl_file.close()
89+
90+
#print(len(tr_img), len(tr_lbl), tr_size)
91+
#print(len(te_img), len(te_lbl), te_size)
92+
93+
train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)
94+
train_lbl = np.zeros((tr_size,), dtype=np.int8)
95+
for i in range(tr_size):
96+
train_img[i] = np.array(tr_img[i*tr_rows*tr_cols : (i+1)*tr_rows*tr_cols]).reshape((tr_rows*te_cols))
97+
train_lbl[i] = tr_lbl[i]
98+
99+
test_img = np.zeros((te_size, te_rows*te_cols), dtype=np.int16)
100+
test_lbl = np.zeros((te_size,), dtype=np.int8)
101+
for i in range(te_size):
102+
test_img[i] = np.array(te_img[i*te_rows*te_cols : (i+1)*te_rows*te_cols]).reshape((te_rows*te_cols))
103+
test_lbl[i] = te_lbl[i]
104+
105+
return(train_img, train_lbl, test_img, test_lbl)
106+
107+
108+
def show_MNIST(labels, images, samples=8):
109+
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
110+
num_classes = len(classes)
111+
112+
for y, cls in enumerate(classes):
113+
idxs = np.nonzero([i == y for i in labels])
114+
idxs = np.random.choice(idxs[0], samples, replace=False)
115+
for i , idx in enumerate(idxs):
116+
plt_idx = i * num_classes + y + 1
117+
plt.subplot(samples, num_classes, plt_idx)
118+
plt.imshow(images[idx].reshape((28, 28)))
119+
plt.axis("off")
120+
if i == 0:
121+
plt.title(cls)
122+
123+
plt.show()
124+
125+
126+
def show_ave_MNIST(labels, images):
127+
classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
128+
num_classes = len(classes)
129+
130+
for y, cls in enumerate(classes):
131+
idxs = np.nonzero([i == y for i in labels])
132+
print("Digit", y, ":", len(idxs[0]), "images.")
133+
134+
ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0)
135+
#print(ave_img.shape)
136+
137+
plt.subplot(1, num_classes, y+1)
138+
plt.imshow(ave_img.reshape((28, 28)))
139+
plt.axis("off")
140+
plt.title(cls)
141+
142+
plt.show()
143+
144+
# ______________________________________________________________________________
145+
5146

6147
_canvas = """
7148
<script type="text/javascript" src="./js/canvas.js"></script>
@@ -132,7 +273,7 @@ def display_html(html_string):
132273

133274

134275
################################################################################
135-
276+
136277

137278
class Canvas_TicTacToe(Canvas):
138279
"""Play a 3x3 TicTacToe game on HTML canvas

0 commit comments

Comments
 (0)