Skip to content

Commit b3ab809

Browse files
Add Cubature Kalman Filter (AtsushiSakai#410)
* cubature kalman filter * Revert "cubature kalman filter" This reverts commit 1727728. * add ckf test * update flags for CI * update flags for CI * update flags for CI * remove comments * remove comments * change postpross * changes requested * remove comments * Changes to comments, remove linear_update * changes to comments * removed comments * change comments * update comments * update comments * update comments * update comments * fix comment
1 parent b10bfbb commit b3ab809

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""
2+
Cubature Kalman filter using Constant Turn Rate and Velocity (CTRV) model
3+
Fuse sensor data from IMU and GPS to obtain accurate position
4+
5+
https://ieeexplore.ieee.org/document/4982682
6+
7+
Author: Raghuram Shankar
8+
9+
state matrix: 2D x-y position, yaw, velocity and yaw rate
10+
measurement matrix: 2D x-y position, velocity and yaw rate
11+
12+
dt: Duration of time step
13+
N: Number of time steps
14+
show_final: Flag for showing final result
15+
show_animation: Flag for showing each animation frame
16+
show_ellipse: Flag for showing covariance ellipse
17+
z_noise: Measurement noise
18+
x_0: Prior state estimate matrix
19+
P_0: Prior state estimate covariance matrix
20+
q: Process noise covariance
21+
hx: Measurement model matrix
22+
r: Sensor noise covariance
23+
SP: Sigma Points
24+
W: Weights
25+
26+
x_est: State estimate
27+
P_est: State estimate covariance
28+
x_true: Ground truth value of state
29+
x_true_cat: Concatenate all ground truth states
30+
x_est_cat: Concatenate all state estimates
31+
z_cat: Concatenate all measurements
32+
33+
"""
34+
35+
import math
36+
import matplotlib.pyplot as plt
37+
import numpy as np
38+
from scipy.linalg import sqrtm
39+
40+
41+
dt = 0.1
42+
N = 100
43+
44+
show_final = 1
45+
show_animation = 0
46+
show_ellipse = 0
47+
48+
49+
z_noise = np.array([[0.1, 0.0, 0.0, 0.0], # x position [m]
50+
[0.0, 0.1, 0.0, 0.0], # y position [m]
51+
[0.0, 0.0, 0.1, 0.0], # velocity [m/s]
52+
[0.0, 0.0, 0.0, 0.1]]) # yaw rate [rad/s]
53+
54+
55+
x_0 = np.array([[0.0], # x position [m]
56+
[0.0], # y position [m]
57+
[0.0], # yaw [rad]
58+
[1.0], # velocity [m/s]
59+
[0.1]]) # yaw rate [rad/s]
60+
61+
62+
p_0 = np.array([[1e-3, 0.0, 0.0, 0.0, 0.0],
63+
[0.0, 1e-3, 0.0, 0.0, 0.0],
64+
[0.0, 0.0, 1.0, 0.0, 0.0],
65+
[0.0, 0.0, 0.0, 1.0, 0.0],
66+
[0.0, 0.0, 0.0, 0.0, 1.0]])
67+
68+
69+
q = np.array([[1e-11, 0.0, 0.0, 0.0, 0.0],
70+
[0.0, 1e-11, 0.0, 0.0, 0.0],
71+
[0.0, 0.0, np.deg2rad(1e-4), 0.0, 0.0],
72+
[0.0, 0.0, 0.0, 1e-4, 0.0],
73+
[0.0, 0.0, 0.0, 0.0, np.deg2rad(1e-4)]])
74+
75+
76+
hx = np.array([[1.0, 0.0, 0.0, 0.0, 0.0],
77+
[0.0, 1.0, 0.0, 0.0, 0.0],
78+
[0.0, 0.0, 0.0, 1.0, 0.0],
79+
[0.0, 0.0, 0.0, 0.0, 1.0]])
80+
81+
82+
r = np.array([[0.015, 0.0, 0.0, 0.0],
83+
[0.0, 0.010, 0.0, 0.0],
84+
[0.0, 0.0, 0.1, 0.0],
85+
[0.0, 0.0, 0.0, 0.01]])**2
86+
87+
88+
def cubature_kalman_filter(x_est, p_est, z):
89+
x_pred, p_pred = cubature_prediction(x_est, p_est)
90+
x_upd, p_upd = cubature_update(x_pred, p_pred, z)
91+
return x_upd, p_upd
92+
93+
94+
def f(x):
95+
"""
96+
Motion Model
97+
References:
98+
http://fusion.isif.org/proceedings/fusion08CD/papers/1569107835.pdf
99+
https://github.com/balzer82/Kalman
100+
"""
101+
x[0] = x[0] + (x[3]/x[4]) * (np.sin(x[4] * dt + x[2]) - np.sin(x[2]))
102+
x[1] = x[1] + (x[3]/x[4]) * (- np.cos(x[4] * dt + x[2]) + np.cos(x[2]))
103+
x[2] = x[2] + x[4] * dt
104+
x[3] = x[3]
105+
x[4] = x[4]
106+
return x
107+
108+
109+
def h(x):
110+
"""Measurement Model"""
111+
x = hx @ x
112+
return x
113+
114+
115+
def sigma(x, p):
116+
"""
117+
Unscented Transform with Cubature Rule
118+
Generate 2n Sigma Points to represent the nonlinear motion
119+
Assign Weights to each Sigma Point, Wi = 1/2n
120+
Cubature Rule - Special Case of Unscented Transform
121+
W0 = 0, no extra tuning parameters, no negative weights
122+
"""
123+
n = np.shape(x)[0]
124+
SP = np.zeros((n, 2*n))
125+
W = np.zeros((1, 2*n))
126+
for i in range(n):
127+
SD = sqrtm(p)
128+
SP[:, i] = (x + (math.sqrt(n) * SD[:, i]).reshape((n, 1))).flatten()
129+
SP[:, i+n] = (x - (math.sqrt(n) * SD[:, i]).reshape((n, 1))).flatten()
130+
W[:, i] = 1/(2*n)
131+
W[:, i+n] = W[:, i]
132+
return SP, W
133+
134+
135+
def cubature_prediction(x_pred, p_pred):
136+
n = np.shape(x_pred)[0]
137+
[SP, W] = sigma(x_pred, p_pred)
138+
x_pred = np.zeros((n, 1))
139+
p_pred = q
140+
for i in range(2*n):
141+
x_pred = x_pred + (f(SP[:, i]).reshape((n, 1)) * W[0, i])
142+
for i in range(2*n):
143+
p_step = (f(SP[:, i]).reshape((n, 1)) - x_pred)
144+
p_pred = p_pred + (p_step @ p_step.T * W[0, i])
145+
return x_pred, p_pred
146+
147+
148+
def cubature_update(x_pred, p_pred, z):
149+
n = np.shape(x_pred)[0]
150+
m = np.shape(z)[0]
151+
[SP, W] = sigma(x_pred, p_pred)
152+
y_k = np.zeros((m, 1))
153+
P_xy = np.zeros((n, m))
154+
s = r
155+
for i in range(2*n):
156+
y_k = y_k + (h(SP[:, i]).reshape((m, 1)) * W[0, i])
157+
for i in range(2*n):
158+
p_step = (h(SP[:, i]).reshape((m, 1)) - y_k)
159+
P_xy = P_xy + ((SP[:, i]).reshape((n, 1)) -
160+
x_pred) @ p_step.T * W[0, i]
161+
s = s + p_step @ p_step.T * W[0, i]
162+
x_pred = x_pred + P_xy @ np.linalg.pinv(s) @ (z - y_k)
163+
p_pred = p_pred - P_xy @ np.linalg.pinv(s) @ P_xy.T
164+
return x_pred, p_pred
165+
166+
167+
def generate_measurement(x_true):
168+
gz = hx @ x_true
169+
z = gz + z_noise @ np.random.randn(4, 1)
170+
return z
171+
172+
173+
def plot_animation(i, x_true_cat, x_est_cat, z):
174+
if i == 0:
175+
plt.plot(x_true_cat[0], x_true_cat[1], '.r')
176+
plt.plot(x_est_cat[0], x_est_cat[1], '.b')
177+
else:
178+
plt.plot(x_true_cat[0:, 0], x_true_cat[0:, 1], 'r')
179+
plt.plot(x_est_cat[0:, 0], x_est_cat[0:, 1], 'b')
180+
plt.plot(z[0], z[1], '+g')
181+
plt.grid(True)
182+
plt.pause(0.001)
183+
184+
185+
def plot_ellipse(x_est, p_est):
186+
phi = np.linspace(0, 2 * math.pi, 100)
187+
p_ellipse = np.array(
188+
[[p_est[0, 0], p_est[0, 1]], [p_est[1, 0], p_est[1, 1]]])
189+
x0 = 3 * sqrtm(p_ellipse)
190+
xy_1 = np.array([])
191+
xy_2 = np.array([])
192+
for i in range(100):
193+
arr = np.array([[math.sin(phi[i])], [math.cos(phi[i])]])
194+
arr = x0 @ arr
195+
xy_1 = np.hstack([xy_1, arr[0]])
196+
xy_2 = np.hstack([xy_2, arr[1]])
197+
plt.plot(xy_1 + x_est[0], xy_2 + x_est[1], 'r')
198+
plt.pause(0.00001)
199+
200+
201+
def plot_final(x_true_cat, x_est_cat, z_cat):
202+
fig = plt.figure()
203+
subplot = fig.add_subplot(111)
204+
subplot.plot(x_true_cat[0:, 0], x_true_cat[0:, 1],
205+
'r', label='True Position')
206+
subplot.plot(x_est_cat[0:, 0], x_est_cat[0:, 1],
207+
'b', label='Estimated Position')
208+
subplot.plot(z_cat[0:, 0], z_cat[0:, 1], '+g', label='Noisy Measurements')
209+
subplot.set_xlabel('x [m]')
210+
subplot.set_ylabel('y [m]')
211+
subplot.set_title('Cubature Kalman Filter - CTRV Model')
212+
subplot.legend(loc='upper left', shadow=True, fontsize='large')
213+
plt.grid(True)
214+
plt.show()
215+
216+
217+
def main():
218+
print(__file__ + " start!!")
219+
x_est = x_0
220+
p_est = p_0
221+
x_true = x_0
222+
x_true_cat = np.array([x_0[0, 0], x_0[1, 0]])
223+
x_est_cat = np.array([x_0[0, 0], x_0[1, 0]])
224+
z_cat = np.array([x_0[0, 0], x_0[1, 0]])
225+
for i in range(N):
226+
x_true = f(x_true)
227+
z = generate_measurement(x_true)
228+
if i == (N - 1) and show_final == 1:
229+
show_final_flag = 1
230+
else:
231+
show_final_flag = 0
232+
if show_animation == 1:
233+
plot_animation(i, x_true_cat, x_est_cat, z)
234+
if show_ellipse == 1:
235+
plot_ellipse(x_est[0:2], p_est)
236+
if show_final_flag == 1:
237+
plot_final(x_true_cat, x_est_cat, z_cat)
238+
x_est, p_est = cubature_kalman_filter(x_est, p_est, z)
239+
x_true_cat = np.vstack((x_true_cat, x_true[0:2].T))
240+
x_est_cat = np.vstack((x_est_cat, x_est[0:2].T))
241+
z_cat = np.vstack((z_cat, z[0:2].T))
242+
print('CKF Over')
243+
244+
245+
if __name__ == '__main__':
246+
main()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from unittest import TestCase
2+
3+
from Localization.cubature_kalman_filter import cubature_kalman_filter as m
4+
5+
print(__file__)
6+
7+
8+
class Test(TestCase):
9+
10+
def test1(self):
11+
m.show_final = False
12+
m.show_animation = False
13+
m.show_ellipse = False
14+
m.main()

0 commit comments

Comments
 (0)