Skip to content

Commit 2764889

Browse files
committed
fix ax processing bug in {nyquist,nichols,describing_function}_plot
1 parent dc7d71b commit 2764889

File tree

4 files changed

+33
-24
lines changed

4 files changed

+33
-24
lines changed

control/descfcn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,16 +521,17 @@ def describing_function_plot(
521521

522522
# Plot the Nyquist response
523523
cplt = dfresp.response.plot(**kwargs)
524+
ax = cplt.axes[0, 0] # Get the axes where the plot was made
524525
lines[0] = cplt.lines[0] # Return Nyquist lines for first system
525526

526527
# Add the describing function curve to the plot
527-
lines[1] = plt.plot(dfresp.N_vals.real, dfresp.N_vals.imag)
528+
lines[1] = ax.plot(dfresp.N_vals.real, dfresp.N_vals.imag)
528529

529530
# Label the intersection points
530531
if point_label:
531532
for pos, (a, omega) in zip(dfresp.positions, dfresp.intersections):
532533
# Add labels to the intersection points
533-
plt.text(pos.real, pos.imag, point_label % (a, omega))
534+
ax.text(pos.real, pos.imag, point_label % (a, omega))
534535

535536
return ControlPlot(lines, cplt.axes, cplt.figure)
536537

control/freqplot.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,7 +1913,7 @@ def _parse_linestyle(style_name, allow_false=False):
19131913
# Plot the regular portions of the curve (and grab the color)
19141914
x_reg = np.ma.masked_where(reg_mask, resp.real)
19151915
y_reg = np.ma.masked_where(reg_mask, resp.imag)
1916-
p = plt.plot(
1916+
p = ax.plot(
19171917
x_reg, y_reg, primary_style[0], color=color, label=label, **kwargs)
19181918
c = p[0].get_color()
19191919
out[idx] += p
@@ -1928,7 +1928,7 @@ def _parse_linestyle(style_name, allow_false=False):
19281928
x_scl = np.ma.masked_where(scale_mask, resp.real)
19291929
y_scl = np.ma.masked_where(scale_mask, resp.imag)
19301930
if x_scl.count() >= 1 and y_scl.count() >= 1:
1931-
out[idx] += plt.plot(
1931+
out[idx] += ax.plot(
19321932
x_scl * (1 + curve_offset),
19331933
y_scl * (1 + curve_offset),
19341934
primary_style[1], color=c, **kwargs)
@@ -1939,20 +1939,19 @@ def _parse_linestyle(style_name, allow_false=False):
19391939
x, y = resp.real.copy(), resp.imag.copy()
19401940
x[reg_mask] *= (1 + curve_offset[reg_mask])
19411941
y[reg_mask] *= (1 + curve_offset[reg_mask])
1942-
p = plt.plot(x, y, linestyle='None', color=c)
1942+
p = ax.plot(x, y, linestyle='None', color=c)
19431943

19441944
# Add arrows
1945-
ax = plt.gca()
19461945
_add_arrows_to_line2D(
19471946
ax, p[0], arrow_pos, arrowstyle=arrow_style, dir=1)
19481947

19491948
# Plot the mirror image
19501949
if mirror_style is not False:
19511950
# Plot the regular and scaled segments
1952-
out[idx] += plt.plot(
1951+
out[idx] += ax.plot(
19531952
x_reg, -y_reg, mirror_style[0], color=c, **kwargs)
19541953
if x_scl.count() >= 1 and y_scl.count() >= 1:
1955-
out[idx] += plt.plot(
1954+
out[idx] += ax.plot(
19561955
x_scl * (1 - curve_offset),
19571956
-y_scl * (1 - curve_offset),
19581957
mirror_style[1], color=c, **kwargs)
@@ -1963,19 +1962,19 @@ def _parse_linestyle(style_name, allow_false=False):
19631962
x, y = resp.real.copy(), resp.imag.copy()
19641963
x[reg_mask] *= (1 - curve_offset[reg_mask])
19651964
y[reg_mask] *= (1 - curve_offset[reg_mask])
1966-
p = plt.plot(x, -y, linestyle='None', color=c, **kwargs)
1965+
p = ax.plot(x, -y, linestyle='None', color=c, **kwargs)
19671966
_add_arrows_to_line2D(
19681967
ax, p[0], arrow_pos, arrowstyle=arrow_style, dir=-1)
19691968
else:
19701969
out[idx] += [None, None]
19711970

19721971
# Mark the start of the curve
19731972
if start_marker:
1974-
plt.plot(resp[0].real, resp[0].imag, start_marker,
1973+
ax.plot(resp[0].real, resp[0].imag, start_marker,
19751974
color=c, markersize=start_marker_size)
19761975

19771976
# Mark the -1 point
1978-
plt.plot([-1], [0], 'r+')
1977+
ax.plot([-1], [0], 'r+')
19791978

19801979
#
19811980
# Draw circles for gain crossover and sensitivity functions
@@ -1987,16 +1986,16 @@ def _parse_linestyle(style_name, allow_false=False):
19871986

19881987
# Display the unit circle, to read gain crossover frequency
19891988
if unit_circle:
1990-
plt.plot(cos, sin, **config.defaults['nyquist.circle_style'])
1989+
ax.plot(cos, sin, **config.defaults['nyquist.circle_style'])
19911990

19921991
# Draw circles for given magnitudes of sensitivity
19931992
if ms_circles is not None:
19941993
for ms in ms_circles:
19951994
pos_x = -1 + (1/ms)*cos
19961995
pos_y = (1/ms)*sin
1997-
plt.plot(
1996+
ax.plot(
19981997
pos_x, pos_y, **config.defaults['nyquist.circle_style'])
1999-
plt.text(pos_x[label_pos], pos_y[label_pos], ms)
1998+
ax.text(pos_x[label_pos], pos_y[label_pos], ms)
20001999

20012000
# Draw circles for given magnitudes of complementary sensitivity
20022001
if mt_circles is not None:
@@ -2006,17 +2005,17 @@ def _parse_linestyle(style_name, allow_false=False):
20062005
rt = mt/(mt**2-1) # Mt radius
20072006
pos_x = ct+rt*cos
20082007
pos_y = rt*sin
2009-
plt.plot(
2008+
ax.plot(
20102009
pos_x, pos_y,
20112010
**config.defaults['nyquist.circle_style'])
2012-
plt.text(pos_x[label_pos], pos_y[label_pos], mt)
2011+
ax.text(pos_x[label_pos], pos_y[label_pos], mt)
20132012
else:
2014-
_, _, ymin, ymax = plt.axis()
2013+
_, _, ymin, ymax = ax.axis()
20152014
pos_y = np.linspace(ymin, ymax, 100)
2016-
plt.vlines(
2015+
ax.vlines(
20172016
-0.5, ymin=ymin, ymax=ymax,
20182017
**config.defaults['nyquist.circle_style'])
2019-
plt.text(-0.5, pos_y[label_pos], 1)
2018+
ax.text(-0.5, pos_y[label_pos], 1)
20202019

20212020
# Label the frequencies of the points on the Nyquist curve
20222021
if label_freq:
@@ -2039,7 +2038,7 @@ def _parse_linestyle(style_name, allow_false=False):
20392038
# np.round() is used because 0.99... appears
20402039
# instead of 1.0, and this would otherwise be
20412040
# truncated to 0.
2042-
plt.text(xpt, ypt, ' ' +
2041+
ax.text(xpt, ypt, ' ' +
20432042
str(int(np.round(f / 1000 ** pow1000, 0))) + ' ' +
20442043
prefix + 'Hz')
20452044

control/nichols.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,15 @@ def nichols_plot(
132132
out[idx] = ax_nichols.plot(x, y, *fmt, label=label_, **kwargs)
133133

134134
# Label the plot axes
135-
plt.xlabel('Phase [deg]')
136-
plt.ylabel('Magnitude [dB]')
135+
ax_nichols.set_xlabel('Phase [deg]')
136+
ax_nichols.set_ylabel('Magnitude [dB]')
137137

138138
# Mark the -180 point
139-
plt.plot([-180], [0], 'r+')
139+
ax_nichols.plot([-180], [0], 'r+')
140140

141141
# Add grid
142142
if grid:
143-
nichols_grid()
143+
nichols_grid(ax=ax_nichols)
144144

145145
# List of systems that are included in this plot
146146
lines, labels = _get_line_labels(ax_nichols)

control/tests/ctrlplot_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,15 @@ def test_plot_ax_processing(resp_fcn, plot_fcn):
243243
# No response function available; just plot the data
244244
plot_fcn(*args, **kwargs, **plot_kwargs, ax=ax)
245245

246+
# Make sure the plot ended up in the right place
247+
assert len(axs[0, 0].get_lines()) == 0 # upper left
248+
assert len(axs[0, 1].get_lines()) != 0 # top middle
249+
assert len(axs[1, 0].get_lines()) == 0 # lower left
250+
if resp_fcn != ct.gangof4_response:
251+
assert len(axs[1, 2].get_lines()) == 0 # lower right (normally empty)
252+
else:
253+
assert len(axs[1, 2].get_lines()) != 0 # gangof4 uses this axes
254+
246255
# Check to make sure original settings did not change
247256
assert fig._suptitle.get_text() == title
248257
assert fig._suptitle.get_fontsize() == titlesize

0 commit comments

Comments
 (0)