Skip to content

Commit 47262f5

Browse files
committed
update coeff handling to allow multi-variable basis
1 parent 2901cbe commit 47262f5

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

control/optimal.py

+30-35
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,16 @@ def _cost_function(self, coeffs):
268268
start_time = time.process_time()
269269
logging.info("_cost_function called at: %g", start_time)
270270

271-
# Retrieve the initial state and reshape the input vector
271+
# Retrieve the saved initial state
272272
x = self.x
273-
coeffs = coeffs.reshape((self.system.ninputs, -1))
274273

275-
# Compute time points (if basis present)
274+
# Compute inputs
276275
if self.basis:
277276
if self.log:
278277
logging.debug("coefficients = " + str(coeffs))
279278
inputs = self._coeffs_to_inputs(coeffs)
280279
else:
281-
inputs = coeffs
280+
inputs = coeffs.reshape((self.system.ninputs, -1))
282281

283282
# See if we already have a simulation for this condition
284283
if np.array_equal(coeffs, self.last_coeffs) and \
@@ -391,15 +390,14 @@ def _constraint_function(self, coeffs):
391390
start_time = time.process_time()
392391
logging.info("_constraint_function called at: %g", start_time)
393392

394-
# Retrieve the initial state and reshape the input vector
393+
# Retrieve the initial state
395394
x = self.x
396-
coeffs = coeffs.reshape((self.system.ninputs, -1))
397395

398-
# Compute time points (if basis present)
396+
# Compute input at time points
399397
if self.basis:
400398
inputs = self._coeffs_to_inputs(coeffs)
401399
else:
402-
inputs = coeffs
400+
inputs = coeffs.reshape((self.system.ninputs, -1))
403401

404402
# See if we already have a simulation for this condition
405403
if np.array_equal(coeffs, self.last_coeffs) \
@@ -473,15 +471,14 @@ def _eqconst_function(self, coeffs):
473471
start_time = time.process_time()
474472
logging.info("_eqconst_function called at: %g", start_time)
475473

476-
# Retrieve the initial state and reshape the input vector
474+
# Retrieve the initial state
477475
x = self.x
478-
coeffs = coeffs.reshape((self.system.ninputs, -1))
479476

480-
# Compute time points (if basis present)
477+
# Compute input at time points
481478
if self.basis:
482479
inputs = self._coeffs_to_inputs(coeffs)
483480
else:
484-
inputs = coeffs
481+
inputs = coeffs.reshape((self.system.ninputs, -1))
485482

486483
# See if we already have a simulation for this condition
487484
if np.array_equal(coeffs, self.last_coeffs) and \
@@ -609,34 +606,36 @@ def _inputs_to_coeffs(self, inputs):
609606
return inputs
610607

611608
# Solve least squares problems (M x = b) for coeffs on each input
612-
coeffs = np.zeros((self.system.ninputs, self.basis.N))
609+
coeffs = []
613610
for i in range(self.system.ninputs):
614611
# Set up the matrices to get inputs
615-
M = np.zeros((self.timepts.size, self.basis.N))
612+
M = np.zeros((self.timepts.size, self.basis.var_ncoefs(i)))
616613
b = np.zeros(self.timepts.size)
617614

618615
# Evaluate at each time point and for each basis function
619616
# TODO: vectorize
620617
for j, t in enumerate(self.timepts):
621-
for k in range(self.basis.N):
618+
for k in range(self.basis.var_ncoefs(i)):
622619
M[j, k] = self.basis(k, t)
623-
b[j] = inputs[i, j]
620+
b[j] = inputs[i, j]
624621

625622
# Solve a least squares problem for the coefficients
626623
alpha, residuals, rank, s = np.linalg.lstsq(M, b, rcond=None)
627-
coeffs[i, :] = alpha
624+
coeffs.append(alpha)
628625

629-
return coeffs
626+
return np.hstack(coeffs)
630627

631628
# Utility function to convert coefficient vector to input vector
632629
def _coeffs_to_inputs(self, coeffs):
633630
# TODO: vectorize
634631
inputs = np.zeros((self.system.ninputs, self.timepts.size))
635-
for i, t in enumerate(self.timepts):
636-
for k in range(self.basis.N):
637-
phi_k = self.basis(k, t)
638-
for inp in range(self.system.ninputs):
639-
inputs[inp, i] += coeffs[inp, k] * phi_k
632+
offset = 0
633+
for i in range(self.system.ninputs):
634+
length = self.basis.var_ncoefs(i)
635+
for j, t in enumerate(self.timepts):
636+
for k in range(length):
637+
inputs[i, j] += coeffs[offset + k] * self.basis(k, t)
638+
offset += length
640639
return inputs
641640

642641
#
@@ -680,7 +679,7 @@ def _print_statistics(self, reset=True):
680679

681680
# Compute the optimal trajectory from the current state
682681
def compute_trajectory(
683-
self, x, squeeze=None, transpose=None, return_states=None,
682+
self, x, squeeze=None, transpose=None, return_states=True,
684683
initial_guess=None, print_summary=True, **kwargs):
685684
"""Compute the optimal input at state x
686685
@@ -689,8 +688,7 @@ def compute_trajectory(
689688
x : array-like or number, optional
690689
Initial state for the system.
691690
return_states : bool, optional
692-
If True, return the values of the state at each time (default =
693-
False).
691+
If True (default), return the values of the state at each time.
694692
squeeze : bool, optional
695693
If True and if the system has a single output, return the system
696694
output as a 1D array rather than a 2D array. If False, return the
@@ -837,7 +835,7 @@ class OptimalControlResult(sp.optimize.OptimizeResult):
837835
838836
"""
839837
def __init__(
840-
self, ocp, res, return_states=False, print_summary=False,
838+
self, ocp, res, return_states=True, print_summary=False,
841839
transpose=None, squeeze=None):
842840
"""Create a OptimalControlResult object"""
843841

@@ -848,14 +846,11 @@ def __init__(
848846
# Remember the optimal control problem that we solved
849847
self.problem = ocp
850848

851-
# Reshape and process the input vector
852-
coeffs = res.x.reshape((ocp.system.ninputs, -1))
853-
854-
# Compute time points (if basis present)
849+
# Compute input at time points
855850
if ocp.basis:
856-
inputs = ocp._coeffs_to_inputs(coeffs)
851+
inputs = ocp._coeffs_to_inputs(res.x)
857852
else:
858-
inputs = coeffs
853+
inputs = res.x.reshape((ocp.system.ninputs, -1))
859854

860855
# See if we got an answer
861856
if not res.success:
@@ -894,7 +889,7 @@ def __init__(
894889
def solve_ocp(
895890
sys, horizon, X0, cost, trajectory_constraints=None, terminal_cost=None,
896891
terminal_constraints=[], initial_guess=None, basis=None, squeeze=None,
897-
transpose=None, return_states=False, log=False, **kwargs):
892+
transpose=None, return_states=True, log=False, **kwargs):
898893

899894
"""Compute the solution to an optimal control problem
900895
@@ -949,7 +944,7 @@ def solve_ocp(
949944
If `True`, turn on logging messages (using Python logging module).
950945
951946
return_states : bool, optional
952-
If True, return the values of the state at each time (default = False).
947+
If True, return the values of the state at each time (default = True).
953948
954949
squeeze : bool, optional
955950
If True and if the system has a single output, return the system

control/tests/optimal_test.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,9 @@ def test_terminal_constraints(sys_args):
300300
np.testing.assert_almost_equal(res.inputs, u1)
301301

302302
# Re-run using a basis function and see if we get the same answer
303-
res = opt.solve_ocp(sys, time, x0, cost, terminal_constraints=final_point,
304-
basis=flat.BezierFamily(8, Tf))
303+
res = opt.solve_ocp(
304+
sys, time, x0, cost, terminal_constraints=final_point,
305+
basis=flat.BezierFamily(8, Tf))
305306

306307
# Final point doesn't affect cost => don't need to test
307308
np.testing.assert_almost_equal(
@@ -471,8 +472,12 @@ def test_ocp_argument_errors():
471472
sys, time, x0, cost, terminal_constraints=constraints)
472473

473474

474-
def test_optimal_basis_simple():
475-
sys = ct.ss2io(ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1))
475+
@pytest.mark.parametrize("basis", [
476+
flat.PolyFamily(4), flat.PolyFamily(6),
477+
flat.BezierFamily(4), flat.BSplineFamily([0, 4, 8], 6)
478+
])
479+
def test_optimal_basis_simple(basis):
480+
sys = ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1)
476481

477482
# State and input constraints
478483
constraints = [
@@ -492,7 +497,7 @@ def test_optimal_basis_simple():
492497
# Basic optimal control problem
493498
res1 = opt.solve_ocp(
494499
sys, time, x0, cost, constraints,
495-
basis=flat.BezierFamily(4, Tf), return_x=True)
500+
terminal_cost=cost, basis=basis, return_x=True)
496501
assert res1.success
497502

498503
# Make sure the constraints were satisfied
@@ -503,14 +508,14 @@ def test_optimal_basis_simple():
503508
# Pass an initial guess and rerun
504509
res2 = opt.solve_ocp(
505510
sys, time, x0, cost, constraints, initial_guess=0.99*res1.inputs,
506-
basis=flat.BezierFamily(4, Tf), return_x=True)
511+
terminal_cost=cost, basis=basis, return_x=True)
507512
assert res2.success
508513
np.testing.assert_allclose(res2.inputs, res1.inputs, atol=0.01, rtol=0.01)
509514

510515
# Run with logging turned on for code coverage
511516
res3 = opt.solve_ocp(
512-
sys, time, x0, cost, constraints,
513-
basis=flat.BezierFamily(4, Tf), return_x=True, log=True)
517+
sys, time, x0, cost, constraints, terminal_cost=cost,
518+
basis=basis, return_x=True, log=True)
514519
assert res3.success
515520
np.testing.assert_almost_equal(res3.inputs, res1.inputs, decimal=3)
516521

0 commit comments

Comments
 (0)