@@ -306,13 +306,35 @@ def predict(example):
306306# ______________________________________________________________________________ 
307307
308308
309- def  NaiveBayesLearner (dataset , continuous = True ):
309+ def  NaiveBayesLearner (dataset , continuous = True , simple = False ):
310+     if  simple :
311+         return  NaiveBayesSimple (dataset )
310312    if (continuous ):
311313        return  NaiveBayesContinuous (dataset )
312314    else :
313315        return  NaiveBayesDiscrete (dataset )
314316
315317
318+ def  NaiveBayesSimple (distribution ):
319+     """A simple naive bayes classifier that takes as input a dictionary of 
320+     CountingProbDist objects and classifies items according to these distributions. 
321+     The input dictionary is in the following form: 
322+         (ClassName, ClassProb): CountingProbDist""" 
323+     target_dist  =  {c_name : prob  for  c_name , prob  in  distribution .keys ()}
324+     attr_dists  =  {c_name : count_prob  for  (c_name , _ ), count_prob  in  distribution .items ()}
325+ 
326+     def  predict (example ):
327+         """Predict the target value for example. Calculate probabilities for each 
328+         class and pick the max.""" 
329+         def  class_probability (targetval ):
330+             attr_dist  =  attr_dists [targetval ]
331+             return  target_dist [targetval ] *  product (attr_dist [a ] for  a  in  example )
332+ 
333+         return  argmax (target_dist .keys (), key = class_probability )
334+ 
335+     return  predict 
336+ 
337+ 
316338def  NaiveBayesDiscrete (dataset ):
317339    """Just count how many times each value of each input attribute 
318340    occurs, conditional on the target value. Count the different 
0 commit comments