Skip to content

Commit 5ba44b5

Browse files
authored
Merge pull request #1018 from murrayrm/timeresp_improvements-01Jun2024
Time response plot improvements
2 parents 8c1ddec + 10f009b commit 5ba44b5

13 files changed

+1626
-279
lines changed

control/ctrlplot.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
#
44
# Collection of functions that are used by various plotting functions.
55

6+
from os.path import commonprefix
7+
68
import matplotlib.pyplot as plt
79
import numpy as np
810

911
from . import config
1012

11-
__all__ = ['suptitle']
13+
__all__ = ['suptitle', 'get_plot_axes']
1214

1315

1416
def suptitle(
@@ -44,7 +46,6 @@ def suptitle(
4446

4547
elif frame == 'axes':
4648
# TODO: move common plotting params to 'ctrlplot'
47-
rcParams = config._get_param('freqplot', 'rcParams', rcParams)
4849
with plt.rc_context(rcParams):
4950
plt.tight_layout() # Put the figure into proper layout
5051
xc, _ = _find_axes_center(fig, fig.get_axes())
@@ -56,6 +57,93 @@ def suptitle(
5657
raise ValueError(f"unknown frame '{frame}'")
5758

5859

60+
# Create vectorized function to find axes from lines
61+
def get_plot_axes(line_array):
62+
"""Get a list of axes from an array of lines.
63+
64+
This function can be used to return the set of axes corresponding to
65+
the line array that is returned by `time_response_plot`. This is useful for
66+
generating an axes array that can be passed to subsequent plotting
67+
calls.
68+
69+
Parameters
70+
----------
71+
line_array : array of list of Line2D
72+
A 2D array with elements corresponding to a list of lines appearing
73+
in an axes, matching the return type of a time response data plot.
74+
75+
Returns
76+
-------
77+
axes_array : array of list of Axes
78+
A 2D array with elements corresponding to the Axes assocated with
79+
the lines in `line_array`.
80+
81+
Notes
82+
-----
83+
Only the first element of each array entry is used to determine the axes.
84+
85+
"""
86+
_get_axes = np.vectorize(lambda lines: lines[0].axes)
87+
return _get_axes(line_array)
88+
89+
#
90+
# Utility functions
91+
#
92+
93+
94+
# Utility function to make legend labels
95+
def _make_legend_labels(labels, ignore_common=False):
96+
97+
# Look for a common prefix (up to a space)
98+
common_prefix = commonprefix(labels)
99+
last_space = common_prefix.rfind(', ')
100+
if last_space < 0 or ignore_common:
101+
common_prefix = ''
102+
elif last_space > 0:
103+
common_prefix = common_prefix[:last_space]
104+
prefix_len = len(common_prefix)
105+
106+
# Look for a common suffix (up to a space)
107+
common_suffix = commonprefix(
108+
[label[::-1] for label in labels])[::-1]
109+
suffix_len = len(common_suffix)
110+
# Only chop things off after a comma or space
111+
while suffix_len > 0 and common_suffix[-suffix_len] != ',':
112+
suffix_len -= 1
113+
114+
# Strip the labels of common information
115+
if suffix_len > 0 and not ignore_common:
116+
labels = [label[prefix_len:-suffix_len] for label in labels]
117+
else:
118+
labels = [label[prefix_len:] for label in labels]
119+
120+
return labels
121+
122+
123+
def _update_suptitle(fig, title, rcParams=None, frame='axes'):
124+
if fig is not None and isinstance(title, str):
125+
# Get the current title, if it exists
126+
old_title = None if fig._suptitle is None else fig._suptitle._text
127+
128+
if old_title is not None:
129+
# Find the common part of the titles
130+
common_prefix = commonprefix([old_title, title])
131+
132+
# Back up to the last space
133+
last_space = common_prefix.rfind(' ')
134+
if last_space > 0:
135+
common_prefix = common_prefix[:last_space]
136+
common_len = len(common_prefix)
137+
138+
# Add the new part of the title (usually the system name)
139+
if old_title[common_len:] != title[common_len:]:
140+
separator = ',' if len(common_prefix) > 0 else ';'
141+
title = old_title + separator + title[common_len:]
142+
143+
# Add the title
144+
suptitle(title, fig=fig, rcParams=rcParams, frame=frame)
145+
146+
59147
def _find_axes_center(fig, axs):
60148
"""Find the midpoint between axes in display coordinates.
61149

control/frdata.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,14 @@ def plot(self, plot_type=None, *args, **kwargs):
653653

654654
# Convert to pandas
655655
def to_pandas(self):
656+
"""Convert response data to pandas data frame.
657+
658+
Creates a pandas data frame for the value of the frequency
659+
response at each `omega`. The frequency response values are
660+
labeled in the form "H_{<out>, <in>}" where "<out>" and "<in>"
661+
are replaced with the output and input labels for the system.
662+
663+
"""
656664
if not pandas_check():
657665
ImportError('pandas not installed')
658666
import pandas

control/freqplot.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020
from . import config
2121
from .bdalg import feedback
22-
from .ctrlplot import suptitle, _find_axes_center
22+
from .ctrlplot import suptitle, _find_axes_center, _make_legend_labels, \
23+
_update_suptitle
2324
from .ctrlutil import unwrap
2425
from .exception import ControlMIMONotImplemented
2526
from .frdata import FrequencyResponseData
2627
from .lti import LTI, _process_frequency_response, frequency_response
2728
from .margins import stability_margins
2829
from .statesp import StateSpace
29-
from .timeplot import _make_legend_labels
3030
from .xferfcn import TransferFunction
3131

3232
__all__ = ['bode_plot', 'NyquistResponseData', 'nyquist_response',
@@ -954,28 +954,7 @@ def gen_zero_centered_series(val_min, val_max, period):
954954
else:
955955
title = data[0].title
956956

957-
if fig is not None and isinstance(title, str):
958-
# Get the current title, if it exists
959-
old_title = None if fig._suptitle is None else fig._suptitle._text
960-
new_title = title
961-
962-
if old_title is not None:
963-
# Find the common part of the titles
964-
common_prefix = commonprefix([old_title, new_title])
965-
966-
# Back up to the last space
967-
last_space = common_prefix.rfind(' ')
968-
if last_space > 0:
969-
common_prefix = common_prefix[:last_space]
970-
common_len = len(common_prefix)
971-
972-
# Add the new part of the title (usually the system name)
973-
if old_title[common_len:] != new_title[common_len:]:
974-
separator = ',' if len(common_prefix) > 0 else ';'
975-
new_title = old_title + separator + new_title[common_len:]
976-
977-
# Add the title
978-
suptitle(title, fig=fig, rcParams=rcParams, frame=suptitle_frame)
957+
_update_suptitle(fig, title, rcParams=rcParams, frame=suptitle_frame)
979958

980959
#
981960
# Create legends
@@ -2717,12 +2696,13 @@ def _get_line_labels(ax, use_color=True):
27172696

27182697

27192698
# Turn label keyword into array indexed by trace, output, input
2720-
def _process_line_labels(label, nsys, ninputs=0, noutputs=0):
2699+
# TODO: move to ctrlutil.py and update parameter names to reflect general use
2700+
def _process_line_labels(label, ntraces, ninputs=0, noutputs=0):
27212701
if label is None:
27222702
return None
27232703

27242704
if isinstance(label, str):
2725-
label = [label]
2705+
label = [label] * ntraces # single label for all traces
27262706

27272707
# Convert to an ndarray, if not done aleady
27282708
try:
@@ -2734,12 +2714,14 @@ def _process_line_labels(label, nsys, ninputs=0, noutputs=0):
27342714
# TODO: allow more sophisticated broadcasting (and error checking)
27352715
try:
27362716
if ninputs > 0 and noutputs > 0:
2737-
if line_labels.ndim == 1:
2738-
line_labels = line_labels.reshape(nsys, 1, 1)
2739-
line_labels = np.broadcast_to(
2740-
line_labels,(nsys, ninputs, noutputs))
2717+
if line_labels.ndim == 1 and line_labels.size == ntraces:
2718+
line_labels = line_labels.reshape(ntraces, 1, 1)
2719+
line_labels = np.broadcast_to(
2720+
line_labels, (ntraces, ninputs, noutputs))
2721+
else:
2722+
line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
27412723
except:
2742-
if line_labels.shape[0] != nsys:
2724+
if line_labels.shape[0] != ntraces:
27432725
raise ValueError("number of labels must match number of traces")
27442726
else:
27452727
raise ValueError("labels must be given for each input/output pair")

control/nlsys.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from . import config
2828
from .iosys import InputOutputSystem, _parse_spec, _process_iosys_keywords, \
2929
_process_signal_list, common_timebase, isctime, isdtime
30-
from .timeresp import TimeResponseData, _check_convert_array, \
31-
_process_time_response
30+
from .timeresp import _check_convert_array, _process_time_response, \
31+
TimeResponseData, TimeResponseList
3232

3333
__all__ = ['NonlinearIOSystem', 'InterconnectedSystem', 'nlsys',
3434
'input_output_response', 'find_eqpt', 'linearize',
@@ -1327,8 +1327,8 @@ def input_output_response(
13271327
13281328
Parameters
13291329
----------
1330-
sys : InputOutputSystem
1331-
Input/output system to simulate.
1330+
sys : NonlinearIOSystem or list of NonlinearIOSystem
1331+
I/O system(s) for which input/output response is simulated.
13321332
13331333
T : array-like
13341334
Time steps at which the input is defined; values must be evenly spaced.
@@ -1448,6 +1448,16 @@ def input_output_response(
14481448
if kwargs:
14491449
raise TypeError("unrecognized keyword(s): ", str(kwargs))
14501450

1451+
# If passed a list, recursively call individual responses with given T
1452+
if isinstance(sys, (list, tuple)):
1453+
sysdata, responses = sys, []
1454+
for sys in sysdata:
1455+
responses.append(input_output_response(
1456+
sys, T, U=U, X0=X0, params=params, transpose=transpose,
1457+
return_x=return_x, squeeze=squeeze, t_eval=t_eval,
1458+
solve_ivp_kwargs=solve_ivp_kwargs, **kwargs))
1459+
return TimeResponseList(responses)
1460+
14511461
# Sanity checking on the input
14521462
if not isinstance(sys, NonlinearIOSystem):
14531463
raise TypeError("System of type ", type(sys), " not valid")

control/tests/kwargs_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def test_response_plot_kwargs(data_fcn, plot_fcn, mimo):
308308
'StateSpace.sample': test_unrecognized_kwargs,
309309
'TimeResponseData.__call__': trdata_test.test_response_copy,
310310
'TimeResponseData.plot': timeplot_test.test_errors,
311+
'TimeResponseList.plot': timeplot_test.test_errors,
311312
'TransferFunction.__init__': test_unrecognized_kwargs,
312313
'TransferFunction.sample': test_unrecognized_kwargs,
313314
'optimal.OptimalControlProblem.__init__':

0 commit comments

Comments
 (0)