Skip to content

Commit 1dc371d

Browse files
committed
kmeans: Also test bytearray
1 parent e39ad3b commit 1dc371d

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

tests/test_kmeans.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,47 @@
44
import array
55
import gc
66

7-
def test_kmeans_two_clusters():
8-
"""
9-
Data that is grouped into high/low should be clusterable into 2
10-
"""
7+
def make_two_cluster_data(typecode):
118

129
n_features = 3
13-
dataset = array.array('B', [
10+
dataset = [
11+
# cluster1
1412
0, 0, 0,
1513
10, 5, 2,
16-
14+
15+
# cluster2
1716
200, 50, 100,
1817
255, 255, 255,
19-
])
18+
]
2019

21-
centroids = array.array('B', [
20+
centroids = [
2221
0, 0, 0,
2322
200, 200, 200,
24-
])
23+
]
24+
25+
if typecode == 'bytearray':
26+
dataset = bytearray(dataset)
27+
centroids = bytearray(centroids)
28+
else:
29+
dataset = array.array(typecode, dataset)
30+
centroids = array.array(typecode, centroids)
31+
32+
return dataset, centroids
33+
34+
35+
def test_kmeans_two_clusters():
36+
"""
37+
Data that is grouped into high/low should be clusterable into 2
38+
"""
39+
40+
n_features = 3
41+
# test both with "bytearray" and "array.array"
42+
for typecode in ['bytearray', 'B']:
43+
dataset, centroids = make_two_cluster_data(typecode)
2544

26-
assignments = emlkmeans.cluster(dataset, centroids, channels=n_features)
27-
assert len(assignments) == len(dataset)/n_features
28-
assert list(assignments) == [0, 0, 1, 1], assignments
45+
assignments = emlkmeans.cluster(dataset, centroids, channels=n_features)
46+
assert len(assignments) == len(dataset)/n_features
47+
assert list(assignments) == [0, 0, 1, 1], assignments
2948

3049

3150
test_kmeans_two_clusters()

0 commit comments

Comments
 (0)