Skip to content

Move ctrlplot code prior to upcoming PR #1033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 227 additions & 5 deletions control/ctrlplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,36 @@

from os.path import commonprefix

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from . import config

__all__ = ['suptitle', 'get_plot_axes']

#
# Style parameters
#

_ctrlplot_rcParams = mpl.rcParams.copy()
_ctrlplot_rcParams.update({
'axes.labelsize': 'small',
'axes.titlesize': 'small',
'figure.titlesize': 'medium',
'legend.fontsize': 'x-small',
'xtick.labelsize': 'small',
'ytick.labelsize': 'small',
})


#
# User functions
#
# The functions below can be used by users to modify ctrl plots or get
# information about them.
#


def suptitle(
title, fig=None, frame='axes', **kwargs):
Expand All @@ -35,7 +58,7 @@ def suptitle(
Additional keywords (passed to matplotlib).

"""
rcParams = config._get_param('freqplot', 'rcParams', kwargs, pop=True)
rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)

if fig is None:
fig = plt.gcf()
Expand All @@ -61,10 +84,10 @@ def suptitle(
def get_plot_axes(line_array):
"""Get a list of axes from an array of lines.

This function can be used to return the set of axes corresponding to
the line array that is returned by `time_response_plot`. This is useful for
generating an axes array that can be passed to subsequent plotting
calls.
This function can be used to return the set of axes corresponding
to the line array that is returned by `time_response_plot`. This
is useful for generating an axes array that can be passed to
subsequent plotting calls.

Parameters
----------
Expand All @@ -89,6 +112,125 @@ def get_plot_axes(line_array):
#
# Utility functions
#
# These functions are used by plotting routines to provide a consistent way
# of processing and displaying information.
#


def _process_ax_keyword(
axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False):
"""Utility function to process ax keyword to plotting commands.

This function processes the `ax` keyword to plotting commands. If no
ax keyword is passed, the current figure is checked to see if it has
the correct shape. If the shape matches the desired shape, then the
current figure and axes are returned. Otherwise a new figure is
created with axes of the desired shape.

Legacy behavior: some of the older plotting commands use a axes label
to identify the proper axes for plotting. This behavior is supported
through the use of the label keyword, but will only work if shape ==
(1, 1) and squeeze == True.

"""
if axs is None:
fig = plt.gcf() # get current figure (or create new one)
axs = fig.get_axes()

# Check to see if axes are the right shape; if not, create new figure
# Note: can't actually check the shape, just the total number of axes
if len(axs) != np.prod(shape):
with plt.rc_context(rcParams):
if len(axs) != 0:
# Create a new figure
fig, axs = plt.subplots(*shape, squeeze=False)
else:
# Create new axes on (empty) figure
axs = fig.subplots(*shape, squeeze=False)
fig.set_layout_engine('tight')
fig.align_labels()
else:
# Use the existing axes, properly reshaped
axs = np.asarray(axs).reshape(*shape)

if clear_text:
# Clear out any old text from the current figure
for text in fig.texts:
text.set_visible(False) # turn off the text
del text # get rid of it completely
else:
try:
axs = np.asarray(axs).reshape(shape)
except ValueError:
raise ValueError(
"specified axes are not the right shape; "
f"got {axs.shape} but expecting {shape}")
fig = axs[0, 0].figure

# Process the squeeze keyword
if squeeze and shape == (1, 1):
axs = axs[0, 0] # Just return the single axes object
elif squeeze:
axs = axs.squeeze()

return fig, axs


# Turn label keyword into array indexed by trace, output, input
# TODO: move to ctrlutil.py and update parameter names to reflect general use
def _process_line_labels(label, ntraces, ninputs=0, noutputs=0):
if label is None:
return None

if isinstance(label, str):
label = [label] * ntraces # single label for all traces

# Convert to an ndarray, if not done aleady
try:
line_labels = np.asarray(label)
except ValueError:
raise ValueError("label must be a string or array_like")

# Turn the data into a 3D array of appropriate shape
# TODO: allow more sophisticated broadcasting (and error checking)
try:
if ninputs > 0 and noutputs > 0:
if line_labels.ndim == 1 and line_labels.size == ntraces:
line_labels = line_labels.reshape(ntraces, 1, 1)
line_labels = np.broadcast_to(
line_labels, (ntraces, ninputs, noutputs))
else:
line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
except ValueError:
if line_labels.shape[0] != ntraces:
raise ValueError("number of labels must match number of traces")
else:
raise ValueError("labels must be given for each input/output pair")

return line_labels


# Get labels for all lines in an axes
def _get_line_labels(ax, use_color=True):
labels, lines = [], []
last_color, counter = None, 0 # label unknown systems
for i, line in enumerate(ax.get_lines()):
label = line.get_label()
if use_color and label.startswith("Unknown"):
label = f"Unknown-{counter}"
if last_color is None:
last_color = line.get_color()
elif last_color != line.get_color():
counter += 1
last_color = line.get_color()
elif label[0] == '_':
continue

if label not in labels:
lines.append(line)
labels.append(label)

return lines, labels


# Utility function to make legend labels
Expand Down Expand Up @@ -160,3 +302,83 @@ def _find_axes_center(fig, axs):
ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]

return (np.sum(xlim)/2, np.sum(ylim)/2)


# Internal function to add arrows to a curve
def _add_arrows_to_line2D(
axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
arrowstyle='-|>', arrowsize=1, dir=1):
"""
Add arrows to a matplotlib.lines.Line2D at selected locations.

Parameters:
-----------
axes: Axes object as returned by axes command (or gca)
line: Line2D object as returned by plot command
arrow_locs: list of locations where to insert arrows, % of total length
arrowstyle: style of the arrow
arrowsize: size of the arrow

Returns:
--------
arrows: list of arrows

Based on https://stackoverflow.com/questions/26911898/

"""
# Get the coordinates of the line, in plot coordinates
if not isinstance(line, mpl.lines.Line2D):
raise ValueError("expected a matplotlib.lines.Line2D object")
x, y = line.get_xdata(), line.get_ydata()

# Determine the arrow properties
arrow_kw = {"arrowstyle": arrowstyle}

color = line.get_color()
use_multicolor_lines = isinstance(color, np.ndarray)
if use_multicolor_lines:
raise NotImplementedError("multicolor lines not supported")
else:
arrow_kw['color'] = color

linewidth = line.get_linewidth()
if isinstance(linewidth, np.ndarray):
raise NotImplementedError("multiwidth lines not supported")
else:
arrow_kw['linewidth'] = linewidth

# Figure out the size of the axes (length of diagonal)
xlim, ylim = axes.get_xlim(), axes.get_ylim()
ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
diag = np.linalg.norm(ul - lr)

# Compute the arc length along the curve
s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))

# Truncate the number of arrows if the curve is short
# TODO: figure out a smarter way to do this
frac = min(s[-1] / diag, 1)
if len(arrow_locs) and frac < 0.05:
arrow_locs = [] # too short; no arrows at all
elif len(arrow_locs) and frac < 0.2:
arrow_locs = [0.5] # single arrow in the middle

# Plot the arrows (and return list if patches)
arrows = []
for loc in arrow_locs:
n = np.searchsorted(s, s[-1] * loc)

if dir == 1 and n == 0:
# Move the arrow forward by one if it is at start of a segment
n = 1

# Place the head of the arrow at the desired location
arrow_head = [x[n], y[n]]
arrow_tail = [x[n - dir], y[n - dir]]

p = mpl.patches.FancyArrowPatch(
arrow_tail, arrow_head, transform=axes.transData, lw=0,
**arrow_kw)
axes.add_patch(p)
arrows.append(p)
return arrows
Loading
Loading