Skip to content

Commit 097a289

Browse files
committed
code clean up for informed rrt star
1 parent 38373ac commit 097a289

File tree

1 file changed

+94
-87
lines changed

1 file changed

+94
-87
lines changed

PathPlanning/InformedRRTStar/informed_rrt_star.py

Lines changed: 94 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,40 @@
44
author: Karan Chawla
55
Atsushi Sakai(@Atsushi_twi)
66
7-
Reference: Informed RRT*: Optimal Sampling-based Path Planning Focused via
7+
Reference: Informed RRT*: Optimal Sampling-based Path planning Focused via
88
Direct Sampling of an Admissible Ellipsoidal Heuristichttps://arxiv.org/pdf/1404.2334.pdf
99
1010
"""
1111

12-
13-
import random
14-
import numpy as np
15-
import math
1612
import copy
13+
import math
14+
import random
15+
1716
import matplotlib.pyplot as plt
17+
import numpy as np
1818

1919
show_animation = True
2020

2121

22-
class InformedRRTStar():
22+
class InformedRRTStar:
2323

2424
def __init__(self, start, goal,
2525
obstacleList, randArea,
2626
expandDis=0.5, goalSampleRate=10, maxIter=200):
2727

2828
self.start = Node(start[0], start[1])
2929
self.goal = Node(goal[0], goal[1])
30-
self.minrand = randArea[0]
31-
self.maxrand = randArea[1]
32-
self.expandDis = expandDis
33-
self.goalSampleRate = goalSampleRate
34-
self.maxIter = maxIter
35-
self.obstacleList = obstacleList
30+
self.min_rand = randArea[0]
31+
self.max_rand = randArea[1]
32+
self.expand_dis = expandDis
33+
self.goal_sample_rate = goalSampleRate
34+
self.max_iter = maxIter
35+
self.obstacle_list = obstacleList
36+
self.node_list = None
3637

37-
def InformedRRTStarSearch(self, animation=True):
38+
def informed_rrt_star_search(self, animation=True):
3839

39-
self.nodeList = [self.start]
40+
self.node_list = [self.start]
4041
# max length we expect to find in our 'informed' sample space, starts as infinite
4142
cBest = float('inf')
4243
pathLen = float('inf')
@@ -55,62 +56,62 @@ def InformedRRTStarSearch(self, animation=True):
5556
# first column of idenity matrix transposed
5657
id1_t = np.array([1.0, 0.0, 0.0]).reshape(1, 3)
5758
M = a1 @ id1_t
58-
U, S, Vh = np.linalg.svd(M, 1, 1)
59+
U, S, Vh = np.linalg.svd(M, True, True)
5960
C = np.dot(np.dot(U, np.diag(
6061
[1.0, 1.0, np.linalg.det(U) * np.linalg.det(np.transpose(Vh))])), Vh)
6162

62-
for i in range(self.maxIter):
63+
for i in range(self.max_iter):
6364
# Sample space is defined by cBest
6465
# cMin is the minimum distance between the start point and the goal
6566
# xCenter is the midpoint between the start and the goal
6667
# cBest changes when a new path is found
6768

6869
rnd = self.informed_sample(cBest, cMin, xCenter, C)
69-
nind = self.getNearestListIndex(self.nodeList, rnd)
70-
nearestNode = self.nodeList[nind]
70+
nind = self.get_nearest_list_index(self.node_list, rnd)
71+
nearestNode = self.node_list[nind]
7172
# steer
7273
theta = math.atan2(rnd[1] - nearestNode.y, rnd[0] - nearestNode.x)
73-
newNode = self.getNewNode(theta, nind, nearestNode)
74-
d = self.lineCost(nearestNode, newNode)
74+
newNode = self.get_new_node(theta, nind, nearestNode)
75+
d = self.line_cost(nearestNode, newNode)
7576

76-
isCollision = self.__CollisionCheck(newNode, self.obstacleList)
77+
isCollision = self.collision_check(newNode, self.obstacle_list)
7778
isCollisionEx = self.check_collision_extend(nearestNode, theta, d)
7879

7980
if isCollision and isCollisionEx:
80-
nearInds = self.findNearNodes(newNode)
81-
newNode = self.chooseParent(newNode, nearInds)
81+
nearInds = self.find_near_nodes(newNode)
82+
newNode = self.choose_parent(newNode, nearInds)
8283

83-
self.nodeList.append(newNode)
84+
self.node_list.append(newNode)
8485
self.rewire(newNode, nearInds)
8586

86-
if self.isNearGoal(newNode):
87+
if self.is_near_goal(newNode):
8788
solutionSet.add(newNode)
88-
lastIndex = len(self.nodeList) - 1
89-
tempPath = self.getFinalCourse(lastIndex)
90-
tempPathLen = self.getPathLen(tempPath)
89+
lastIndex = len(self.node_list) - 1
90+
tempPath = self.get_final_course(lastIndex)
91+
tempPathLen = self.get_path_len(tempPath)
9192
if tempPathLen < pathLen:
9293
path = tempPath
9394
cBest = tempPathLen
9495

9596
if animation:
96-
self.drawGraph(xCenter=xCenter,
97-
cBest=cBest, cMin=cMin,
98-
etheta=etheta, rnd=rnd)
97+
self.draw_graph(xCenter=xCenter,
98+
cBest=cBest, cMin=cMin,
99+
etheta=etheta, rnd=rnd)
99100

100101
return path
101102

102-
def chooseParent(self, newNode, nearInds):
103+
def choose_parent(self, newNode, nearInds):
103104
if len(nearInds) == 0:
104105
return newNode
105106

106107
dList = []
107108
for i in nearInds:
108-
dx = newNode.x - self.nodeList[i].x
109-
dy = newNode.y - self.nodeList[i].y
109+
dx = newNode.x - self.node_list[i].x
110+
dy = newNode.y - self.node_list[i].y
110111
d = math.sqrt(dx ** 2 + dy ** 2)
111112
theta = math.atan2(dy, dx)
112-
if self.check_collision_extend(self.nodeList[i], theta, d):
113-
dList.append(self.nodeList[i].cost + d)
113+
if self.check_collision_extend(self.node_list[i], theta, d):
114+
dList.append(self.node_list[i].cost + d)
114115
else:
115116
dList.append(float('inf'))
116117

@@ -126,29 +127,30 @@ def chooseParent(self, newNode, nearInds):
126127

127128
return newNode
128129

129-
def findNearNodes(self, newNode):
130-
nnode = len(self.nodeList)
130+
def find_near_nodes(self, newNode):
131+
nnode = len(self.node_list)
131132
r = 50.0 * math.sqrt((math.log(nnode) / nnode))
132133
dlist = [(node.x - newNode.x) ** 2
133-
+ (node.y - newNode.y) ** 2 for node in self.nodeList]
134+
+ (node.y - newNode.y) ** 2 for node in self.node_list]
134135
nearinds = [dlist.index(i) for i in dlist if i <= r ** 2]
135136
return nearinds
136137

137138
def informed_sample(self, cMax, cMin, xCenter, C):
138139
if cMax < float('inf'):
139140
r = [cMax / 2.0,
140-
math.sqrt(cMax**2 - cMin**2) / 2.0,
141-
math.sqrt(cMax**2 - cMin**2) / 2.0]
141+
math.sqrt(cMax ** 2 - cMin ** 2) / 2.0,
142+
math.sqrt(cMax ** 2 - cMin ** 2) / 2.0]
142143
L = np.diag(r)
143-
xBall = self.sampleUnitBall()
144+
xBall = self.sample_unit_ball()
144145
rnd = np.dot(np.dot(C, L), xBall) + xCenter
145146
rnd = [rnd[(0, 0)], rnd[(1, 0)]]
146147
else:
147-
rnd = self.sampleFreeSpace()
148+
rnd = self.sample_free_space()
148149

149150
return rnd
150151

151-
def sampleUnitBall(self):
152+
@staticmethod
153+
def sample_unit_ball():
152154
a = random.random()
153155
b = random.random()
154156

@@ -159,114 +161,118 @@ def sampleUnitBall(self):
159161
b * math.sin(2 * math.pi * a / b))
160162
return np.array([[sample[0]], [sample[1]], [0]])
161163

162-
def sampleFreeSpace(self):
163-
if random.randint(0, 100) > self.goalSampleRate:
164-
rnd = [random.uniform(self.minrand, self.maxrand),
165-
random.uniform(self.minrand, self.maxrand)]
164+
def sample_free_space(self):
165+
if random.randint(0, 100) > self.goal_sample_rate:
166+
rnd = [random.uniform(self.min_rand, self.max_rand),
167+
random.uniform(self.min_rand, self.max_rand)]
166168
else:
167169
rnd = [self.goal.x, self.goal.y]
168170

169171
return rnd
170172

171-
def getPathLen(self, path):
173+
@staticmethod
174+
def get_path_len(path):
172175
pathLen = 0
173176
for i in range(1, len(path)):
174177
node1_x = path[i][0]
175178
node1_y = path[i][1]
176179
node2_x = path[i - 1][0]
177180
node2_y = path[i - 1][1]
178181
pathLen += math.sqrt((node1_x - node2_x)
179-
** 2 + (node1_y - node2_y)**2)
182+
** 2 + (node1_y - node2_y) ** 2)
180183

181184
return pathLen
182185

183-
def lineCost(self, node1, node2):
184-
return math.sqrt((node1.x - node2.x)**2 + (node1.y - node2.y)**2)
186+
@staticmethod
187+
def line_cost(node1, node2):
188+
return math.sqrt((node1.x - node2.x) ** 2 + (node1.y - node2.y) ** 2)
185189

186-
def getNearestListIndex(self, nodes, rnd):
187-
dList = [(node.x - rnd[0])**2
188-
+ (node.y - rnd[1])**2 for node in nodes]
190+
@staticmethod
191+
def get_nearest_list_index(nodes, rnd):
192+
dList = [(node.x - rnd[0]) ** 2
193+
+ (node.y - rnd[1]) ** 2 for node in nodes]
189194
minIndex = dList.index(min(dList))
190195
return minIndex
191196

192-
def __CollisionCheck(self, newNode, obstacleList):
197+
@staticmethod
198+
def collision_check(newNode, obstacleList):
193199
for (ox, oy, size) in obstacleList:
194200
dx = ox - newNode.x
195201
dy = oy - newNode.y
196202
d = dx * dx + dy * dy
197-
if d <= 1.1 * size**2:
203+
if d <= 1.1 * size ** 2:
198204
return False # collision
199205

200206
return True # safe
201207

202-
def getNewNode(self, theta, nind, nearestNode):
208+
def get_new_node(self, theta, nind, nearestNode):
203209
newNode = copy.deepcopy(nearestNode)
204210

205-
newNode.x += self.expandDis * math.cos(theta)
206-
newNode.y += self.expandDis * math.sin(theta)
211+
newNode.x += self.expand_dis * math.cos(theta)
212+
newNode.y += self.expand_dis * math.sin(theta)
207213

208-
newNode.cost += self.expandDis
214+
newNode.cost += self.expand_dis
209215
newNode.parent = nind
210216
return newNode
211217

212-
def isNearGoal(self, node):
213-
d = self.lineCost(node, self.goal)
214-
if d < self.expandDis:
218+
def is_near_goal(self, node):
219+
d = self.line_cost(node, self.goal)
220+
if d < self.expand_dis:
215221
return True
216222
return False
217223

218224
def rewire(self, newNode, nearInds):
219-
nnode = len(self.nodeList)
225+
n_node = len(self.node_list)
220226
for i in nearInds:
221-
nearNode = self.nodeList[i]
227+
nearNode = self.node_list[i]
222228

223-
d = math.sqrt((nearNode.x - newNode.x)**2
224-
+ (nearNode.y - newNode.y)**2)
229+
d = math.sqrt((nearNode.x - newNode.x) ** 2
230+
+ (nearNode.y - newNode.y) ** 2)
225231

226232
scost = newNode.cost + d
227233

228234
if nearNode.cost > scost:
229235
theta = math.atan2(newNode.y - nearNode.y,
230236
newNode.x - nearNode.x)
231237
if self.check_collision_extend(nearNode, theta, d):
232-
nearNode.parent = nnode - 1
238+
nearNode.parent = n_node - 1
233239
nearNode.cost = scost
234240

235241
def check_collision_extend(self, nearNode, theta, d):
236242
tmpNode = copy.deepcopy(nearNode)
237243

238-
for i in range(int(d / self.expandDis)):
239-
tmpNode.x += self.expandDis * math.cos(theta)
240-
tmpNode.y += self.expandDis * math.sin(theta)
241-
if not self.__CollisionCheck(tmpNode, self.obstacleList):
244+
for i in range(int(d / self.expand_dis)):
245+
tmpNode.x += self.expand_dis * math.cos(theta)
246+
tmpNode.y += self.expand_dis * math.sin(theta)
247+
if not self.collision_check(tmpNode, self.obstacle_list):
242248
return False
243249

244250
return True
245251

246-
def getFinalCourse(self, lastIndex):
252+
def get_final_course(self, lastIndex):
247253
path = [[self.goal.x, self.goal.y]]
248-
while self.nodeList[lastIndex].parent is not None:
249-
node = self.nodeList[lastIndex]
254+
while self.node_list[lastIndex].parent is not None:
255+
node = self.node_list[lastIndex]
250256
path.append([node.x, node.y])
251257
lastIndex = node.parent
252258
path.append([self.start.x, self.start.y])
253259
return path
254260

255-
def drawGraph(self, xCenter=None, cBest=None, cMin=None, etheta=None, rnd=None):
261+
def draw_graph(self, xCenter=None, cBest=None, cMin=None, etheta=None, rnd=None):
256262

257263
plt.clf()
258264
if rnd is not None:
259265
plt.plot(rnd[0], rnd[1], "^k")
260266
if cBest != float('inf'):
261267
self.plot_ellipse(xCenter, cBest, cMin, etheta)
262268

263-
for node in self.nodeList:
269+
for node in self.node_list:
264270
if node.parent is not None:
265271
if node.x or node.y is not None:
266-
plt.plot([node.x, self.nodeList[node.parent].x], [
267-
node.y, self.nodeList[node.parent].y], "-g")
272+
plt.plot([node.x, self.node_list[node.parent].x], [
273+
node.y, self.node_list[node.parent].y], "-g")
268274

269-
for (ox, oy, size) in self.obstacleList:
275+
for (ox, oy, size) in self.obstacle_list:
270276
plt.plot(ox, oy, "ok", ms=30 * size)
271277

272278
plt.plot(self.start.x, self.start.y, "xr")
@@ -275,9 +281,10 @@ def drawGraph(self, xCenter=None, cBest=None, cMin=None, etheta=None, rnd=None):
275281
plt.grid(True)
276282
plt.pause(0.01)
277283

278-
def plot_ellipse(self, xCenter, cBest, cMin, etheta): # pragma: no cover
284+
@staticmethod
285+
def plot_ellipse(xCenter, cBest, cMin, etheta): # pragma: no cover
279286

280-
a = math.sqrt(cBest**2 - cMin**2) / 2.0
287+
a = math.sqrt(cBest ** 2 - cMin ** 2) / 2.0
281288
b = cBest / 2.0
282289
angle = math.pi / 2.0 - etheta
283290
cx = xCenter[0]
@@ -295,7 +302,7 @@ def plot_ellipse(self, xCenter, cBest, cMin, etheta): # pragma: no cover
295302
plt.plot(px, py, "--c")
296303

297304

298-
class Node():
305+
class Node:
299306

300307
def __init__(self, x, y):
301308
self.x = x
@@ -320,17 +327,17 @@ def main():
320327
# Set params
321328
rrt = InformedRRTStar(start=[0, 0], goal=[5, 10],
322329
randArea=[-2, 15], obstacleList=obstacleList)
323-
path = rrt.InformedRRTStarSearch(animation=show_animation)
330+
path = rrt.informed_rrt_star_search(animation=show_animation)
324331
print("Done!!")
325332

326333
# Plot path
327334
if show_animation:
328-
rrt.drawGraph()
335+
rrt.draw_graph()
329336
plt.plot([x for (x, y) in path], [y for (x, y) in path], '-r')
330337
plt.grid(True)
331338
plt.pause(0.01)
332339
plt.show()
333340

334341

335342
if __name__ == '__main__':
336-
main()
343+
main()

0 commit comments

Comments
 (0)