Skip to content

Commit 0b6348f

Browse files
committed
add label keyword to singular_value_plot and nyquist_plot
1 parent b9acc99 commit 0b6348f

File tree

2 files changed

+59
-35
lines changed

2 files changed

+59
-35
lines changed

control/freqplot.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@ def nyquist_response(
15011501

15021502

15031503
def nyquist_plot(
1504-
data, omega=None, plot=None, label_freq=0, color=None,
1504+
data, omega=None, plot=None, label_freq=0, color=None, label=None,
15051505
return_contour=None, title=None, legend_loc='upper right', **kwargs):
15061506
"""Nyquist plot for a system.
15071507
@@ -1590,6 +1590,11 @@ def nyquist_plot(
15901590
imaginary axis. Portions of the Nyquist plot corresponding to indented
15911591
portions of the contour are plotted using a different line style.
15921592
1593+
label : str or array-like of str
1594+
If present, replace automatically generated label(s) with the given
1595+
label(s). If sysdata is a list, strings should be specified for each
1596+
system.
1597+
15931598
label_freq : int, optiona
15941599
Label every nth frequency on the plot. If not specified, no labels
15951600
are generated.
@@ -1739,6 +1744,9 @@ def _parse_linestyle(style_name, allow_false=False):
17391744
if not isinstance(data, (list, tuple)):
17401745
data = [data]
17411746

1747+
# Process label keyword
1748+
line_labels = _process_line_labels(label, len(data))
1749+
17421750
# If we are passed a list of systems, compute response first
17431751
if all([isinstance(
17441752
sys, (StateSpace, TransferFunction, FrequencyResponseData))
@@ -1804,12 +1812,14 @@ def _parse_linestyle(style_name, allow_false=False):
18041812
reg_mask, abs(resp) > max_curve_magnitude)
18051813
resp[rescale] *= max_curve_magnitude / abs(resp[rescale])
18061814

1815+
# Get the label to use for the line
1816+
label = response.sysname if line_labels is None else line_labels[idx]
1817+
18071818
# Plot the regular portions of the curve (and grab the color)
18081819
x_reg = np.ma.masked_where(reg_mask, resp.real)
18091820
y_reg = np.ma.masked_where(reg_mask, resp.imag)
18101821
p = plt.plot(
1811-
x_reg, y_reg, primary_style[0], color=color,
1812-
label=response.sysname, **kwargs)
1822+
x_reg, y_reg, primary_style[0], color=color, label=label, **kwargs)
18131823
c = p[0].get_color()
18141824
out[idx] += p
18151825

@@ -2211,7 +2221,7 @@ def singular_values_response(
22112221

22122222
def singular_values_plot(
22132223
data, omega=None, *fmt, plot=None, omega_limits=None, omega_num=None,
2214-
title=None, legend_loc='center right', **kwargs):
2224+
label=None, title=None, legend_loc='center right', **kwargs):
22152225
"""Plot the singular values for a system.
22162226
22172227
Plot the singular values as a function of frequency for a system or
@@ -2257,6 +2267,10 @@ def singular_values_plot(
22572267
grid : bool
22582268
If True, plot grid lines on gain and phase plots. Default is set by
22592269
`config.defaults['freqplot.grid']`.
2270+
label : str or array-like of str
2271+
If present, replace automatically generated label(s) with the given
2272+
label(s). If sysdata is a list, strings should be specified for each
2273+
system.
22602274
omega_limits : array_like of two values
22612275
Set limits for plotted frequency range. If Hz=True the limits
22622276
are in Hz otherwise in rad/s.
@@ -2306,6 +2320,9 @@ def singular_values_plot(
23062320

23072321
responses = data
23082322

2323+
# Process label keyword
2324+
line_labels = _process_line_labels(label, len(data))
2325+
23092326
# Process (legacy) plot keyword
23102327
if plot is not None:
23112328
warnings.warn(
@@ -2385,11 +2402,14 @@ def singular_values_plot(
23852402
with plt.rc_context(freqplot_rcParams):
23862403
out[idx_sys] = ax_sigma.semilogx(
23872404
omega, 20 * np.log10(sigma), *fmt,
2388-
label=sysname, **color_arg, **kwargs)
2405+
label=label, **color_arg, **kwargs)
23892406
else:
23902407
with plt.rc_context(freqplot_rcParams):
23912408
out[idx_sys] = ax_sigma.loglog(
2392-
omega, sigma, label=sysname, *fmt, **color_arg, **kwargs)
2409+
omega, sigma, label=label, *fmt, **color_arg, **kwargs)
2410+
2411+
# Get the label to use for the line
2412+
label = sysname if line_labels is None else line_labels[idx]
23932413

23942414
# Plot the Nyquist frequency
23952415
if nyq_freq is not None:
@@ -2653,7 +2673,7 @@ def _get_line_labels(ax, use_color=True):
26532673

26542674

26552675
# Turn label keyword into array indexed by trace, output, input
2656-
def _process_line_labels(label, nsys, ninput, noutput):
2676+
def _process_line_labels(label, nsys, ninputs=0, noutputs=0):
26572677
if label is None:
26582678
return None
26592679

@@ -2669,7 +2689,8 @@ def _process_line_labels(label, nsys, ninput, noutput):
26692689
# Turn the data into a 3D array of appropriate shape
26702690
# TODO: allow more sophisticated broadcasting
26712691
try:
2672-
line_labels = line_labels.reshape(nsys, ninput, noutput)
2692+
if ninputs > 0 and noutputs > 0:
2693+
line_labels = line_labels.reshape(nsys, ninputs, noutputs)
26732694
except:
26742695
if line_labels.shape[0] != nsys:
26752696
raise ValueError("number of labels must match number of traces")

control/tests/freqplot_test.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -346,57 +346,60 @@ def _get_visible_limits(ax):
346346
_get_visible_limits(ax.reshape(-1)[0]), np.array([1, 100]))
347347

348348

349-
def test_freqplot_trace_labels():
349+
@pytest.mark.parametrize(
350+
"plt_fcn", [ct.bode_plot, ct.singular_values_plot, ct.nyquist_plot])
351+
def test_bode_trace_labels(plt_fcn):
350352
sys1 = ct.rss(2, 1, 1, name='sys1')
351353
sys2 = ct.rss(3, 1, 1, name='sys2')
352354

353355
# Make sure default labels are as expected
354-
out = ct.bode_plot([sys1, sys2])
356+
out = ct.plt_fcn([sys1, sys2])
355357
axs = ct.get_plot_axes(out)
356358
legend = axs[0, 0].get_legend().get_texts()
357359
assert legend[0].get_text() == 'sys1'
358360
assert legend[1].get_text() == 'sys2'
359361
plt.close()
360362

361363
# Override labels all at once
362-
out = ct.bode_plot([sys1, sys2], label=['line1', 'line2'])
364+
out = ct.plt_fcn([sys1, sys2], label=['line1', 'line2'])
363365
axs = ct.get_plot_axes(out)
364366
legend = axs[0, 0].get_legend().get_texts()
365367
assert legend[0].get_text() == 'line1'
366368
assert legend[1].get_text() == 'line2'
367369
plt.close()
368370

369371
# Override labels one at a time
370-
out = ct.bode_plot(sys1, label='line1')
371-
out = ct.bode_plot(sys2, label='line2')
372+
out = ct.plt_fcn(sys1, label='line1')
373+
out = ct.plt_fcn(sys2, label='line2')
372374
axs = ct.get_plot_axes(out)
373375
legend = axs[0, 0].get_legend().get_texts()
374376
assert legend[0].get_text() == 'line1'
375377
assert legend[1].get_text() == 'line2'
376378
plt.close()
377379

378-
# Multi-dimensional data
379-
sys1 = ct.rss(2, 2, 2, name='sys1')
380-
sys2 = ct.rss(3, 2, 2, name='sys2')
381-
382-
# Check out some errors first
383-
with pytest.raises(ValueError, match="number of labels must match"):
384-
ct.bode_plot([sys1, sys2], label=['line1'])
385-
with pytest.raises(ValueError, match="labels must be given for each"):
386-
ct.bode_plot(sys1, label=['line1'])
387-
388-
# Now do things that should work
389-
out = ct.bode_plot(
390-
[sys1, sys2],
391-
label=[
392-
[['line1', 'line1'], ['line1', 'line1']],
393-
[['line2', 'line2'], ['line2', 'line2']],
394-
])
395-
axs = ct.get_plot_axes(out)
396-
legend = axs[0, -1].get_legend().get_texts()
397-
assert legend[0].get_text() == 'line1'
398-
assert legend[1].get_text() == 'line2'
399-
plt.close()
380+
if plt_fcn == ct.bode_plot:
381+
# Multi-dimensional data
382+
sys1 = ct.rss(2, 2, 2, name='sys1')
383+
sys2 = ct.rss(3, 2, 2, name='sys2')
384+
385+
# Check out some errors first
386+
with pytest.raises(ValueError, match="number of labels must match"):
387+
ct.bode_plot([sys1, sys2], label=['line1'])
388+
with pytest.raises(ValueError, match="labels must be given for each"):
389+
ct.bode_plot(sys1, label=['line1'])
390+
391+
# Now do things that should work
392+
out = ct.bode_plot(
393+
[sys1, sys2],
394+
label=[
395+
[['line1', 'line1'], ['line1', 'line1']],
396+
[['line2', 'line2'], ['line2', 'line2']],
397+
])
398+
axs = ct.get_plot_axes(out)
399+
legend = axs[0, -1].get_legend().get_texts()
400+
assert legend[0].get_text() == 'line1'
401+
assert legend[1].get_text() == 'line2'
402+
plt.close()
400403

401404

402405
@pytest.mark.parametrize("plt_fcn", [ct.bode_plot, ct.singular_values_plot])

0 commit comments

Comments
 (0)