Skip to content

Commit ef173a4

Browse files
committed
Eliminate quality() function. Create transpose() function. Improve docstrings
1 parent fd67194 commit ef173a4

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

kmeans.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from random import sample
33
from math import fsum, sqrt
44
from collections import defaultdict
5-
# from functools import partial
65

76
def partial(func, *args):
7+
"Rewrite functools.partial() in a way that doesn't confuse mypy"
88
def inner(*moreargs):
99
return func(*args, *moreargs)
1010
return inner
@@ -17,6 +17,10 @@ def mean(data: Iterable[float]) -> float:
1717
data = list(data)
1818
return fsum(data) / len(data)
1919

20+
def transpose(matrix: Iterable[Iterable]) -> Iterable[tuple]:
21+
'Swap rows with columns for a 2-D array'
22+
return zip(*matrix)
23+
2024
def dist(p: Point, q: Point, sqrt=sqrt, fsum=fsum, zip=zip) -> float:
2125
'Euclidean distance'
2226
return sqrt(fsum((x1 - x2) ** 2.0 for x1, x2 in zip(p, q)))
@@ -31,11 +35,7 @@ def assign_data(centroids: Sequence[Centroid], data: Iterable[Point]) -> Dict[Ce
3135

3236
def compute_centroids(groups: Iterable[Sequence[Point]]) -> List[Centroid]:
3337
'Compute the centroid of each group'
34-
return [tuple(map(mean, zip(*group))) for group in groups]
35-
36-
def quality(labeled: Dict[Centroid, Sequence[Point]]) -> float:
37-
'Mean value of squared distances from data to its assigned centroid'
38-
return mean(dist(c, p) ** 2 for c, pts in labeled.items() for p in pts)
38+
return [tuple(map(mean, transpose(group))) for group in groups]
3939

4040
def k_means(data: Iterable[Point], k:int=2, iterations:int=10) -> List[Point]:
4141
'Return k-centroids for the data'
@@ -92,6 +92,5 @@ def k_means(data: Iterable[Point], k:int=2, iterations:int=10) -> List[Point]:
9292
# 5583 1338 1202 668 611 409 463
9393
centroids = k_means(data, k=4, iterations=20)
9494
d = assign_data(centroids, data)
95-
print(quality(d))
9695
pprint(d)
9796

0 commit comments

Comments
 (0)