Skip to content

Commit d30c212

Browse files
authored
Check if feature exist when predict target in NaiveBayes (#327)
* Check if feature exist when predict target in NaiveBayes * Fix typo
1 parent 18c36b9 commit d30c212

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/Classification/NaiveBayes.php

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
namespace Phpml\Classification;
66

7+
use Phpml\Exception\InvalidArgumentException;
78
use Phpml\Helper\Predictable;
89
use Phpml\Helper\Trainable;
910
use Phpml\Math\Statistic\Mean;
@@ -137,6 +138,10 @@ private function calculateStatistics(string $label, array $samples): void
137138
*/
138139
private function sampleProbability(array $sample, int $feature, string $label): float
139140
{
141+
if (!isset($sample[$feature])) {
142+
throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
143+
}
144+
140145
$value = $sample[$feature];
141146
if ($this->dataType[$label][$feature] == self::NOMINAL) {
142147
if (!isset($this->discreteProb[$label][$feature][$value]) ||

tests/Classification/NaiveBayesTest.php

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
namespace Phpml\Tests\Classification;
66

77
use Phpml\Classification\NaiveBayes;
8+
use Phpml\Exception\InvalidArgumentException;
89
use Phpml\ModelManager;
910
use PHPUnit\Framework\TestCase;
1011

@@ -125,4 +126,19 @@ public function testSaveAndRestoreNumericLabels(): void
125126
self::assertEquals($classifier, $restoredClassifier);
126127
self::assertEquals($predicted, $restoredClassifier->predict($testSamples));
127128
}
129+
130+
public function testInconsistentFeaturesInSamples(): void
131+
{
132+
$trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]];
133+
$trainLabels = ['1996', '1997', '1998'];
134+
135+
$testSamples = [[3, 1, 1], [5, 1], [4, 3, 8]];
136+
137+
$classifier = new NaiveBayes();
138+
$classifier->train($trainSamples, $trainLabels);
139+
140+
$this->expectException(InvalidArgumentException::class);
141+
142+
$classifier->predict($testSamples);
143+
}
128144
}

0 commit comments

Comments
 (0)