Skip to content

Commit 03e8969

Browse files
committed
inverted pendulum mpc control is added
1 parent 994e0e3 commit 03e8969

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
Inverted Pendulum MPC control
3+
author: Atsushi Sakai
4+
"""
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import math
9+
import time
10+
import cvxpy
11+
12+
# Model parameters
13+
14+
l_bar = 2.0 # length of bar
15+
M = 1.0 # [kg]
16+
m = 0.3 # [kg]
17+
g = 9.8 # [m/s^2]
18+
19+
Q = np.diag([0.0, 1.0, 1.0, 0.0])
20+
R = np.diag([0.01])
21+
nx = 4 # number of state
22+
nu = 1 # number of input
23+
T = 30 # Horizon length
24+
delta_t = 0.1 # time tick
25+
26+
animation = True
27+
28+
29+
def main():
30+
x0 = np.array([
31+
[0.0],
32+
[0.0],
33+
[0.3],
34+
[0.0]
35+
])
36+
37+
x = np.copy(x0)
38+
39+
for i in range(50):
40+
41+
# calc control input
42+
optimized_x, optimized_delta_x, optimized_theta, optimized_delta_theta, optimized_input = mpc_control(x)
43+
44+
# get input
45+
u = optimized_input[0]
46+
47+
# simulate inverted pendulum cart
48+
x = simulation(x, u)
49+
50+
if animation:
51+
plt.clf()
52+
px = float(x[0])
53+
theta = float(x[2])
54+
plot_cart(px, theta)
55+
plt.xlim([-5.0, 2.0])
56+
plt.pause(0.001)
57+
58+
59+
def simulation(x, u):
60+
A, B = get_model_matrix()
61+
62+
x = np.dot(A, x) + np.dot(B, u)
63+
64+
return x
65+
66+
67+
def mpc_control(x0):
68+
x = cvxpy.Variable((nx, T + 1))
69+
u = cvxpy.Variable((nu, T))
70+
71+
A, B = get_model_matrix()
72+
73+
cost = 0.0
74+
constr = []
75+
for t in range(T):
76+
cost += cvxpy.quad_form(x[:, t + 1], Q)
77+
cost += cvxpy.quad_form(u[:, t], R)
78+
constr += [x[:, t + 1] == A * x[:, t] + B * u[:, t]]
79+
80+
constr += [x[:, 0] == x0[:, 0]]
81+
prob = cvxpy.Problem(cvxpy.Minimize(cost), constr)
82+
83+
start = time.time()
84+
prob.solve(verbose=False)
85+
elapsed_time = time.time() - start
86+
print("calc time:{0} [sec]".format(elapsed_time))
87+
88+
if prob.status == cvxpy.OPTIMAL:
89+
ox = get_nparray_from_matrix(x.value[0, :])
90+
dx = get_nparray_from_matrix(x.value[1, :])
91+
theta = get_nparray_from_matrix(x.value[2, :])
92+
dtheta = get_nparray_from_matrix(x.value[3, :])
93+
94+
ou = get_nparray_from_matrix(u.value[0, :])
95+
96+
return ox, dx, theta, dtheta, ou
97+
98+
99+
def get_nparray_from_matrix(x):
100+
"""
101+
get build-in list from matrix
102+
"""
103+
return np.array(x).flatten()
104+
105+
106+
def get_model_matrix():
107+
A = np.array([
108+
[0.0, 1.0, 0.0, 0.0],
109+
[0.0, 0.0, m * g / M, 0.0],
110+
[0.0, 0.0, 0.0, 1.0],
111+
[0.0, 0.0, g * (M + m) / (l_bar * M), 0.0]
112+
])
113+
A = np.eye(nx) + delta_t * A
114+
115+
B = np.array([
116+
[0.0],
117+
[1.0 / M],
118+
[0.0],
119+
[1.0 / (l_bar * M)]
120+
])
121+
B = delta_t * B
122+
123+
return A, B
124+
125+
126+
def flatten(a):
127+
return np.array(a).flatten()
128+
129+
130+
def plot_cart(xt, theta):
131+
cart_w = 1.0
132+
cart_h = 0.5
133+
radius = 0.1
134+
135+
cx = np.array([-cart_w / 2.0, cart_w / 2.0, cart_w /
136+
2.0, -cart_w / 2.0, -cart_w / 2.0])
137+
cy = np.array([0.0, 0.0, cart_h, cart_h, 0.0])
138+
cy += radius * 2.0
139+
140+
cx = cx + xt
141+
142+
bx = np.array([0.0, l_bar * math.sin(-theta)])
143+
bx += xt
144+
by = np.array([cart_h, l_bar * math.cos(-theta) + cart_h])
145+
by += radius * 2.0
146+
147+
angles = np.arange(0.0, math.pi * 2.0, math.radians(3.0))
148+
ox = [radius * math.cos(a) for a in angles]
149+
oy = [radius * math.sin(a) for a in angles]
150+
151+
rwx = np.copy(ox) + cart_w / 4.0 + xt
152+
rwy = np.copy(oy) + radius
153+
lwx = np.copy(ox) - cart_w / 4.0 + xt
154+
lwy = np.copy(oy) + radius
155+
156+
wx = np.copy(ox) + float(bx[0, -1])
157+
wy = np.copy(oy) + float(by[0, -1])
158+
159+
plt.plot(flatten(cx), flatten(cy), "-b")
160+
plt.plot(flatten(bx), flatten(by), "-k")
161+
plt.plot(flatten(rwx), flatten(rwy), "-k")
162+
plt.plot(flatten(lwx), flatten(lwy), "-k")
163+
plt.plot(flatten(wx), flatten(wy), "-k")
164+
plt.title("x:" + str(round(xt, 2)) + ",theta:" +
165+
str(round(math.degrees(theta), 2)))
166+
167+
plt.axis("equal")
168+
169+
170+
if __name__ == '__main__':
171+
main()

0 commit comments

Comments
 (0)