@@ -306,13 +306,35 @@ def predict(example):
306306# ______________________________________________________________________________
307307
308308
309- def NaiveBayesLearner (dataset , continuous = True ):
309+ def NaiveBayesLearner (dataset , continuous = True , simple = False ):
310+ if simple :
311+ return NaiveBayesSimple (dataset )
310312 if (continuous ):
311313 return NaiveBayesContinuous (dataset )
312314 else :
313315 return NaiveBayesDiscrete (dataset )
314316
315317
318+ def NaiveBayesSimple (distribution ):
319+ """A simple naive bayes classifier that takes as input a dictionary of
320+ CountingProbDist objects and classifies items according to these distributions.
321+ The input dictionary is in the following form:
322+ (ClassName, ClassProb): CountingProbDist"""
323+ target_dist = {c_name : prob for c_name , prob in distribution .keys ()}
324+ attr_dists = {c_name : count_prob for (c_name , _ ), count_prob in distribution .items ()}
325+
326+ def predict (example ):
327+ """Predict the target value for example. Calculate probabilities for each
328+ class and pick the max."""
329+ def class_probability (targetval ):
330+ attr_dist = attr_dists [targetval ]
331+ return target_dist [targetval ] * product (attr_dist [a ] for a in example )
332+
333+ return argmax (target_dist .keys (), key = class_probability )
334+
335+ return predict
336+
337+
316338def NaiveBayesDiscrete (dataset ):
317339 """Just count how many times each value of each input attribute
318340 occurs, conditional on the target value. Count the different
0 commit comments