Skip to content

Commit d7e66f5

Browse files
committed
kmeans: Make core algorithm accessible as generator
1 parent 80a8a7e commit d7e66f5

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

examples/color_quantize_kmeans/color_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def quantize_path(inp, outp, palette, n_samples=100):
103103

104104
# Learn a palette
105105
start = time.ticks_us()
106-
emlkmeans.cluster(samples, palette, max_iter=20)
106+
emlkmeans.cluster(samples, palette, features=3, max_iter=20)
107107
dur = (time.ticks_diff(time.ticks_us(), start) / 1000.0)
108108
print('cluster duration (ms)', dur)
109109

src/emlkmeans/kmeans.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
# when the native code emitter is enabled. Which is critical for performance...
66
# maybe we can move the inner part into a kmeans_cluster_step done in C
77

8+
89
#@micropython.native
9-
def cluster(values, centroids,
10-
channels=3, max_iter=10, stop_changes=0):
10+
def cluster_iter(values, centroids, assignments, features,
11+
max_iter=10, stop_changes=0):
1112
"""
1213
Perform K-Means clustering of @values
1314
@@ -16,16 +17,14 @@ def cluster(values, centroids,
1617
NOTE: will mutate @centroids
1718
"""
1819

20+
channels = features
1921
n_clusters = len(centroids) // channels
2022
n_samples = len(values) // channels
2123

22-
assert channels == 3, 'only support 3 channels for now'
23-
2424
assert channels < 255, channels
2525
assert n_clusters < 255, n_clusters
2626
assert n_samples < 65535, n_samples
2727

28-
assignments = array.array('B', (255 for _ in range(n_samples)))
2928
cluster_sums = array.array('L', (0 for _ in range(n_clusters*channels)))
3029
cluster_counts = array.array('H', (0 for _ in range(n_clusters)))
3130

@@ -36,18 +35,24 @@ def cluster(values, centroids,
3635
for s in range(n_samples):
3736
v = values[s*channels:(s+1)*channels]
3837

38+
# PERF: considering taking all N points at the same time, filling indices and (optionally) distances
3939
idx, dist = euclidean_argmin(centroids, v)
4040
#idx, dist = 0, 0
4141

4242
if idx != assignments[s]:
4343
changes += 1
4444
assignments[s] = idx
4545

46-
print('iter', i, changes)
46+
# Pass control back to caller
47+
# So one can do other work between the iterations
48+
# or implement custom stopping criteria
49+
yield changes
50+
4751
if changes <= stop_changes:
4852
break
4953

5054
## update cluster centroids
55+
# PERF: consider moving this to C. With a update_centroids() function
5156
# reset old values
5257
for c in range(n_clusters*channels):
5358
cluster_sums[c] = 0
@@ -70,9 +75,18 @@ def cluster(values, centroids,
7075

7176
for i in range(channels):
7277
centroids[(c*channels)+i] = cluster_sums[(c*channels)+i] // count
73-
74-
#yield assignments
75-
# TODO: make this into a generator? so other work can be done in between
7678

7779

80+
81+
def cluster(values, centroids, features, **kwargs):
82+
"""Convenience wrapper around cluster_iter"""
83+
84+
n_samples = len(values) // features
85+
assignments = array.array('B', (255 for _ in range(n_samples)))
86+
87+
generator = cluster_iter(values, centroids, assignments, features, **kwargs)
88+
for changes in generator:
89+
pass
90+
7891
return assignments
92+

tests/test_kmeans.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,26 @@ def test_kmeans_two_clusters():
4242
for typecode in ['bytearray', 'B']:
4343
dataset, centroids = make_two_cluster_data(typecode)
4444

45-
assignments = emlkmeans.cluster(dataset, centroids, channels=n_features)
45+
assignments = emlkmeans.cluster(dataset, centroids, features=n_features)
4646
assert len(assignments) == len(dataset)/n_features
4747
assert list(assignments) == [0, 0, 1, 1], assignments
4848

4949

50+
def test_kmeans_many_features():
51+
52+
n_features = 10
53+
n_samples = 100
54+
n_clusters = 10
55+
typecode = 'B'
56+
# TODO: actually throw out some clusters, see we can find them
57+
dataset = array.array(typecode, (0 for _ in range(n_features*n_samples)) )
58+
centroids = array.array(typecode, (0 for _ in range(n_features*n_clusters)) )
59+
60+
assignments = emlkmeans.cluster(dataset, centroids, features=n_features, max_iter=2)
61+
assert len(assignments) == len(dataset)/n_features
62+
assert min(assignments) >= 0
63+
assert max(assignments) < n_clusters
64+
5065
test_kmeans_two_clusters()
66+
test_kmeans_many_features()
67+

0 commit comments

Comments
 (0)