Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1ed489c

Browse files
author
Shunichi09
committed
Update: Cartpole Env and Cartpole models
1 parent e523720 commit 1ed489c

File tree

12 files changed

+475
-43
lines changed

12 files changed

+475
-43
lines changed

‎Environments.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ X_g denote the goal states.
4141

4242
## [CatpoleEnv (Swing up)](PythonLinearNonlinearControl/envs/cartpole.py)
4343

44-
System equation.
44+
## System equation.
4545

4646
<img src="assets/cartpole.png" width="600">
4747

‎PythonLinearNonlinearControl/configs/cartpole.py‎

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ class CartPoleConfigModule():
1010
INPUT_SIZE = 1
1111
DT = 0.02
1212
# cost parameters
13-
R = np.diag([0.01])
13+
R = np.diag([1.]) # 0.01 is worked for MPPI and CEM and MPPIWilliams
14+
# 1. is worked for iLQR
15+
Terminal_Weight = 1.
16+
Q = None
17+
Sf = None
1418
# bounds
1519
INPUT_LOWER_BOUND = np.array([-3.])
1620
INPUT_UPPER_BOUND = np.array([3.])
@@ -128,12 +132,14 @@ def terminal_state_cost_fn(terminal_x, terminal_g_x):
128132
return (6. * (terminal_x[:, 0]**2) \
129133
+ 12. * ((np.cos(terminal_x[:, 2]) + 1.)**2) \
130134
+ 0.1 * (terminal_x[:, 1]**2) \
131-
+ 0.1 * (terminal_x[:, 3]**2))[:, np.newaxis]
135+
+ 0.1 * (terminal_x[:, 3]**2))[:, np.newaxis] \
136+
* CartPoleConfigModule.Terminal_Weight
132137

133-
return 6. * (terminal_x[0]**2) \
138+
return (6. * (terminal_x[0]**2) \
134139
+ 12. * ((np.cos(terminal_x[2]) + 1.)**2) \
135140
+ 0.1 * (terminal_x[1]**2) \
136-
+ 0.1 * (terminal_x[3]**2)
141+
+ 0.1 * (terminal_x[3]**2)) \
142+
* CartPoleConfigModule.Terminal_Weight
137143

138144
@staticmethod
139145
def gradient_cost_fn_with_state(x, g_x, terminal=False):
@@ -148,9 +154,21 @@ def gradient_cost_fn_with_state(x, g_x, terminal=False):
148154
or shape(1, state_size)
149155
"""
150156
if not terminal:
151-
return None
157+
cost_dx0 = 12. * x[:, 0]
158+
cost_dx1 = 0.2 * x[:, 1]
159+
cost_dx2 = 24. * (1 + np.cos(x[:, 2])) * -np.sin(x[:, 2])
160+
cost_dx3 = 0.2 * x[:, 3]
161+
cost_dx = np.stack((cost_dx0, cost_dx1,\
162+
cost_dx2, cost_dx3), axis=1)
163+
return cost_dx
152164

153-
return None
165+
cost_dx0 = 12. * x[0]
166+
cost_dx1 = 0.2 * x[1]
167+
cost_dx2 = 24. * (1 + np.cos(x[2])) * -np.sin(x[2])
168+
cost_dx3 = 0.2 * x[3]
169+
cost_dx = np.array([[cost_dx0, cost_dx1, cost_dx2, cost_dx3]])
170+
171+
return cost_dx * CartPoleConfigModule.Terminal_Weight
154172

155173
@staticmethod
156174
def gradient_cost_fn_with_input(x, u):
@@ -163,7 +181,7 @@ def gradient_cost_fn_with_input(x, u):
163181
Returns:
164182
l_u (numpy.ndarray): gradient of cost, shape(pred_len, input_size)
165183
"""
166-
return None
184+
return 2.*u*np.diag(CartPoleConfigModule.R)
167185

168186
@staticmethod
169187
def hessian_cost_fn_with_state(x, g_x, terminal=False):
@@ -179,10 +197,30 @@ def hessian_cost_fn_with_state(x, g_x, terminal=False):
179197
shape(1, state_size, state_size) or
180198
"""
181199
if not terminal:
182-
(pred_len, _) = x.shape
183-
return None
200+
(pred_len, state_size) = x.shape
201+
hessian = np.eye(state_size)
202+
hessian = np.tile(hessian, (pred_len, 1, 1))
203+
hessian[:, 0, 0] = 12.
204+
hessian[:, 1, 1] = 0.2
205+
hessian[:, 2, 2] = 24. * -np.sin(x[:, 2]) \
206+
* (-np.sin(x[:, 2])) \
207+
+ 24. * (1. + np.cos(x[:, 2])) \
208+
* -np.cos(x[:, 2])
209+
hessian[:, 3, 3] = 0.2
210+
211+
return hessian
184212

185-
return None
213+
state_size = len(x)
214+
hessian = np.eye(state_size)
215+
hessian[0, 0] = 12.
216+
hessian[1, 1] = 0.2
217+
hessian[2, 2] = 24. * -np.sin(x[2]) \
218+
* (-np.sin(x[2])) \
219+
+ 24. * (1. + np.cos(x[2])) \
220+
* -np.cos(x[2])
221+
hessian[3, 3] = 0.2
222+
223+
return hessian[np.newaxis, :, :] * CartPoleConfigModule.Terminal_Weight
186224

187225
@staticmethod
188226
def hessian_cost_fn_with_input(x, u):
@@ -198,7 +236,7 @@ def hessian_cost_fn_with_input(x, u):
198236
"""
199237
(pred_len, _) = u.shape
200238

201-
return None
239+
return np.tile(2.*CartPoleConfigModule.R, (pred_len, 1, 1))
202240

203241
@staticmethod
204242
def hessian_cost_fn_with_input_state(x, u):

‎PythonLinearNonlinearControl/configs/first_order_lag.py‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,9 @@ def hessian_cost_fn_with_state(x, g_x, terminal=False):
159159
"""
160160
if not terminal:
161161
(pred_len, _) = x.shape
162-
return -g_x[:, :, np.newaxis] \
163-
* np.tile(2.*FirstOrderLagConfigModule.Q, (pred_len, 1, 1))
162+
return np.tile(2.*FirstOrderLagConfigModule.Q, (pred_len, 1, 1))
164163

165-
return -g_x[:, np.newaxis] \
166-
* np.tile(2.*FirstOrderLagConfigModule.Sf, (1, 1, 1))
164+
return np.tile(2.*FirstOrderLagConfigModule.Sf, (1, 1, 1))
167165

168166
@staticmethod
169167
def hessian_cost_fn_with_input(x, u):

‎PythonLinearNonlinearControl/configs/two_wheeled.py‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,9 @@ def hessian_cost_fn_with_state(x, g_x, terminal=False):
153153
"""
154154
if not terminal:
155155
(pred_len, _) = x.shape
156-
return -g_x[:, :, np.newaxis] \
157-
* np.tile(2.*TwoWheeledConfigModule.Q, (pred_len, 1, 1))
156+
return np.tile(2.*TwoWheeledConfigModule.Q, (pred_len, 1, 1))
158157

159-
return -g_x[:, np.newaxis] \
160-
* np.tile(2.*TwoWheeledConfigModule.Sf, (1, 1, 1))
158+
return np.tile(2.*TwoWheeledConfigModule.Sf, (1, 1, 1))
161159

162160
@staticmethod
163161
def hessian_cost_fn_with_input(x, u):

‎PythonLinearNonlinearControl/controllers/ilqr.py‎

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ def __init__(self, config, model):
5050
self.input_size = config.INPUT_SIZE
5151
self.dt = config.DT
5252

53-
# cost parameters
54-
self.Q = config.Q
55-
self.R = config.R
56-
self.Sf = config.Sf
57-
5853
# initialize
5954
self.prev_sol = np.zeros((self.pred_len, self.input_size))
6055

‎PythonLinearNonlinearControl/envs/cartpole.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def reset(self, init_x=None):
3737
"""
3838
self.step_count = 0
3939

40-
self.curr_x = np.array([0., 0., 0., 0.])
40+
theta = np.random.randn(1)
41+
self.curr_x = np.array([0., 0., theta[0], 0.])
4142

4243
if init_x is not None:
4344
self.curr_x = init_x

‎PythonLinearNonlinearControl/models/cartpole.py‎

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,41 @@ def calc_f_x(self, xs, us, dt):
9090

9191
f_x = np.zeros((pred_len, state_size, state_size))
9292

93-
f_x[:, 0, 2] = -np.sin(xs[:, 2]) * us[:, 0]
94-
f_x[:, 1, 2] = np.cos(xs[:, 2]) * us[:, 0]
93+
# f_x_dot
94+
f_x[:, 0, 1] = np.ones(pred_len)
95+
96+
# f_theta
97+
tmp = ((self.mc + self.mp * np.sin(xs[:, 2])**2)**(-2)) \
98+
* self.mp * 2. * np.sin(xs[:, 2]) * np.cos(xs[:, 2])
99+
tmp2 = 1. / (self.mc + self.mp * (np.sin(xs[:, 2])**2))
100+
101+
f_x[:, 1, 2] = - us[:, 0] * tmp \
102+
- tmp * (self.mp * np.sin(xs[:, 2]) \
103+
* (self.l * xs[:, 3]**2 \
104+
+ self.g * np.cos(xs[:, 2]))) \
105+
+ tmp2 * (self.mp * np.cos(xs[:, 2]) * self.l \
106+
* xs[:, 3]**2 \
107+
+ self.mp * self.g * (np.cos(xs[:, 2])**2 \
108+
- np.sin(xs[:, 2])**2))
109+
f_x[:, 3, 2] = - 1. / self.l * tmp \
110+
* (-us[:, 0] * np.cos(xs[:, 2]) \
111+
- self.mp * self.l * (xs[:, 3]**2) \
112+
* np.cos(xs[:, 2]) * np.sin(xs[:, 2]) \
113+
- (self.mc + self.mp) * self.g * np.sin(xs[:, 2])) \
114+
+ 1. / self.l * tmp2 \
115+
* (us[:, 0] * np.sin(xs[:, 2]) \
116+
- self.mp * self.l * xs[:, 3]**2 \
117+
* (np.cos(xs[:, 2])**2 - np.sin(xs[:, 2])**2) \
118+
- (self.mc + self.mp) \
119+
* self.g * np.cos(xs[:, 2]))
120+
121+
# f_theta_dot
122+
f_x[:, 1, 3] = tmp2 * (self.mp * np.sin(xs[:, 2]) \
123+
* self.l * 2 * xs[:, 3])
124+
f_x[:, 2, 3] = np.ones(pred_len)
125+
f_x[:, 3, 3] = 1. / self.l * tmp2 \
126+
* (-2. * self.mp * self.l * xs[:, 3] \
127+
* np.cos(xs[:, 2]) * np.sin(xs[:, 2]))
95128

96129
return f_x * dt + np.eye(state_size) # to discrete form
97130

@@ -139,10 +172,7 @@ def calc_f_xx(self, xs, us, dt):
139172

140173
f_xx = np.zeros((pred_len, state_size, state_size, state_size))
141174

142-
f_xx[:, 0, 2, 2] = -np.cos(xs[:, 2]) * us[:, 0]
143-
f_xx[:, 1, 2, 2] = -np.sin(xs[:, 2]) * us[:, 0]
144-
145-
return f_xx * dt
175+
raise NotImplementedError
146176

147177
def calc_f_ux(self, xs, us, dt):
148178
""" hessian of model with respect to state and input in batch form
@@ -161,11 +191,8 @@ def calc_f_ux(self, xs, us, dt):
161191

162192
f_ux = np.zeros((pred_len, state_size, input_size, state_size))
163193

164-
f_ux[:, 0, 0, 2] = -np.sin(xs[:, 2])
165-
f_ux[:, 1, 0, 2] = np.cos(xs[:, 2])
194+
raise NotImplementedError
166195

167-
return f_ux * dt
168-
169196
def calc_f_uu(self, xs, us, dt):
170197
""" hessian of model with respect to input in batch form
171198
@@ -183,4 +210,4 @@ def calc_f_uu(self, xs, us, dt):
183210

184211
f_uu = np.zeros((pred_len, state_size, input_size, input_size))
185212

186-
returnf_uu*dt
213+
raiseNotImplementedError

‎scripts/simple_run.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def run(args):
4242
def main():
4343
parser = argparse.ArgumentParser()
4444

45-
parser.add_argument("--controller_type", type=str, default="CEM")
45+
parser.add_argument("--controller_type", type=str, default="DDP")
4646
parser.add_argument("--planner_type", type=str, default="const")
47-
parser.add_argument("--env", type=str, default="TwoWheeledConst")
47+
parser.add_argument("--env", type=str, default="CartPole")
4848
parser.add_argument("--result_dir", type=str, default="./result")
4949

5050
args = parser.parse_args()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /