1212
1313import matplotlib .pyplot as plt
1414import numpy as np
15-
1615from scipy import ndimage
1716
18-
1917do_animation = True
2018
2119
2220def transform (
23- gridmap , src , distance_type = 'chessboard' ,
24- transform_type = 'path' , alpha = 0.01
21+ grid_map , src , distance_type = 'chessboard' ,
22+ transform_type = 'path' , alpha = 0.01
2523):
2624 """transform
2725
2826 calculating transform of transform_type from src
2927 in given distance_type
3028
31- :param gridmap : 2d binary map
29+ :param grid_map : 2d binary map
3230 :param src: distance transform source
3331 :param distance_type: type of distance used
3432 :param transform_type: type of transform used
35- :param alpha: weight of Obstacle Transform usedwhen using path_transform
33+ :param alpha: weight of Obstacle Transform used when using path_transform
3634 """
3735
38- nrows , ncols = gridmap .shape
36+ n_rows , n_cols = grid_map .shape
3937
40- if nrows == 0 or ncols == 0 :
41- sys .exit ('Empty gridmap .' )
38+ if n_rows == 0 or n_cols == 0 :
39+ sys .exit ('Empty grid_map .' )
4240
4341 inc_order = [[0 , 1 ], [1 , 1 ], [1 , 0 ], [1 , - 1 ],
4442 [0 , - 1 ], [- 1 , - 1 ], [- 1 , 0 ], [- 1 , 1 ]]
@@ -49,30 +47,30 @@ def transform(
4947 else :
5048 sys .exit ('Unsupported distance type.' )
5149
52- transform_matrix = float ('inf' ) * np .ones_like (gridmap , dtype = np .float )
50+ transform_matrix = float ('inf' ) * np .ones_like (grid_map , dtype = np .float )
5351 transform_matrix [src [0 ], src [1 ]] = 0
5452 if transform_type == 'distance' :
55- eT = np .zeros_like (gridmap )
53+ eT = np .zeros_like (grid_map )
5654 elif transform_type == 'path' :
57- eT = ndimage .distance_transform_cdt (1 - gridmap , distance_type )
55+ eT = ndimage .distance_transform_cdt (1 - grid_map , distance_type )
5856 else :
5957 sys .exit ('Unsupported transform type.' )
6058
6159 # set obstacle transform_matrix value to infinity
62- for i in range (nrows ):
63- for j in range (ncols ):
64- if gridmap [i ][j ] == 1.0 :
60+ for i in range (n_rows ):
61+ for j in range (n_cols ):
62+ if grid_map [i ][j ] == 1.0 :
6563 transform_matrix [i ][j ] = float ('inf' )
6664 is_visited = np .zeros_like (transform_matrix , dtype = bool )
6765 is_visited [src [0 ], src [1 ]] = True
6866 traversal_queue = [src ]
69- calculated = [(src [0 ]- 1 ) * ncols + src [1 ]]
67+ calculated = [(src [0 ] - 1 ) * n_cols + src [1 ]]
7068
71- def is_valid_neighbor (i , j ):
72- return ni >= 0 and ni < nrows and nj >= 0 and nj < ncols \
73- and not gridmap [ ni ][ nj ]
69+ def is_valid_neighbor (g_i , g_j ):
70+ return 0 <= g_i < n_rows and 0 <= g_j < n_cols \
71+ and not grid_map [ g_i ][ g_j ]
7472
75- while traversal_queue != [] :
73+ while traversal_queue :
7674 i , j = traversal_queue .pop (0 )
7775 for k , inc in enumerate (inc_order ):
7876 ni = i + inc [0 ]
@@ -83,16 +81,35 @@ def is_valid_neighbor(i, j):
8381 # update transform_matrix
8482 transform_matrix [i ][j ] = min (
8583 transform_matrix [i ][j ],
86- transform_matrix [ni ][nj ] + cost [k ] + alpha * eT [ni ][nj ])
84+ transform_matrix [ni ][nj ] + cost [k ] + alpha * eT [ni ][nj ])
8785
8886 if not is_visited [ni ][nj ] \
89- and ((ni - 1 ) * ncols + nj ) not in calculated :
87+ and ((ni - 1 ) * n_cols + nj ) not in calculated :
9088 traversal_queue .append ((ni , nj ))
91- calculated .append ((ni - 1 ) * ncols + nj )
89+ calculated .append ((ni - 1 ) * n_cols + nj )
9290
9391 return transform_matrix
9492
9593
94+ def get_search_order_increment (start , goal ):
95+ if start [0 ] >= goal [0 ] and start [1 ] >= goal [1 ]:
96+ order = [[1 , 0 ], [0 , 1 ], [- 1 , 0 ], [0 , - 1 ],
97+ [1 , 1 ], [1 , - 1 ], [- 1 , 1 ], [- 1 , - 1 ]]
98+ elif start [0 ] <= goal [0 ] and start [1 ] >= goal [1 ]:
99+ order = [[- 1 , 0 ], [0 , 1 ], [1 , 0 ], [0 , - 1 ],
100+ [- 1 , 1 ], [- 1 , - 1 ], [1 , 1 ], [1 , - 1 ]]
101+ elif start [0 ] >= goal [0 ] and start [1 ] <= goal [1 ]:
102+ order = [[1 , 0 ], [0 , - 1 ], [- 1 , 0 ], [0 , 1 ],
103+ [1 , - 1 ], [- 1 , - 1 ], [1 , 1 ], [- 1 , 1 ]]
104+ elif start [0 ] <= goal [0 ] and start [1 ] <= goal [1 ]:
105+ order = [[- 1 , 0 ], [0 , - 1 ], [0 , 1 ], [1 , 0 ],
106+ [- 1 , - 1 ], [- 1 , 1 ], [1 , - 1 ], [1 , 1 ]]
107+ else :
108+ sys .exit ('get_search_order_increment: cannot determine \
109+ start=>goal increment order' )
110+ return order
111+
112+
96113def wavefront (transform_matrix , start , goal ):
97114 """wavefront
98115
@@ -104,32 +121,14 @@ def wavefront(transform_matrix, start, goal):
104121 """
105122
106123 path = []
107- nrows , ncols = transform_matrix .shape
108-
109- def get_search_order_increment (start , goal ):
110- if start [0 ] >= goal [0 ] and start [1 ] >= goal [1 ]:
111- order = [[1 , 0 ], [0 , 1 ], [- 1 , 0 ], [0 , - 1 ],
112- [1 , 1 ], [1 , - 1 ], [- 1 , 1 ], [- 1 , - 1 ]]
113- elif start [0 ] <= goal [0 ] and start [1 ] >= goal [1 ]:
114- order = [[- 1 , 0 ], [0 , 1 ], [1 , 0 ], [0 , - 1 ],
115- [- 1 , 1 ], [- 1 , - 1 ], [1 , 1 ], [1 , - 1 ]]
116- elif start [0 ] >= goal [0 ] and start [1 ] <= goal [1 ]:
117- order = [[1 , 0 ], [0 , - 1 ], [- 1 , 0 ], [0 , 1 ],
118- [1 , - 1 ], [- 1 , - 1 ], [1 , 1 ], [- 1 , 1 ]]
119- elif start [0 ] <= goal [0 ] and start [1 ] <= goal [1 ]:
120- order = [[- 1 , 0 ], [0 , - 1 ], [0 , 1 ], [1 , 0 ],
121- [- 1 , - 1 ], [- 1 , 1 ], [1 , - 1 ], [1 , 1 ]]
122- else :
123- sys .exit ('get_search_order_increment: cannot determine \
124- start=>goal increment order' )
125- return order
124+ n_rows , n_cols = transform_matrix .shape
126125
127- def is_valid_neighbor (i , j ):
128- is_i_valid_bounded = i >= 0 and i < nrows
129- is_j_valid_bounded = j >= 0 and j < ncols
126+ def is_valid_neighbor (g_i , g_j ):
127+ is_i_valid_bounded = 0 <= g_i < n_rows
128+ is_j_valid_bounded = 0 <= g_j < n_cols
130129 if is_i_valid_bounded and is_j_valid_bounded :
131- return not is_visited [i ][ j ] and \
132- transform_matrix [i ][ j ] != float ('inf' )
130+ return not is_visited [g_i ][ g_j ] and \
131+ transform_matrix [g_i ][ g_j ] != float ('inf' )
133132 return False
134133
135134 inc_order = get_search_order_increment (start , goal )
@@ -146,7 +145,7 @@ def is_valid_neighbor(i, j):
146145 i_max = (- 1 , - 1 )
147146 i_last = 0
148147 for i_last in range (len (path )):
149- current_node = path [- 1 - i_last ] # get lastest node in path
148+ current_node = path [- 1 - i_last ] # get latest node in path
150149 for ci , cj in inc_order :
151150 ni , nj = current_node [0 ] + ci , current_node [1 ] + cj
152151 if is_valid_neighbor (ni , nj ) and \
@@ -168,10 +167,10 @@ def is_valid_neighbor(i, j):
168167 return path
169168
170169
171- def visualize_path (grid_map , start , goal , path , resolution ):
170+ def visualize_path (grid_map , start , goal , path ): # pragma: no cover
172171 oy , ox = start
173172 gy , gx = goal
174- px , py = np .transpose (np .flipud (np .fliplr (( path ) )))
173+ px , py = np .transpose (np .flipud (np .fliplr (path )))
175174
176175 if not do_animation :
177176 plt .imshow (grid_map , cmap = 'Greys' )
@@ -207,12 +206,12 @@ def main():
207206 # distance transform wavefront
208207 DT = transform (img , goal , transform_type = 'distance' )
209208 DT_path = wavefront (DT , start , goal )
210- visualize_path (img , start , goal , DT_path , 1 )
209+ visualize_path (img , start , goal , DT_path )
211210
212211 # path transform wavefront
213212 PT = transform (img , goal , transform_type = 'path' , alpha = 0.01 )
214213 PT_path = wavefront (PT , start , goal )
215- visualize_path (img , start , goal , PT_path , 1 )
214+ visualize_path (img , start , goal , PT_path )
216215
217216
218217if __name__ == "__main__" :
0 commit comments