@@ -74,24 +74,24 @@ def informed_rrt_star_search(self, animation=True):
7474 newNode = self .get_new_node (theta , nind , nearestNode )
7575 d = self .line_cost (nearestNode , newNode )
7676
77- isCollision = self .collision_check (newNode , self .obstacle_list )
78- isCollisionEx = self .check_collision_extend (nearestNode , theta , d )
77+ noCollision = self .check_collision (nearestNode , theta , d )
7978
80- if isCollision and isCollisionEx :
79+ if noCollision :
8180 nearInds = self .find_near_nodes (newNode )
8281 newNode = self .choose_parent (newNode , nearInds )
8382
8483 self .node_list .append (newNode )
8584 self .rewire (newNode , nearInds )
8685
8786 if self .is_near_goal (newNode ):
88- solutionSet .add (newNode )
89- lastIndex = len (self .node_list ) - 1
90- tempPath = self .get_final_course (lastIndex )
91- tempPathLen = self .get_path_len (tempPath )
92- if tempPathLen < pathLen :
93- path = tempPath
94- cBest = tempPathLen
87+ if self .check_segment_collision (newNode .x , newNode .y , self .goal .x , self .goal .y ):
88+ solutionSet .add (newNode )
89+ lastIndex = len (self .node_list ) - 1
90+ tempPath = self .get_final_course (lastIndex )
91+ tempPathLen = self .get_path_len (tempPath )
92+ if tempPathLen < pathLen :
93+ path = tempPath
94+ cBest = tempPathLen
9595
9696 if animation :
9797 self .draw_graph (xCenter = xCenter ,
@@ -110,7 +110,7 @@ def choose_parent(self, newNode, nearInds):
110110 dy = newNode .y - self .node_list [i ].y
111111 d = math .sqrt (dx ** 2 + dy ** 2 )
112112 theta = math .atan2 (dy , dx )
113- if self .check_collision_extend (self .node_list [i ], theta , d ):
113+ if self .check_collision (self .node_list [i ], theta , d ):
114114 dList .append (self .node_list [i ].cost + d )
115115 else :
116116 dList .append (float ('inf' ))
@@ -194,17 +194,6 @@ def get_nearest_list_index(nodes, rnd):
194194 minIndex = dList .index (min (dList ))
195195 return minIndex
196196
197- @staticmethod
198- def collision_check (newNode , obstacleList ):
199- for (ox , oy , size ) in obstacleList :
200- dx = ox - newNode .x
201- dy = oy - newNode .y
202- d = dx * dx + dy * dy
203- if d <= 1.1 * size ** 2 :
204- return False # collision
205-
206- return True # safe
207-
208197 def get_new_node (self , theta , nind , nearestNode ):
209198 newNode = copy .deepcopy (nearestNode )
210199
@@ -234,20 +223,40 @@ def rewire(self, newNode, nearInds):
234223 if nearNode .cost > scost :
235224 theta = math .atan2 (newNode .y - nearNode .y ,
236225 newNode .x - nearNode .x )
237- if self .check_collision_extend (nearNode , theta , d ):
226+ if self .check_collision (nearNode , theta , d ):
238227 nearNode .parent = n_node - 1
239228 nearNode .cost = scost
229+
230+ @staticmethod
231+ def distance_squared_point_to_segment (v , w , p ):
232+ # Return minimum distance between line segment vw and point p
233+ if (np .array_equal (v , w )):
234+ return (p - v ).dot (p - v ) # v == w case
235+ l2 = (w - v ).dot (w - v ) # i.e. |w-v|^2 - avoid a sqrt
236+ # Consider the line extending the segment, parameterized as v + t (w - v).
237+ # We find projection of point p onto the line.
238+ # It falls where t = [(p-v) . (w-v)] / |w-v|^2
239+ # We clamp t from [0,1] to handle points outside the segment vw.
240+ t = max (0 , min (1 , (p - v ).dot (w - v ) / l2 ))
241+ projection = v + t * (w - v ) # Projection falls on the segment
242+ return (p - projection ).dot (p - projection )
243+
244+ def check_segment_collision (self , x1 , y1 , x2 , y2 ):
245+ for (ox , oy , size ) in self .obstacle_list :
246+ dd = self .distance_squared_point_to_segment (
247+ np .array ([x1 , y1 ]),
248+ np .array ([x2 , y2 ]),
249+ np .array ([ox , oy ]))
250+ if dd <= size ** 2 :
251+ return False # collision
252+ return True
240253
241- def check_collision_extend (self , nearNode , theta , d ):
242- tmpNode = copy .deepcopy (nearNode )
243-
244- for i in range (int (d / self .expand_dis )):
245- tmpNode .x += self .expand_dis * math .cos (theta )
246- tmpNode .y += self .expand_dis * math .sin (theta )
247- if not self .collision_check (tmpNode , self .obstacle_list ):
248- return False
249254
250- return True
255+ def check_collision (self , nearNode , theta , d ):
256+ tmpNode = copy .deepcopy (nearNode )
257+ endx = tmpNode .x + math .cos (theta )* d
258+ endy = tmpNode .y + math .sin (theta )* d
259+ return self .check_segment_collision (tmpNode .x , tmpNode .y , endx , endy )
251260
252261 def get_final_course (self , lastIndex ):
253262 path = [[self .goal .x , self .goal .y ]]
0 commit comments