@@ -886,3 +886,159 @@ def draw_table(self):
886886 self .fill (0 , 0 , 0 )
887887 self .text_n (self .table [self .context [0 ]][self .context [1 ]] if self .context else "Click for text" , 0.025 , 0.975 )
888888 self .update ()
889+
890+ ############################################################################################################
891+
892+ ##################### Functions to assist plotting in search.ipynb ####################
893+
894+ ############################################################################################################
895+ import networkx as nx
896+ import matplotlib .pyplot as plt
897+ from matplotlib import lines
898+
899+ from ipywidgets import interact
900+ import ipywidgets as widgets
901+ from IPython .display import display
902+ import time
903+ from search import GraphProblem , romania_map
904+
905+ def show_map (graph_data , node_colors = None ):
906+ G = nx .Graph (graph_data ['graph_dict' ])
907+ node_colors = node_colors or graph_data ['node_colors' ]
908+ node_positions = graph_data ['node_positions' ]
909+ node_label_pos = graph_data ['node_label_positions' ]
910+ edge_weights = graph_data ['edge_weights' ]
911+
912+ # set the size of the plot
913+ plt .figure (figsize = (18 ,13 ))
914+ # draw the graph (both nodes and edges) with locations from romania_locations
915+ nx .draw (G , pos = {k : node_positions [k ] for k in G .nodes ()},
916+ node_color = [node_colors [node ] for node in G .nodes ()], linewidths = 0.3 , edgecolors = 'k' )
917+
918+ # draw labels for nodes
919+ node_label_handles = nx .draw_networkx_labels (G , pos = node_label_pos , font_size = 14 )
920+
921+ # add a white bounding box behind the node labels
922+ [label .set_bbox (dict (facecolor = 'white' , edgecolor = 'none' )) for label in node_label_handles .values ()]
923+
924+ # add edge lables to the graph
925+ nx .draw_networkx_edge_labels (G , pos = node_positions , edge_labels = edge_weights , font_size = 14 )
926+
927+ # add a legend
928+ white_circle = lines .Line2D ([], [], color = "white" , marker = 'o' , markersize = 15 , markerfacecolor = "white" )
929+ orange_circle = lines .Line2D ([], [], color = "orange" , marker = 'o' , markersize = 15 , markerfacecolor = "orange" )
930+ red_circle = lines .Line2D ([], [], color = "red" , marker = 'o' , markersize = 15 , markerfacecolor = "red" )
931+ gray_circle = lines .Line2D ([], [], color = "gray" , marker = 'o' , markersize = 15 , markerfacecolor = "gray" )
932+ green_circle = lines .Line2D ([], [], color = "green" , marker = 'o' , markersize = 15 , markerfacecolor = "green" )
933+ plt .legend ((white_circle , orange_circle , red_circle , gray_circle , green_circle ),
934+ ('Un-explored' , 'Frontier' , 'Currently Exploring' , 'Explored' , 'Final Solution' ),
935+ numpoints = 1 ,prop = {'size' :16 }, loc = (.8 ,.75 ))
936+
937+ # show the plot. No need to use in notebooks. nx.draw will show the graph itself.
938+ plt .show ()
939+
940+ ## helper functions for visualisations
941+
942+ def final_path_colors (initial_node_colors , problem , solution ):
943+ "returns a node_colors dict of the final path provided the problem and solution"
944+
945+ # get initial node colors
946+ final_colors = dict (initial_node_colors )
947+ # color all the nodes in solution and starting node to green
948+ final_colors [problem .initial ] = "green"
949+ for node in solution :
950+ final_colors [node ] = "green"
951+ return final_colors
952+
953+ def display_visual (graph_data , user_input , algorithm = None , problem = None ):
954+ initial_node_colors = graph_data ['node_colors' ]
955+ if user_input == False :
956+ def slider_callback (iteration ):
957+ # don't show graph for the first time running the cell calling this function
958+ try :
959+ show_map (graph_data , node_colors = all_node_colors [iteration ])
960+ except :
961+ pass
962+ def visualize_callback (Visualize ):
963+ if Visualize is True :
964+ button .value = False
965+
966+ global all_node_colors
967+
968+ iterations , all_node_colors , node = algorithm (problem )
969+ solution = node .solution ()
970+ all_node_colors .append (final_path_colors (all_node_colors [0 ], problem , solution ))
971+
972+ slider .max = len (all_node_colors ) - 1
973+
974+ for i in range (slider .max + 1 ):
975+ slider .value = i
976+ #time.sleep(.5)
977+
978+ slider = widgets .IntSlider (min = 0 , max = 1 , step = 1 , value = 0 )
979+ slider_visual = widgets .interactive (slider_callback , iteration = slider )
980+ display (slider_visual )
981+
982+ button = widgets .ToggleButton (value = False )
983+ button_visual = widgets .interactive (visualize_callback , Visualize = button )
984+ display (button_visual )
985+
986+ if user_input == True :
987+ node_colors = dict (initial_node_colors )
988+ if isinstance (algorithm , dict ):
989+ assert set (algorithm .keys ()).issubset (set (["Breadth First Tree Search" ,
990+ "Depth First Tree Search" ,
991+ "Breadth First Search" ,
992+ "Depth First Graph Search" ,
993+ "Uniform Cost Search" ,
994+ "A-star Search" ]))
995+
996+ algo_dropdown = widgets .Dropdown (description = "Search algorithm: " ,
997+ options = sorted (list (algorithm .keys ())),
998+ value = "Breadth First Tree Search" )
999+ display (algo_dropdown )
1000+ elif algorithm is None :
1001+ print ("No algorithm to run." )
1002+ return 0
1003+
1004+ def slider_callback (iteration ):
1005+ # don't show graph for the first time running the cell calling this function
1006+ try :
1007+ show_map (graph_data , node_colors = all_node_colors [iteration ])
1008+ except :
1009+ pass
1010+
1011+ def visualize_callback (Visualize ):
1012+ if Visualize is True :
1013+ button .value = False
1014+
1015+ problem = GraphProblem (start_dropdown .value , end_dropdown .value , romania_map )
1016+ global all_node_colors
1017+
1018+ user_algorithm = algorithm [algo_dropdown .value ]
1019+
1020+ iterations , all_node_colors , node = user_algorithm (problem )
1021+ solution = node .solution ()
1022+ all_node_colors .append (final_path_colors (all_node_colors [0 ], problem , solution ))
1023+
1024+ slider .max = len (all_node_colors ) - 1
1025+
1026+ for i in range (slider .max + 1 ):
1027+ slider .value = i
1028+ #time.sleep(.5)
1029+
1030+ start_dropdown = widgets .Dropdown (description = "Start city: " ,
1031+ options = sorted (list (node_colors .keys ())), value = "Arad" )
1032+ display (start_dropdown )
1033+
1034+ end_dropdown = widgets .Dropdown (description = "Goal city: " ,
1035+ options = sorted (list (node_colors .keys ())), value = "Fagaras" )
1036+ display (end_dropdown )
1037+
1038+ button = widgets .ToggleButton (value = False )
1039+ button_visual = widgets .interactive (visualize_callback , Visualize = button )
1040+ display (button_visual )
1041+
1042+ slider = widgets .IntSlider (min = 0 , max = 1 , step = 1 , value = 0 )
1043+ slider_visual = widgets .interactive (slider_callback , iteration = slider )
1044+ display (slider_visual )
0 commit comments