Skip to content

Commit cdf4617

Browse files
committed
mnist_cnn: Prototype code for image downscaling
1 parent 22ac65c commit cdf4617

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

examples/mnist_cnn/convert.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
import sys
3+
import numpy
4+
import skimage.io
5+
6+
p = sys.argv[1]
7+
o = sys.argv[2]
8+
a = numpy.load(p)
9+
print(a.shape)
10+
11+
skimage.io.imsave(o, a)

examples/mnist_cnn/downscale.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
import array
3+
4+
def average2d(inp, rowstride, x, y, size):
5+
acc = 0
6+
for r in range(y, y+size):
7+
for c in range(x, x+size):
8+
acc += inp[(r*rowstride)+c]
9+
10+
avg = acc // (size*size)
11+
return avg
12+
13+
def downscale(inp, out, in_size, out_size):
14+
15+
# assumes square, single-channel (grayscale) images
16+
assert len(inp) == in_size*in_size
17+
assert len(out) == out_size*out_size
18+
assert (in_size % out_size) == 0, (in_size, out_size)
19+
factor = in_size // out_size
20+
21+
for row in range(out_size):
22+
for col in range(out_size):
23+
o = (row * out_size) + col
24+
out[o] = average2d(inp, in_size, col*factor, row*factor, factor)
25+
26+
27+
def test_downscale():
28+
29+
import npyfile
30+
31+
shape, data = npyfile.load('inp.npy')
32+
print(data)
33+
34+
npyfile.save('orig.npy', data, shape)
35+
36+
insize = 96
37+
outsize = 32
38+
out = array.array('B', (0 for _ in range(outsize*outsize)))
39+
40+
downscale(data, out, insize, outsize)
41+
42+
npyfile.save('out.npy', out, (outsize, outsize))
43+
44+
45+
46+
if __name__ == '__main__':
47+
test_downscale()

examples/mnist_cnn/downscalecheck.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
import skimage.io
3+
import skimage.color
4+
import skimage.transform
5+
import numpy
6+
7+
p = 'data/rps-cv-images/rock/dnss2tOuxRmL0ZjZ.png'
8+
a = skimage.io.imread(p)
9+
print(a.shape, a.dtype)
10+
11+
o = skimage.color.rgb2gray(a)
12+
o = skimage.transform.resize(o, (96, 96))
13+
o = (o * 255).astype(numpy.uint8)
14+
15+
print(o.shape, o.dtype)
16+
17+
numpy.save('inp.npy', o, allow_pickle=False)

0 commit comments

Comments
 (0)