Skip to content

Commit af2d732

Browse files
IcaVesticaakondas
authored andcommitted
KMeans associative clustering (#262)
* KMeans associative clustering added * fix travis error * KMeans will return provided keys as point label if they are provided * fix travis * fix travis
1 parent 0d80c78 commit af2d732

File tree

6 files changed

+48
-10
lines changed

6 files changed

+48
-10
lines changed

docs/machine-learning/clustering/k-means.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ To divide the samples into clusters simply use `cluster` method. It's return the
1919

2020
```
2121
$samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]];
22+
Or if you need to keep your indentifiers along with yours samples you can use array keys as labels.
23+
$samples = [ 'Label1' => [1, 1], 'Label2' => [8, 7], 'Label3' => [1, 2]];
2224
2325
$kmeans = new KMeans(2);
2426
$kmeans->cluster($samples);
25-
// return [0=>[[1, 1], ...], 1=>[[8, 7], ...]]
27+
// return [0=>[[1, 1], ...], 1=>[[8, 7], ...]] or [0=>['Label1' => [1, 1], 'Label3' => [1, 2], ...], 1=>['Label2' => [8, 7], ...]]
2628
```
2729

2830
### Initialization methods

src/Clustering/KMeans.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ public function __construct(int $clustersNumber, int $initialization = self::INI
3535

3636
public function cluster(array $samples): array
3737
{
38-
$space = new Space(count($samples[0]));
39-
foreach ($samples as $sample) {
40-
$space->addPoint($sample);
38+
$space = new Space(count(reset($samples)));
39+
foreach ($samples as $key => $sample) {
40+
$space->addPoint($sample, $key);
4141
}
4242

4343
$clusters = [];

src/Clustering/KMeans/Cluster.php

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ public function getPoints(): array
3232
{
3333
$points = [];
3434
foreach ($this->points as $point) {
35-
$points[] = $point->toArray();
35+
if (!empty($point->label)) {
36+
$points[$point->label] = $point->toArray();
37+
} else {
38+
$points[] = $point->toArray();
39+
}
3640
}
3741

3842
return $points;

src/Clustering/KMeans/Point.php

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@ class Point implements ArrayAccess
1818
*/
1919
protected $coordinates = [];
2020

21-
public function __construct(array $coordinates)
21+
/**
22+
* @var mixed
23+
*/
24+
protected $label;
25+
26+
public function __construct(array $coordinates, $label = null)
2227
{
2328
$this->dimension = count($coordinates);
2429
$this->coordinates = $coordinates;
30+
$this->label = $label;
2531
}
2632

2733
public function toArray(): array

src/Clustering/KMeans/Space.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@ public function toArray(): array
3535
return ['points' => $points];
3636
}
3737

38-
public function newPoint(array $coordinates): Point
38+
public function newPoint(array $coordinates, $label = null): Point
3939
{
4040
if (count($coordinates) != $this->dimension) {
4141
throw new LogicException('('.implode(',', $coordinates).') is not a point of this space');
4242
}
4343

44-
return new Point($coordinates);
44+
return new Point($coordinates, $label);
4545
}
4646

4747
/**
4848
* @param null $data
4949
*/
50-
public function addPoint(array $coordinates, $data = null): void
50+
public function addPoint(array $coordinates, $label = null, $data = null): void
5151
{
52-
$this->attach($this->newPoint($coordinates), $data);
52+
$this->attach($this->newPoint($coordinates, $label), $data);
5353
}
5454

5555
/**

tests/Clustering/KMeansTest.php

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,32 @@ public function testKMeansSamplesClustering(): void
2828
$this->assertCount(0, $samples);
2929
}
3030

31+
public function testKMeansSamplesLabeledClustering(): void
32+
{
33+
$samples = [
34+
'555' => [1, 1],
35+
'666' => [8, 7],
36+
'ABC' => [1, 2],
37+
'DEF' => [7, 8],
38+
668 => [2, 1],
39+
[8, 9],
40+
];
41+
42+
$kmeans = new KMeans(2);
43+
$clusters = $kmeans->cluster($samples);
44+
45+
$this->assertCount(2, $clusters);
46+
47+
foreach ($samples as $index => $sample) {
48+
if (in_array($sample, $clusters[0], true) || in_array($sample, $clusters[1], true)) {
49+
$this->assertArrayHasKey($index, $clusters[0] + $clusters[1]);
50+
unset($samples[$index]);
51+
}
52+
}
53+
54+
$this->assertCount(0, $samples);
55+
}
56+
3157
public function testKMeansInitializationMethods(): void
3258
{
3359
$samples = [

0 commit comments

Comments
 (0)