@@ -196,15 +196,20 @@ def test_response_copy():
196
196
with pytest .raises (ValueError , match = "not enough" ):
197
197
t , y , x = response_mimo
198
198
199
- # Labels
200
- assert response_mimo .output_labels is None
201
- assert response_mimo .state_labels is None
202
- assert response_mimo .input_labels is None
199
+ # Make sure labels are transferred to the response
200
+ assert response_siso .output_labels == sys_siso .output_labels
201
+ assert response_siso .state_labels == sys_siso .state_labels
202
+ assert response_siso .input_labels == sys_siso .input_labels
203
+ assert response_mimo .output_labels == sys_mimo .output_labels
204
+ assert response_mimo .state_labels == sys_mimo .state_labels
205
+ assert response_mimo .input_labels == sys_mimo .input_labels
206
+
207
+ # Check relabelling
203
208
response = response_mimo (
204
209
output_labels = ['y1' , 'y2' ], input_labels = 'u' ,
205
- state_labels = ["x[%d] " % i for i in range (4 )])
210
+ state_labels = ["x%d " % i for i in range (4 )])
206
211
assert response .output_labels == ['y1' , 'y2' ]
207
- assert response .state_labels == ['x[0] ' , 'x[1] ' , 'x[2] ' , 'x[3] ' ]
212
+ assert response .state_labels == ['x0 ' , 'x1 ' , 'x2 ' , 'x3 ' ]
208
213
assert response .input_labels == ['u' ]
209
214
210
215
# Unknown keyword
@@ -231,6 +236,17 @@ def test_trdata_labels():
231
236
np .testing .assert_equal (
232
237
response .input_labels , ["u[%d]" % i for i in range (sys .ninputs )])
233
238
239
+ # Make sure the selected input and output are both correctly transferred to the response
240
+ for nu in range (sys .ninputs ):
241
+ for ny in range (sys .noutputs ):
242
+ step_response = ct .step_response (sys , T , input = nu , output = ny )
243
+ assert step_response .input_labels == [sys .input_labels [nu ]]
244
+ assert step_response .output_labels == [sys .output_labels [ny ]]
245
+
246
+ init_response = ct .initial_response (sys , T , input = nu , output = ny )
247
+ assert init_response .input_labels == None
248
+ assert init_response .output_labels == [sys .output_labels [ny ]]
249
+
234
250
235
251
def test_trdata_multitrace ():
236
252
#
0 commit comments