Skip to content

Commit 2bf0dce

Browse files
authored
fix unittest animation bugs (AtsushiSakai#429)
* fix unittest animation bugs * exstract a function
1 parent 48b6ee3 commit 2bf0dce

File tree

5 files changed

+92
-64
lines changed

5 files changed

+92
-64
lines changed

PathPlanning/AStar/a_star_variants.py

Lines changed: 75 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
2-
astar variants
2+
a star variants
33
author: Sarim Mehdi([email protected])
44
Source: http://theory.stanford.edu/~amitp/GameProgramming/Variations.html
55
"""
66

77
import numpy as np
88
import matplotlib.pyplot as plt
99

10-
show_animation = False
10+
show_animation = True
1111
use_beam_search = False
1212
use_iterative_deepening = False
1313
use_dynamic_weighting = False
@@ -89,7 +89,8 @@ def key_points(o_dict):
8989
obs_count += 1
9090
if obs_count == 3 or obs_count == 1:
9191
c_list.append((x, y))
92-
plt.plot(x, y, ".y")
92+
if show_animation:
93+
plt.plot(x, y, ".y")
9394
break
9495
if only_corners:
9596
return c_list
@@ -148,7 +149,8 @@ def __init__(self, obs_grid, goal_x, goal_y, start_x, start_y,
148149
'open': True, 'in_open_list': True}
149150
self.open_set.append(self.all_nodes[tuple(self.start_pt)])
150151

151-
def get_hval(self, x1, y1, x2, y2):
152+
@staticmethod
153+
def get_hval(x1, y1, x2, y2):
152154
x, y = x1, y1
153155
val = 0
154156
while x != x2 or y != y2:
@@ -178,7 +180,6 @@ def get_farthest_point(self, x, y, i, j):
178180
def jump_point(self):
179181
"""Jump point: Instead of exploring all empty spaces of the
180182
map, just explore the corners."""
181-
plt.title('Jump Point')
182183

183184
goal_found = False
184185
while len(self.open_set) > 0:
@@ -232,35 +233,40 @@ def jump_point(self):
232233
if not self.all_nodes[cand_pt]['in_open_list']:
233234
self.open_set.append(self.all_nodes[cand_pt])
234235
self.all_nodes[cand_pt]['in_open_list'] = True
235-
plt.plot(cand_pt[0], cand_pt[1], "r*")
236+
if show_animation:
237+
plt.plot(cand_pt[0], cand_pt[1], "r*")
236238

237239
if goal_found:
238240
break
239-
plt.pause(0.001)
241+
if show_animation:
242+
plt.pause(0.001)
240243
if goal_found:
241244
current_node = self.all_nodes[tuple(self.goal_pt)]
242245
while goal_found:
243246
if current_node['pred'] is None:
244247
break
245248
x = [current_node['pos'][0], current_node['pred'][0]]
246249
y = [current_node['pos'][1], current_node['pred'][1]]
247-
plt.plot(x, y, "b")
248250
current_node = self.all_nodes[tuple(current_node['pred'])]
249-
plt.pause(0.001)
251+
if show_animation:
252+
plt.plot(x, y, "b")
253+
plt.pause(0.001)
250254
if goal_found:
251255
break
252256

253257
current_node['open'] = False
254258
current_node['in_open_list'] = False
255-
plt.plot(current_node['pos'][0], current_node['pos'][1], "g*")
259+
if show_animation:
260+
plt.plot(current_node['pos'][0], current_node['pos'][1], "g*")
256261
del self.open_set[p]
257262
current_node['fcost'], current_node['hcost'] = np.inf, np.inf
258263
if show_animation:
264+
plt.title('Jump Point')
259265
plt.show()
260266

261-
def astar(self):
267+
def a_star(self):
262268
"""Beam search: Maintain an open list of just 30 nodes.
263-
If more than 30 nodes, then get rid of ndoes with high
269+
If more than 30 nodes, then get rid of nodes with high
264270
f values.
265271
Iterative deepening: At every iteration, get a cut-off
266272
value for the f cost. This cut-off is minimum of the f
@@ -275,21 +281,23 @@ def astar(self):
275281
one neighbor at a time. In fact, you can look for the
276282
next node as far out as you can as long as there is a
277283
clear line of sight from your current node to that node."""
278-
if use_beam_search:
279-
plt.title('A* with beam search')
280-
elif use_iterative_deepening:
281-
plt.title('A* with iterative deepening')
282-
elif use_dynamic_weighting:
283-
plt.title('A* with dynamic weighting')
284-
elif use_theta_star:
285-
plt.title('Theta*')
286-
else:
287-
plt.title('A*')
284+
if show_animation:
285+
if use_beam_search:
286+
plt.title('A* with beam search')
287+
elif use_iterative_deepening:
288+
plt.title('A* with iterative deepening')
289+
elif use_dynamic_weighting:
290+
plt.title('A* with dynamic weighting')
291+
elif use_theta_star:
292+
plt.title('Theta*')
293+
else:
294+
plt.title('A*')
288295

289296
goal_found = False
290297
curr_f_thresh = np.inf
291298
depth = 0
292299
no_valid_f = False
300+
w = None
293301
while len(self.open_set) > 0:
294302
self.open_set = sorted(self.open_set, key=lambda x: x['fcost'])
295303
lowest_f = self.open_set[0]['fcost']
@@ -351,30 +359,14 @@ def astar(self):
351359
break
352360

353361
cand_pt = tuple(cand_pt)
354-
if not self.obs_grid[tuple(cand_pt)] and \
355-
self.all_nodes[cand_pt]['open']:
356-
g_cost = offset + current_node['gcost']
357-
h_cost = self.all_nodes[cand_pt]['hcost']
358-
if use_dynamic_weighting:
359-
h_cost = h_cost * w
360-
f_cost = g_cost + h_cost
361-
if f_cost < self.all_nodes[cand_pt]['fcost'] and \
362-
f_cost <= curr_f_thresh:
363-
f_cost_list.append(f_cost)
364-
self.all_nodes[cand_pt]['pred'] = \
365-
current_node['pos']
366-
self.all_nodes[cand_pt]['gcost'] = g_cost
367-
self.all_nodes[cand_pt]['fcost'] = f_cost
368-
if not self.all_nodes[cand_pt]['in_open_list']:
369-
self.open_set.append(self.all_nodes[cand_pt])
370-
self.all_nodes[cand_pt]['in_open_list'] = True
371-
plt.plot(cand_pt[0], cand_pt[1], "r*")
372-
if curr_f_thresh < f_cost < \
373-
self.all_nodes[cand_pt]['fcost']:
374-
no_valid_f = True
362+
no_valid_f = self.update_node_cost(cand_pt, curr_f_thresh,
363+
current_node,
364+
f_cost_list, no_valid_f,
365+
offset, w)
375366
if goal_found:
376367
break
377-
plt.pause(0.001)
368+
if show_animation:
369+
plt.pause(0.001)
378370
if goal_found:
379371
current_node = self.all_nodes[tuple(self.goal_pt)]
380372
while goal_found:
@@ -383,12 +375,13 @@ def astar(self):
383375
if use_theta_star or use_jump_point:
384376
x, y = [current_node['pos'][0], current_node['pred'][0]], \
385377
[current_node['pos'][1], current_node['pred'][1]]
386-
plt.plot(x, y, "b")
378+
if show_animation:
379+
plt.plot(x, y, "b")
387380
else:
388-
plt.plot(current_node['pred'][0],
389-
current_node['pred'][1], "b*")
381+
if show_animation:
382+
plt.plot(current_node['pred'][0],
383+
current_node['pred'][1], "b*")
390384
current_node = self.all_nodes[tuple(current_node['pred'])]
391-
plt.pause(0.001)
392385
if goal_found:
393386
break
394387

@@ -402,13 +395,40 @@ def astar(self):
402395

403396
current_node['open'] = False
404397
current_node['in_open_list'] = False
405-
plt.plot(current_node['pos'][0], current_node['pos'][1], "g*")
398+
if show_animation:
399+
plt.plot(current_node['pos'][0], current_node['pos'][1], "g*")
406400
del self.open_set[p]
407401
current_node['fcost'], current_node['hcost'] = np.inf, np.inf
408402
depth += 1
409403
if show_animation:
410404
plt.show()
411405

406+
def update_node_cost(self, cand_pt, curr_f_thresh, current_node,
407+
f_cost_list, no_valid_f, offset, w):
408+
if not self.obs_grid[tuple(cand_pt)] and \
409+
self.all_nodes[cand_pt]['open']:
410+
g_cost = offset + current_node['gcost']
411+
h_cost = self.all_nodes[cand_pt]['hcost']
412+
if use_dynamic_weighting:
413+
h_cost = h_cost * w
414+
f_cost = g_cost + h_cost
415+
if f_cost < self.all_nodes[cand_pt]['fcost'] and \
416+
f_cost <= curr_f_thresh:
417+
f_cost_list.append(f_cost)
418+
self.all_nodes[cand_pt]['pred'] = \
419+
current_node['pos']
420+
self.all_nodes[cand_pt]['gcost'] = g_cost
421+
self.all_nodes[cand_pt]['fcost'] = f_cost
422+
if not self.all_nodes[cand_pt]['in_open_list']:
423+
self.open_set.append(self.all_nodes[cand_pt])
424+
self.all_nodes[cand_pt]['in_open_list'] = True
425+
if show_animation:
426+
plt.plot(cand_pt[0], cand_pt[1], "r*")
427+
if curr_f_thresh < f_cost < \
428+
self.all_nodes[cand_pt]['fcost']:
429+
no_valid_f = True
430+
return no_valid_f
431+
412432

413433
def main():
414434
# set obstacle positions
@@ -443,10 +463,11 @@ def main():
443463
for x, y, l in zip(all_x, all_y, all_len):
444464
draw_horizontal_line(x, y, l, o_x, o_y, obs_dict)
445465

446-
plt.plot(o_x, o_y, ".k")
447-
plt.plot(s_x, s_y, "og")
448-
plt.plot(g_x, g_y, "xb")
449-
plt.grid(True)
466+
if show_animation:
467+
plt.plot(o_x, o_y, ".k")
468+
plt.plot(s_x, s_y, "og")
469+
plt.plot(g_x, g_y, "xb")
470+
plt.grid(True)
450471

451472
if use_jump_point:
452473
keypoint_list = key_points(obs_dict)
@@ -455,7 +476,7 @@ def main():
455476
search_obj.jump_point()
456477
else:
457478
search_obj = SearchAlgo(obs_dict, g_x, g_y, s_x, s_y, 101, 101)
458-
search_obj.astar()
479+
search_obj.a_star()
459480

460481

461482
if __name__ == '__main__':

PathPlanning/FlowField/flowfield.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,12 @@ def assign_vectors(self):
129129
def follow_vectors(self):
130130
curr_x, curr_y = self.start_pt
131131
while curr_x is not None and curr_y is not None:
132-
plt.plot(curr_x, curr_y, "b*")
133132
curr_x, curr_y = self.vector_field[(curr_x, curr_y)]
134-
plt.pause(0.001)
133+
134+
if show_animation:
135+
plt.plot(curr_x, curr_y, "b*")
136+
plt.pause(0.001)
137+
135138
if show_animation:
136139
plt.show()
137140

@@ -208,12 +211,13 @@ def main():
208211
for x, y, l in zip(all_x, all_y, all_len):
209212
draw_horizontal_line(x, y, l, h_x, h_y, obs_dict, 'hard')
210213

211-
plt.plot(o_x, o_y, "sr")
212-
plt.plot(m_x, m_y, "sg")
213-
plt.plot(h_x, h_y, "sy")
214-
plt.plot(s_x, s_y, "og")
215-
plt.plot(g_x, g_y, "o")
216-
plt.grid(True)
214+
if show_animation:
215+
plt.plot(o_x, o_y, "sr")
216+
plt.plot(m_x, m_y, "sg")
217+
plt.plot(h_x, h_y, "sy")
218+
plt.plot(s_x, s_y, "og")
219+
plt.plot(g_x, g_y, "o")
220+
plt.grid(True)
217221

218222
flow_obj = FlowField(obs_dict, g_x, g_y, s_x, s_y, 50, 50)
219223
flow_obj.find_path()

tests/test_a_star_searching_two_side.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
sys.path.append(os.path.dirname(__file__) + '/../')
66

77
try:
8-
from PathPlanning.AStar import A_Star_searching_from_two_side as m
8+
from PathPlanning.AStar import a_star_searching_from_two_side as m
99
except ImportError:
1010
raise
1111

tests/test_a_star_variants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class Test(TestCase):
1010
def test(self):
1111
# A* with beam search
1212
astar.show_animation = False
13+
1314
astar.use_beam_search = True
1415
astar.main()
1516
self.reset_all()
@@ -34,7 +35,9 @@ def test(self):
3435
astar.main()
3536
self.reset_all()
3637

37-
def reset_all(self):
38+
@staticmethod
39+
def reset_all():
40+
astar.show_animation = False
3841
astar.use_beam_search = False
3942
astar.use_iterative_deepening = False
4043
astar.use_dynamic_weighting = False

0 commit comments

Comments
 (0)