@@ -287,26 +287,73 @@ def enumerate_all(vars, e, bn):
287287
288288#______________________________________________________________________________
289289
290- def elimination_ask (X , e , bn , order = reversed ):
291- "[Fig. 14.11]"
290+ def elimination_ask (X , e , bn ):
291+ """[Fig. 14.11]
292+ >>> elimination_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary
293+ ... ).show_approx()
294+ 'False: 0.716, True: 0.284'"""
292295 factors = []
293- for var in order (bn .vars ):
294- factors .append (Factor (var , e ))
296+ for var in reversed (bn .vars ):
297+ factors .append (make_factor (var , e , bn ))
295298 if is_hidden (var , X , e ):
296- factors = sum_out (var , factors )
297- return pointwise_product (factors ).normalize ()
299+ factors = sum_out (var , factors , bn )
300+ return pointwise_product (factors , bn ).normalize ()
298301
299302def is_hidden (var , X , e ):
300303 return var != X and var not in e
301304
302- def Factor (var , e ):
303- unimplemented ()
305+ def make_factor (var , e , bn ):
306+ node = bn .variable_node (var )
307+ vars = [X for X in [var ] + node .parents if X not in e ]
308+ cpt = dict ((event_values (e1 , vars ), node .p (e1 [var ], e1 ))
309+ for e1 in all_events (vars , bn , e ))
310+ return Factor (vars , cpt )
311+
312+ def pointwise_product (factors , bn ):
313+ return reduce (lambda f , g : f .pointwise_product (g , bn ), factors )
314+
315+ def sum_out (var , factors , bn ):
316+ result , var_factors = [], []
317+ for f in factors :
318+ (var_factors if var in f .vars else result ).append (f )
319+ result .append (pointwise_product (var_factors , bn ).sum_out (var , bn ))
320+ return result
321+
322+ class Factor :
323+
324+ def __init__ (self , vars , cpt ):
325+ update (self , vars = vars , cpt = cpt )
326+
327+ def pointwise_product (self , other , bn ):
328+ vars = list (set (self .vars ) | set (other .vars ))
329+ cpt = dict ((event_values (e , vars ), self .p (e ) * other .p (e ))
330+ for e in all_events (vars , bn , {}))
331+ return Factor (vars , cpt )
332+
333+ def sum_out (self , var , bn ):
334+ vars = [X for X in self .vars if X != var ]
335+ cpt = dict ((event_values (e , vars ),
336+ sum (self .p (extend (e , var , val ))
337+ for val in bn .variable_values (var )))
338+ for e in all_events (vars , bn , {}))
339+ return Factor (vars , cpt )
304340
305- def pointwise_product (factors ):
306- unimplemented ()
341+ def normalize (self ):
342+ assert len (self .vars ) == 1
343+ return ProbDist (self .vars [0 ],
344+ dict ((k , v ) for ((k ,), v ) in self .cpt .items ()))
307345
308- def sum_out (var , factors ):
309- unimplemented ()
346+ def p (self , e ):
347+ return self .cpt [event_values (e , self .vars )]
348+
349+ def all_events (vars , bn , e1 ):
350+ if not vars :
351+ yield e1
352+ else :
353+ X , rest = vars [0 ], vars [1 :]
354+ for e in all_events (rest , bn , e1 ):
355+ for x in bn .variable_values (X ):
356+ yield extend (e , X , x )
310357
311358#______________________________________________________________________________
312359
0 commit comments