11"""Learn to estimate functions from examples. (Chapters 18-20)""" 
22
33from  utils  import  (
4-     removeall , unique , product , mode , argmax , argmax_random_tie , isclose ,
4+     removeall , unique , product , mode , argmax , argmax_random_tie , isclose ,  gaussian , 
55    dotproduct , vector_add , scalar_vector_product , weighted_sample_with_replacement ,
66    weighted_sampler , num_or_str , normalize , clip , sigmoid , print_table , DataFile 
77)
1111import  math 
1212import  random 
1313
14- from  statistics  import  mean 
14+ from  statistics  import  mean ,  stdev 
1515from  collections  import  defaultdict 
1616
1717# ______________________________________________________________________________ 
@@ -178,6 +178,45 @@ def remove_examples(self, value=""):
178178        self .examples  =  [x  for  x  in  self .examples  if  value  not  in   x ]
179179        self .update_values ()
180180
181+     def  split_values_by_classes (self ):
182+         """Split values into buckets according to their class.""" 
183+         buckets  =  defaultdict (lambda : [])
184+         target_names  =  self .values [self .target ]
185+ 
186+         for  v  in  self .examples :
187+             item  =  [a  for  a  in  v  if  a  not  in   target_names ] # Remove target from item 
188+             buckets [v [self .target ]].append (item ) # Add item to bucket of its class 
189+ 
190+         return  buckets 
191+ 
192+     def  find_means_and_deviations (self ):
193+         """Finds the means and standard deviations of self.dataset. 
194+         means     : A dictionary for each class/target. Holds a list of the means 
195+                     of the features for the class. 
196+         deviations: A dictionary for each class/target. Holds a list of the sample 
197+                     standard deviations of the features for the class.""" 
198+         target_names  =  self .values [self .target ]
199+         feature_numbers  =  len (self .inputs )
200+ 
201+         item_buckets  =  self .split_values_by_classes ()
202+         
203+         means  =  defaultdict (lambda : [0  for  i  in  range (feature_numbers )])
204+         deviations  =  defaultdict (lambda : [0  for  i  in  range (feature_numbers )])
205+ 
206+         for  t  in  target_names :
207+             # Find all the item feature values for item in class t 
208+             features  =  [[] for  i  in  range (feature_numbers )]
209+             for  item  in  item_buckets [t ]:
210+                 features  =  [features [i ] +  [item [i ]] for  i  in  range (feature_numbers )]
211+ 
212+             # Calculate means and deviations fo the class 
213+             for  i  in  range (feature_numbers ):
214+                 means [t ][i ] =  mean (features [i ])
215+                 deviations [t ][i ] =  stdev (features [i ])
216+ 
217+         return  means , deviations 
218+ 
219+ 
181220    def  __repr__ (self ):
182221        return  '<DataSet({}): {:d} examples, {:d} attributes>' .format (
183222            self .name , len (self .examples ), len (self .attrs ))
@@ -267,15 +306,22 @@ def predict(example):
267306# ______________________________________________________________________________ 
268307
269308
270- def  NaiveBayesLearner (dataset ):
309+ def  NaiveBayesLearner (dataset , continuous = True ):
310+     if (continuous ):
311+         return  NaiveBayesContinuous (dataset )
312+     else :
313+         return  NaiveBayesDiscrete (dataset )
314+ 
315+ 
316+ def  NaiveBayesDiscrete (dataset ):
271317    """Just count how many times each value of each input attribute 
272318    occurs, conditional on the target value. Count the different 
273319    target values too.""" 
274320
275-     targetvals  =  dataset .values [dataset .target ]
276-     target_dist  =  CountingProbDist (targetvals )
321+     target_vals  =  dataset .values [dataset .target ]
322+     target_dist  =  CountingProbDist (target_vals )
277323    attr_dists  =  {(gv , attr ): CountingProbDist (dataset .values [attr ])
278-                   for  gv  in  targetvals 
324+                   for  gv  in  target_vals 
279325                  for  attr  in  dataset .inputs }
280326    for  example  in  dataset .examples :
281327        targetval  =  example [dataset .target ]
@@ -290,7 +336,29 @@ def class_probability(targetval):
290336            return  (target_dist [targetval ] * 
291337                    product (attr_dists [targetval , attr ][example [attr ]]
292338                            for  attr  in  dataset .inputs ))
293-         return  argmax (targetvals , key = class_probability )
339+         return  argmax (target_vals , key = class_probability )
340+ 
341+     return  predict 
342+ 
343+ 
344+ def  NaiveBayesContinuous (dataset ):
345+     """Count how many times each target value occurs. 
346+     Also, find the means and deviations of input attribute values for each target value.""" 
347+     means , deviations  =  dataset .find_means_and_deviations ()
348+ 
349+     target_vals  =  dataset .values [dataset .target ]
350+     target_dist  =  CountingProbDist (target_vals )
351+ 
352+     def  predict (example ):
353+         """Predict the target value for example. Consider each possible value, 
354+         and pick the most likely by looking at each attribute independently.""" 
355+         def  class_probability (targetval ):
356+             prob  =  target_dist [targetval ]
357+             for  attr  in  dataset .inputs :
358+                 prob  *=  gaussian (means [targetval ][attr ], deviations [targetval ][attr ], example [attr ])
359+             return  prob 
360+ 
361+         return  argmax (target_vals , key = class_probability )
294362
295363    return  predict 
296364
0 commit comments