Skip to content

Commit 136d6f4

Browse files
committed
add signal labels + more unit tests/coverage + docstring tweaks
1 parent ce5a95c commit 136d6f4

File tree

4 files changed

+295
-19
lines changed

4 files changed

+295
-19
lines changed

control/iosys.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1527,7 +1527,9 @@ def input_output_response(
15271527
15281528
The return value of the system can also be accessed by assigning the
15291529
function to a tuple of length 2 (time, output) or of length 3 (time,
1530-
output, state) if ``return_x`` is ``True``.
1530+
output, state) if ``return_x`` is ``True``. If the input/output
1531+
system signals are named, these names will be used as labels for the
1532+
time response.
15311533
15321534
Other parameters
15331535
----------------
@@ -1590,7 +1592,8 @@ def input_output_response(
15901592
u = U[i] if len(U.shape) == 1 else U[:, i]
15911593
y[:, i] = sys._out(T[i], [], u)
15921594
return TimeResponseData(
1593-
T, y, None, None, issiso=sys.issiso(),
1595+
T, y, None, U, issiso=sys.issiso(),
1596+
output_labels=sys.output_index, input_labels=sys.input_index,
15941597
transpose=transpose, return_x=return_x, squeeze=squeeze)
15951598

15961599
# create X0 if not given, test if X0 has correct shape
@@ -1687,6 +1690,8 @@ def ivp_rhs(t, x):
16871690

16881691
return TimeResponseData(
16891692
soln.t, y, soln.y, U, issiso=sys.issiso(),
1693+
output_labels=sys.output_index, input_labels=sys.input_index,
1694+
state_labels=sys.state_index,
16901695
transpose=transpose, return_x=return_x, squeeze=squeeze)
16911696

16921697

control/tests/timeresp_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ def test_squeeze(self, fcn, nstate, nout, ninp, squeeze, shape1, shape2):
11171117
@pytest.mark.parametrize("fcn", [ct.ss, ct.tf, ct.ss2io])
11181118
def test_squeeze_exception(self, fcn):
11191119
sys = fcn(ct.rss(2, 1, 1))
1120-
with pytest.raises(ValueError, match="unknown squeeze value"):
1120+
with pytest.raises(ValueError, match="Unknown squeeze value"):
11211121
step_response(sys, squeeze=1)
11221122

11231123
@pytest.mark.usefixtures("editsdefaults")

control/tests/trdata_test.py

+185-4
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
[2, 1, None],
2626
[2, 1, True],
2727
[2, 1, False],
28-
[2, 2, None],
29-
[2, 2, True],
30-
[2, 2, False],
28+
[2, 3, None],
29+
[2, 3, True],
30+
[2, 3, False],
3131
])
3232
def test_trdata_shapes(nin, nout, squeeze):
3333
# SISO, single trace
@@ -48,6 +48,12 @@ def test_trdata_shapes(nin, nout, squeeze):
4848
assert res.x.shape == (sys.nstates, ntimes)
4949
assert res.u is None
5050

51+
# Check dimensions of the response
52+
assert res.ntraces == 0 # single trace
53+
assert res.ninputs == 0 # no input for initial response
54+
assert res.noutputs == sys.noutputs
55+
assert res.nstates == sys.nstates
56+
5157
# Check shape of class properties
5258
if sys.issiso():
5359
assert res.outputs.shape == (ntimes,)
@@ -78,6 +84,12 @@ def test_trdata_shapes(nin, nout, squeeze):
7884
assert res.x.shape == (sys.nstates, sys.ninputs, ntimes)
7985
assert res.u.shape == (sys.ninputs, sys.ninputs, ntimes)
8086

87+
# Check shape of class members
88+
assert res.ntraces == sys.ninputs
89+
assert res.ninputs == sys.ninputs
90+
assert res.noutputs == sys.noutputs
91+
assert res.nstates == sys.nstates
92+
8193
# Check shape of inputs and outputs
8294
if sys.issiso() and squeeze is not False:
8395
assert res.outputs.shape == (ntimes, )
@@ -108,11 +120,19 @@ def test_trdata_shapes(nin, nout, squeeze):
108120
res = ct.forced_response(sys, T, U, X0, squeeze=squeeze)
109121
ntimes = res.time.shape[0]
110122

123+
# Check shape of class members
111124
assert len(res.time.shape) == 1
112125
assert res.y.shape == (sys.noutputs, ntimes)
113126
assert res.x.shape == (sys.nstates, ntimes)
114127
assert res.u.shape == (sys.ninputs, ntimes)
115128

129+
# Check dimensions of the response
130+
assert res.ntraces == 0 # single trace
131+
assert res.ninputs == sys.ninputs
132+
assert res.noutputs == sys.noutputs
133+
assert res.nstates == sys.nstates
134+
135+
# Check shape of inputs and outputs
116136
if sys.issiso() and squeeze is not False:
117137
assert res.outputs.shape == (ntimes,)
118138
assert res.states.shape == (sys.nstates, ntimes)
@@ -176,6 +196,167 @@ def test_response_copy():
176196
with pytest.raises(ValueError, match="not enough"):
177197
t, y, x = response_mimo
178198

199+
# Labels
200+
assert response_mimo.output_labels is None
201+
assert response_mimo.state_labels is None
202+
assert response_mimo.input_labels is None
203+
response = response_mimo(
204+
output_labels=['y1', 'y2'], input_labels='u',
205+
state_labels=["x[%d]" % i for i in range(4)])
206+
assert response.output_labels == ['y1', 'y2']
207+
assert response.state_labels == ['x[0]', 'x[1]', 'x[2]', 'x[3]']
208+
assert response.input_labels == ['u']
209+
179210
# Unknown keyword
180-
with pytest.raises(ValueError, match="unknown"):
211+
with pytest.raises(ValueError, match="Unknown parameter(s)*"):
181212
response_bad_kw = response_mimo(input=0)
213+
214+
215+
def test_trdata_labels():
216+
# Create an I/O system with labels
217+
sys = ct.rss(4, 3, 2)
218+
iosys = ct.LinearIOSystem(sys)
219+
220+
T = np.linspace(1, 10, 10)
221+
U = [np.sin(T), np.cos(T)]
222+
223+
# Create a response
224+
response = ct.input_output_response(iosys, T, U)
225+
226+
# Make sure the labels got created
227+
np.testing.assert_equal(
228+
response.output_labels, ["y[%d]" % i for i in range(sys.noutputs)])
229+
np.testing.assert_equal(
230+
response.state_labels, ["x[%d]" % i for i in range(sys.nstates)])
231+
np.testing.assert_equal(
232+
response.input_labels, ["u[%d]" % i for i in range(sys.ninputs)])
233+
234+
235+
def test_trdata_multitrace():
236+
#
237+
# Output signal processing
238+
#
239+
240+
# Proper call of multi-trace data w/ ambiguous 2D output
241+
response = ct.TimeResponseData(
242+
np.zeros(5), np.ones((2, 5)), np.zeros((3, 2, 5)),
243+
np.ones((4, 2, 5)), multi_trace=True)
244+
assert response.ntraces == 2
245+
assert response.noutputs == 1
246+
assert response.nstates == 3
247+
assert response.ninputs == 4
248+
249+
# Proper call of single trace w/ ambiguous 2D output
250+
response = ct.TimeResponseData(
251+
np.zeros(5), np.ones((2, 5)), np.zeros((3, 5)),
252+
np.ones((4, 5)), multi_trace=False)
253+
assert response.ntraces == 0
254+
assert response.noutputs == 2
255+
assert response.nstates == 3
256+
assert response.ninputs == 4
257+
258+
# Proper call of multi-trace data w/ ambiguous 1D output
259+
response = ct.TimeResponseData(
260+
np.zeros(5), np.ones(5), np.zeros((3, 5)),
261+
np.ones((4, 5)), multi_trace=False)
262+
assert response.ntraces == 0
263+
assert response.noutputs == 1
264+
assert response.nstates == 3
265+
assert response.ninputs == 4
266+
assert response.y.shape == (1, 5) # Make sure reshape occured
267+
268+
# Output vector not the right shape
269+
with pytest.raises(ValueError, match="Output vector is the wrong shape"):
270+
response = ct.TimeResponseData(
271+
np.zeros(5), np.ones((1, 2, 3, 5)), None, None)
272+
273+
# Inconsistent output vector: different number of time points
274+
with pytest.raises(ValueError, match="Output vector does not match time"):
275+
response = ct.TimeResponseData(
276+
np.zeros(5), np.ones(6), np.zeros(5), np.zeros(5))
277+
278+
#
279+
# State signal processing
280+
#
281+
282+
# For multi-trace, state must be 3D
283+
with pytest.raises(ValueError, match="State vector is the wrong shape"):
284+
response = ct.TimeResponseData(
285+
np.zeros(5), np.ones((1, 5)), np.zeros((3, 5)), multi_trace=True)
286+
287+
# If not multi-trace, state must be 2D
288+
with pytest.raises(ValueError, match="State vector is the wrong shape"):
289+
response = ct.TimeResponseData(
290+
np.zeros(5), np.ones(5), np.zeros((3, 1, 5)), multi_trace=False)
291+
292+
# State vector in the wrong shape
293+
with pytest.raises(ValueError, match="State vector is the wrong shape"):
294+
response = ct.TimeResponseData(
295+
np.zeros(5), np.ones((1, 2, 5)), np.zeros((2, 1, 5)))
296+
297+
# Inconsistent state vector: different number of time points
298+
with pytest.raises(ValueError, match="State vector does not match time"):
299+
response = ct.TimeResponseData(
300+
np.zeros(5), np.ones(5), np.zeros((1, 6)), np.zeros(5))
301+
302+
#
303+
# Input signal processing
304+
#
305+
306+
# Proper call of multi-trace data with 2D input
307+
response = ct.TimeResponseData(
308+
np.zeros(5), np.ones((2, 5)), np.zeros((3, 2, 5)),
309+
np.ones((2, 5)), multi_trace=True)
310+
assert response.ntraces == 2
311+
assert response.noutputs == 1
312+
assert response.nstates == 3
313+
assert response.ninputs == 1
314+
315+
# Input vector in the wrong shape
316+
with pytest.raises(ValueError, match="Input vector is the wrong shape"):
317+
response = ct.TimeResponseData(
318+
np.zeros(5), np.ones((1, 2, 5)), None, np.zeros((2, 1, 5)))
319+
320+
# Inconsistent input vector: different number of time points
321+
with pytest.raises(ValueError, match="Input vector does not match time"):
322+
response = ct.TimeResponseData(
323+
np.zeros(5), np.ones(5), np.zeros((1, 5)), np.zeros(6))
324+
325+
326+
def test_trdata_exceptions():
327+
# Incorrect dimension for time vector
328+
with pytest.raises(ValueError, match="Time vector must be 1D"):
329+
ct.TimeResponseData(np.zeros((2,2)), np.zeros(2), None)
330+
331+
# Infer SISO system from inputs and outputs
332+
response = ct.TimeResponseData(
333+
np.zeros(5), np.ones(5), None, np.ones(5))
334+
assert response.issiso
335+
336+
response = ct.TimeResponseData(
337+
np.zeros(5), np.ones((1, 5)), None, np.ones((1, 5)))
338+
assert response.issiso
339+
340+
response = ct.TimeResponseData(
341+
np.zeros(5), np.ones((1, 2, 5)), None, np.ones((1, 2, 5)))
342+
assert response.issiso
343+
344+
# Not enough input to infer whether SISO
345+
with pytest.raises(ValueError, match="Can't determine if system is SISO"):
346+
response = ct.TimeResponseData(
347+
np.zeros(5), np.ones((1, 2, 5)), np.ones((4, 2, 5)), None)
348+
349+
# Not enough input to infer whether SISO
350+
with pytest.raises(ValueError, match="Keyword `issiso` does not match"):
351+
response = ct.TimeResponseData(
352+
np.zeros(5), np.ones((2, 5)), None, np.ones((1, 5)), issiso=True)
353+
354+
# Unknown squeeze keyword value
355+
with pytest.raises(ValueError, match="Unknown squeeze value"):
356+
response=ct.TimeResponseData(
357+
np.zeros(5), np.ones(5), None, np.ones(5), squeeze=1)
358+
359+
# Legacy interface index error
360+
response[0], response[1], response[2]
361+
with pytest.raises(IndexError):
362+
response[3]

0 commit comments

Comments
 (0)