Skip to content

Commit 6f6c70d

Browse files
committed
refactoring/regularization of ax keyword processing
1 parent 402b45f commit 6f6c70d

File tree

2 files changed

+89
-58
lines changed

2 files changed

+89
-58
lines changed

control/freqplot.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -458,47 +458,13 @@ def bode_plot(
458458
(noutputs if plot_phase else 0)
459459
ncols = ninputs
460460

461-
# See if we can use the current figure axes
462-
fig = plt.gcf() # get current figure (or create new one)
463-
if ax is None and plt.get_fignums():
464-
ax = fig.get_axes()
465-
if len(ax) == nrows * ncols:
466-
# Assume that the shape is right (no easy way to infer this)
467-
ax = np.array(ax).reshape(nrows, ncols)
468-
469-
# Clear out any old text from the current figure
470-
for text in fig.texts:
471-
text.set_visible(False) # turn off the text
472-
del text # get rid of it completely
473-
474-
elif len(ax) != 0:
475-
# Need to generate a new figure
476-
fig, ax = plt.figure(), None
477-
478-
else:
479-
# Blank figure, just need to recreate axes
480-
ax = None
481-
482-
# Create new axes, if needed, and customize them
483461
if ax is None:
484-
with plt.rc_context(_freqplot_rcParams):
485-
ax_array = fig.subplots(nrows, ncols, squeeze=False)
486-
fig.set_layout_engine('tight')
487-
fig.align_labels()
488-
489462
# Set up default sharing of axis limits if not specified
490463
for kw in ['share_magnitude', 'share_phase', 'share_frequency']:
491464
if kw not in kwargs or kwargs[kw] is None:
492465
kwargs[kw] = config.defaults['freqplot.' + kw]
493466

494-
else:
495-
# Make sure the axes are the right shape
496-
if ax.shape != (nrows, ncols):
497-
raise ValueError(
498-
"specified axes are not the right shape; "
499-
f"got {ax.shape} but expecting ({nrows}, {ncols})")
500-
ax_array = ax
501-
fig = ax_array[0, 0].figure # just in case this is not gcf()
467+
fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), squeeze=False)
502468

503469
# Get the values for sharing axes limits
504470
share_magnitude = kwargs.pop('share_magnitude', None)
@@ -1780,11 +1746,8 @@ def _parse_linestyle(style_name, allow_false=False):
17801746
# Return counts and (optionally) the contour we used
17811747
return (counts, contours) if return_contour else counts
17821748

1783-
# Get the figure and axes to use
1784-
if ax is None:
1785-
fig, ax = plt.gcf(), plt.gca()
1786-
else:
1787-
fig = ax.figure
1749+
fig, ax = _process_ax_keyword(
1750+
ax, shape=(1, 1), squeeze=True, rcParams=_freqplot_rcParams)
17881751

17891752
# Create a list of lines for the output
17901753
out = np.empty(len(nyquist_responses), dtype=object)
@@ -2235,7 +2198,7 @@ def singular_values_response(
22352198

22362199
def singular_values_plot(
22372200
data, omega=None, *fmt, plot=None, omega_limits=None, omega_num=None,
2238-
label=None, title=None, legend_loc='center right', **kwargs):
2201+
ax=None, label=None, title=None, legend_loc='center right', **kwargs):
22392202
"""Plot the singular values for a system.
22402203
22412204
Plot the singular values as a function of frequency for a system or
@@ -2364,22 +2327,8 @@ def singular_values_plot(
23642327
else:
23652328
return sigmas, omegas
23662329

2367-
fig = plt.gcf() # get current figure (or create new one)
2368-
ax_sigma = None # axes for plotting singular values
2369-
2370-
# Get the current axes if they already exist
2371-
for ax in fig.axes:
2372-
if ax.get_label() == 'control-sigma':
2373-
ax_sigma = ax
2374-
2375-
# If no axes present, create them from scratch
2376-
if ax_sigma is None:
2377-
if len(fig.axes) > 0:
2378-
# Create a new figure to avoid overwriting in the old one
2379-
fig = plt.figure()
2380-
2381-
with plt.rc_context(_freqplot_rcParams):
2382-
ax_sigma = plt.subplot(111, label='control-sigma')
2330+
fig, ax_sigma = _process_ax_keyword(ax, shape=(1, 1), squeeze=True)
2331+
ax_sigma.set_label('control-sigma') # TODO: deprecate?
23832332

23842333
# Handle color cycle manually as all singular values
23852334
# of the same systems are expected to be of the same color
@@ -2475,7 +2424,7 @@ def singular_values_plot(
24752424
# Utility functions
24762425
#
24772426
# This section of the code contains some utility functions for
2478-
# generating frequency domain plots
2427+
# generating frequency domain plots.
24792428
#
24802429

24812430

@@ -2742,6 +2691,57 @@ def _process_line_labels(label, nsys, ninputs=0, noutputs=0):
27422691
return line_labels
27432692

27442693

2694+
def _process_ax_keyword(axs, shape=(1, 1), rcParams=None, squeeze=False):
2695+
"""Utility function to process ax keyword to plotting commands.
2696+
2697+
This function processes the `ax` keyword to plotting commands. If no
2698+
ax keyword is passed, the current figure is checked to see if it has
2699+
the correct shape. If the shape matches the desired shape, then the
2700+
current figure and axes are returned. Otherwise a new figure is
2701+
created with axes of the desired shape.
2702+
2703+
Legacy behavior: some of the older plotting commands use a axes label
2704+
to identify the proper axes for plotting. This behavior is supported
2705+
through the use of the label keyword, but will only work if shape ==
2706+
(1, 1) and squeeze == True.
2707+
2708+
"""
2709+
if axs is None:
2710+
fig = plt.gcf() # get current figure (or create new one)
2711+
axs = fig.get_axes()
2712+
2713+
# Check to see if axes are the right shape; if not, create new figure
2714+
# Note: can't actually check the shape, just the total number of axes
2715+
if len(axs) != np.prod(shape):
2716+
with plt.rc_context(rcParams):
2717+
if len(axs) != 0:
2718+
# Create a new figure
2719+
fig, axs = plt.subplots(*shape, squeeze=False)
2720+
else:
2721+
# Create new axes on (empty) figure
2722+
axs = fig.subplots(*shape, squeeze=False)
2723+
fig.set_layout_engine('tight')
2724+
fig.align_labels()
2725+
else:
2726+
# Use the existing axes, properly reshaped
2727+
axs = np.asarray(axs).reshape(*shape)
2728+
else:
2729+
try:
2730+
axs = np.asarray(axs).reshape(shape)
2731+
except ValueError:
2732+
raise ValueError(
2733+
"specified axes are not the right shape; "
2734+
f"got {axs.shape} but expecting {shape}")
2735+
fig = axs[0, 0].figure
2736+
2737+
# Process the squeeze keyword
2738+
if squeeze and shape == (1, 1):
2739+
axs = axs[0, 0] # Just return the single axes object
2740+
elif squeeze:
2741+
axs = axs.squeeze()
2742+
2743+
return fig, axs
2744+
27452745
#
27462746
# Utility functions to create nice looking labels (KLD 5/23/11)
27472747
#

control/tests/freqplot_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,37 @@ def test_freqplot_trace_labels(plt_fcn):
462462
plt.close()
463463

464464

465+
466+
@pytest.mark.parametrize(
467+
"plt_fcn", [ct.bode_plot, ct.singular_values_plot, ct.nyquist_plot])
468+
@pytest.mark.parametrize(
469+
"ninputs, noutputs", [(1, 1), (1, 2), (2, 1), (2, 3)])
470+
def test_freqplot_ax_keyword(plt_fcn, ninputs, noutputs):
471+
if plt_fcn == ct.nyquist_plot and (ninputs != 1 or noutputs != 1):
472+
pytest.skip("MIMO not implemented for Nyquist")
473+
474+
# System to use
475+
sys = ct.rss(4, ninputs, noutputs)
476+
477+
# Create an initial figure
478+
out1 = plt_fcn(sys)
479+
480+
# Draw again on the same figure, using array
481+
axs = ct.get_plot_axes(out1)
482+
out2 = plt_fcn(sys, ax=axs)
483+
np.testing.assert_equal(ct.get_plot_axes(out1), ct.get_plot_axes(out2))
484+
485+
# Pass things in as a list instead
486+
axs_list = axs.tolist()
487+
out3 = plt_fcn(sys, ax=axs)
488+
np.testing.assert_equal(ct.get_plot_axes(out1), ct.get_plot_axes(out3))
489+
490+
# Flatten the list
491+
axs_list = axs.squeeze().tolist()
492+
out3 = plt_fcn(sys, ax=axs_list)
493+
np.testing.assert_equal(ct.get_plot_axes(out1), ct.get_plot_axes(out3))
494+
495+
465496
@pytest.mark.parametrize("plt_fcn", [ct.bode_plot, ct.singular_values_plot])
466497
def test_freqplot_errors(plt_fcn):
467498
if plt_fcn == ct.bode_plot:

0 commit comments

Comments
 (0)