Skip to content

Commit 4d89991

Browse files
authored
Merge pull request #800 from murrayrm/benchmarks-24Aug2022
Update benchmarks to help with optimal control tuning
2 parents ed4ff84 + 54de2a3 commit 4d89991

File tree

9 files changed

+285
-138
lines changed

9 files changed

+285
-138
lines changed

benchmarks/README

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ you can use the following command from the root directory of the repository:
1111

1212
PYTHONPATH=`pwd` asv run --python=python
1313

14-
You can also run benchmarks against specific commits usuing
14+
You can also run benchmarks against specific commits using
1515

1616
asv run <range>
1717

benchmarks/flatsys_bench.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import control.flatsys as flat
1212
import control.optimal as opt
1313

14+
#
15+
# System setup: vehicle steering (bicycle model)
16+
#
17+
1418
# Vehicle steering dynamics
1519
def vehicle_update(t, x, u, params):
1620
# Get the parameters for the model
@@ -67,11 +71,28 @@ def vehicle_reverse(zflag, params={}):
6771
# Define the time points where the cost/constraints will be evaluated
6872
timepts = np.linspace(0, Tf, 10, endpoint=True)
6973

70-
def time_steering_point_to_point(basis_name, basis_size):
71-
if basis_name == 'poly':
72-
basis = flat.PolyFamily(basis_size)
73-
elif basis_name == 'bezier':
74-
basis = flat.BezierFamily(basis_size)
74+
#
75+
# Benchmark test parameters
76+
#
77+
78+
basis_params = (['poly', 'bezier', 'bspline'], [8, 10, 12])
79+
basis_param_names = ["basis", "size"]
80+
81+
def get_basis(name, size):
82+
if name == 'poly':
83+
basis = flat.PolyFamily(size, T=Tf)
84+
elif name == 'bezier':
85+
basis = flat.BezierFamily(size, T=Tf)
86+
elif name == 'bspline':
87+
basis = flat.BSplineFamily([0, Tf/2, Tf], size)
88+
return basis
89+
90+
#
91+
# Benchmarks
92+
#
93+
94+
def time_point_to_point(basis_name, basis_size):
95+
basis = get_basis(basis_name, basis_size)
7596

7697
# Find trajectory between initial and final conditions
7798
traj = flat.point_to_point(vehicle, Tf, x0, u0, xf, uf, basis=basis)
@@ -80,13 +101,16 @@ def time_steering_point_to_point(basis_name, basis_size):
80101
x, u = traj.eval([0, Tf])
81102
np.testing.assert_array_almost_equal(x0, x[:, 0])
82103
np.testing.assert_array_almost_equal(u0, u[:, 0])
83-
np.testing.assert_array_almost_equal(xf, x[:, 1])
84-
np.testing.assert_array_almost_equal(uf, u[:, 1])
104+
np.testing.assert_array_almost_equal(xf, x[:, -1])
105+
np.testing.assert_array_almost_equal(uf, u[:, -1])
106+
107+
time_point_to_point.params = basis_params
108+
time_point_to_point.param_names = basis_param_names
85109

86-
time_steering_point_to_point.params = (['poly', 'bezier'], [6, 8])
87-
time_steering_point_to_point.param_names = ["basis", "size"]
88110

89-
def time_steering_cost():
111+
def time_point_to_point_with_cost(basis_name, basis_size):
112+
basis = get_basis(basis_name, basis_size)
113+
90114
# Define cost and constraints
91115
traj_cost = opt.quadratic_cost(
92116
vehicle, None, np.diag([0.1, 1]), u0=uf)
@@ -95,13 +119,47 @@ def time_steering_cost():
95119

96120
traj = flat.point_to_point(
97121
vehicle, timepts, x0, u0, xf, uf,
98-
cost=traj_cost, constraints=constraints, basis=flat.PolyFamily(8)
122+
cost=traj_cost, constraints=constraints, basis=basis,
99123
)
100124

101125
# Verify that the trajectory computation is correct
102126
x, u = traj.eval([0, Tf])
103127
np.testing.assert_array_almost_equal(x0, x[:, 0])
104128
np.testing.assert_array_almost_equal(u0, u[:, 0])
105-
np.testing.assert_array_almost_equal(xf, x[:, 1])
106-
np.testing.assert_array_almost_equal(uf, u[:, 1])
129+
np.testing.assert_array_almost_equal(xf, x[:, -1])
130+
np.testing.assert_array_almost_equal(uf, u[:, -1])
131+
132+
time_point_to_point_with_cost.params = basis_params
133+
time_point_to_point_with_cost.param_names = basis_param_names
134+
135+
136+
def time_solve_flat_ocp_terminal_cost(method, basis_name, basis_size):
137+
basis = get_basis(basis_name, basis_size)
138+
139+
# Define cost and constraints
140+
traj_cost = opt.quadratic_cost(
141+
vehicle, None, np.diag([0.1, 1]), u0=uf)
142+
term_cost = opt.quadratic_cost(
143+
vehicle, np.diag([1e3, 1e3, 1e3]), None, x0=xf)
144+
constraints = [
145+
opt.input_range_constraint(vehicle, [8, -0.1], [12, 0.1]) ]
146+
147+
# Initial guess = straight line
148+
initial_guess = np.array(
149+
[x0[i] + (xf[i] - x0[i]) * timepts/Tf for i in (0, 1)])
150+
151+
traj = flat.solve_flat_ocp(
152+
vehicle, timepts, x0, u0, basis=basis, initial_guess=initial_guess,
153+
trajectory_cost=traj_cost, constraints=constraints,
154+
terminal_cost=term_cost, minimize_method=method,
155+
)
156+
157+
# Verify that the trajectory computation is correct
158+
x, u = traj.eval([0, Tf])
159+
np.testing.assert_array_almost_equal(x0, x[:, 0])
160+
np.testing.assert_array_almost_equal(xf, x[:, -1], decimal=2)
107161

162+
time_solve_flat_ocp_terminal_cost.params = tuple(
163+
[['slsqp', 'trust-constr']] + list(basis_params))
164+
time_solve_flat_ocp_terminal_cost.param_names = tuple(
165+
['method'] + basis_param_names)

0 commit comments

Comments
 (0)