Skip to content

Commit 856e8d9

Browse files
antmarakisnorvig
authored andcommitted
Implementation: Continuous Naive Bayes (aimacode#435)
* Add Gaussian Function * Added Tests Add tests for Continuous Naive Bayes + Means/Standard Deviation * Update learning.py * Commenting Fix * Add test for gaussian * test for every class * Update test_learning.py * Round float results to make sure test passes
1 parent e9c2d07 commit 856e8d9

File tree

4 files changed

+107
-7
lines changed

4 files changed

+107
-7
lines changed

learning.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Learn to estimate functions from examples. (Chapters 18-20)"""
22

33
from utils import (
4-
removeall, unique, product, mode, argmax, argmax_random_tie, isclose,
4+
removeall, unique, product, mode, argmax, argmax_random_tie, isclose, gaussian,
55
dotproduct, vector_add, scalar_vector_product, weighted_sample_with_replacement,
66
weighted_sampler, num_or_str, normalize, clip, sigmoid, print_table, DataFile
77
)
@@ -11,7 +11,7 @@
1111
import math
1212
import random
1313

14-
from statistics import mean
14+
from statistics import mean, stdev
1515
from collections import defaultdict
1616

1717
# ______________________________________________________________________________
@@ -178,6 +178,45 @@ def remove_examples(self, value=""):
178178
self.examples = [x for x in self.examples if value not in x]
179179
self.update_values()
180180

181+
def split_values_by_classes(self):
182+
"""Split values into buckets according to their class."""
183+
buckets = defaultdict(lambda: [])
184+
target_names = self.values[self.target]
185+
186+
for v in self.examples:
187+
item = [a for a in v if a not in target_names] # Remove target from item
188+
buckets[v[self.target]].append(item) # Add item to bucket of its class
189+
190+
return buckets
191+
192+
def find_means_and_deviations(self):
193+
"""Finds the means and standard deviations of self.dataset.
194+
means : A dictionary for each class/target. Holds a list of the means
195+
of the features for the class.
196+
deviations: A dictionary for each class/target. Holds a list of the sample
197+
standard deviations of the features for the class."""
198+
target_names = self.values[self.target]
199+
feature_numbers = len(self.inputs)
200+
201+
item_buckets = self.split_values_by_classes()
202+
203+
means = defaultdict(lambda: [0 for i in range(feature_numbers)])
204+
deviations = defaultdict(lambda: [0 for i in range(feature_numbers)])
205+
206+
for t in target_names:
207+
# Find all the item feature values for item in class t
208+
features = [[] for i in range(feature_numbers)]
209+
for item in item_buckets[t]:
210+
features = [features[i] + [item[i]] for i in range(feature_numbers)]
211+
212+
# Calculate means and deviations fo the class
213+
for i in range(feature_numbers):
214+
means[t][i] = mean(features[i])
215+
deviations[t][i] = stdev(features[i])
216+
217+
return means, deviations
218+
219+
181220
def __repr__(self):
182221
return '<DataSet({}): {:d} examples, {:d} attributes>'.format(
183222
self.name, len(self.examples), len(self.attrs))
@@ -267,15 +306,22 @@ def predict(example):
267306
# ______________________________________________________________________________
268307

269308

270-
def NaiveBayesLearner(dataset):
309+
def NaiveBayesLearner(dataset, continuous=True):
310+
if(continuous):
311+
return NaiveBayesContinuous(dataset)
312+
else:
313+
return NaiveBayesDiscrete(dataset)
314+
315+
316+
def NaiveBayesDiscrete(dataset):
271317
"""Just count how many times each value of each input attribute
272318
occurs, conditional on the target value. Count the different
273319
target values too."""
274320

275-
targetvals = dataset.values[dataset.target]
276-
target_dist = CountingProbDist(targetvals)
321+
target_vals = dataset.values[dataset.target]
322+
target_dist = CountingProbDist(target_vals)
277323
attr_dists = {(gv, attr): CountingProbDist(dataset.values[attr])
278-
for gv in targetvals
324+
for gv in target_vals
279325
for attr in dataset.inputs}
280326
for example in dataset.examples:
281327
targetval = example[dataset.target]
@@ -290,7 +336,29 @@ def class_probability(targetval):
290336
return (target_dist[targetval] *
291337
product(attr_dists[targetval, attr][example[attr]]
292338
for attr in dataset.inputs))
293-
return argmax(targetvals, key=class_probability)
339+
return argmax(target_vals, key=class_probability)
340+
341+
return predict
342+
343+
344+
def NaiveBayesContinuous(dataset):
345+
"""Count how many times each target value occurs.
346+
Also, find the means and deviations of input attribute values for each target value."""
347+
means, deviations = dataset.find_means_and_deviations()
348+
349+
target_vals = dataset.values[dataset.target]
350+
target_dist = CountingProbDist(target_vals)
351+
352+
def predict(example):
353+
"""Predict the target value for example. Consider each possible value,
354+
and pick the most likely by looking at each attribute independently."""
355+
def class_probability(targetval):
356+
prob = target_dist[targetval]
357+
for attr in dataset.inputs:
358+
prob *= gaussian(means[targetval][attr], deviations[targetval][attr], example[attr])
359+
return prob
360+
361+
return argmax(target_vals, key=class_probability)
294362

295363
return predict
296364

tests/test_learning.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ def test_weighted_replicate():
3535
assert weighted_replicate('ABC', [1, 2, 1], 4) == ['A', 'B', 'B', 'C']
3636

3737

38+
def test_means_and_deviation():
39+
iris = DataSet(name="iris")
40+
41+
means, deviations = iris.find_means_and_deviations()
42+
43+
assert round(means["setosa"][0], 3) == 5.006
44+
assert round(means["versicolor"][0], 3) == 5.936
45+
assert round(means["virginica"][0], 3) == 6.588
46+
47+
assert round(deviations["setosa"][0], 3) == 0.352
48+
assert round(deviations["versicolor"][0], 3) == 0.516
49+
assert round(deviations["virginica"][0], 3) == 0.636
50+
51+
3852
def test_plurality_learner():
3953
zoo = DataSet(name="zoo")
4054

@@ -48,6 +62,14 @@ def test_naive_bayes():
4862
# Discrete
4963
nBD = NaiveBayesLearner(iris)
5064
assert nBD([5, 3, 1, 0.1]) == "setosa"
65+
assert nBD([6, 5, 3, 1.5]) == "versicolor"
66+
assert nBD([7, 3, 6.5, 2]) == "virginica"
67+
68+
# Continuous
69+
nBC = NaiveBayesLearner(iris, continuous=True)
70+
assert nBC([5, 3, 1, 0.1]) == "setosa"
71+
assert nBC([6, 5, 3, 1.5]) == "versicolor"
72+
assert nBC([7, 3, 6.5, 2]) == "virginica"
5173

5274

5375
def test_k_nearest_neighbors():

tests/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ def test_sigmoid():
148148
assert isclose(0.2689414213699951, sigmoid(-1))
149149

150150

151+
def test_gaussian():
152+
assert gaussian(1,0.5,0.7) == 0.6664492057835993
153+
assert gaussian(5,2,4.5) == 0.19333405840142462
154+
assert gaussian(3,1,3) == 0.3989422804014327
155+
156+
151157
def test_step():
152158
assert step(1) == step(0.5) == 1
153159
assert step(0) == 1

utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ def step(x):
258258
"""Return activation value of x with sign function"""
259259
return 1 if x >= 0 else 0
260260

261+
def gaussian(mean, st_dev, x):
262+
"""Given the mean and standard deviation of a distribution, it returns the probability of x."""
263+
return 1/(math.sqrt(2*math.pi)*st_dev)*math.e**(-0.5*(float(x-mean)/st_dev)**2)
264+
261265

262266
try: # math.isclose was added in Python 3.5; but we might be in 3.4
263267
from math import isclose

0 commit comments

Comments
 (0)