Skip to content

Commit e5394c4

Browse files
authored
Merge pull request #1012 from guptavaibhav0/main
Add slicing access for state-space models with tests
2 parents ad6b49e + a0fc6bc commit e5394c4

File tree

3 files changed

+55
-16
lines changed

3 files changed

+55
-16
lines changed

control/statesp.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import math
5151
from copy import deepcopy
5252
from warnings import warn
53+
from collections.abc import Iterable
5354

5455
import numpy as np
5556
import scipy as sp
@@ -289,9 +290,9 @@ def __init__(self, *args, **kwargs):
289290
raise ValueError("A and B must have the same number of rows.")
290291
if self.nstates != C.shape[1]:
291292
raise ValueError("A and C must have the same number of columns.")
292-
if self.ninputs != B.shape[1]:
293+
if self.ninputs != B.shape[1] or self.ninputs != D.shape[1]:
293294
raise ValueError("B and D must have the same number of columns.")
294-
if self.noutputs != C.shape[0]:
295+
if self.noutputs != C.shape[0] or self.noutputs != D.shape[0]:
295296
raise ValueError("C and D must have the same number of rows.")
296297

297298
#
@@ -1215,17 +1216,23 @@ def append(self, other):
12151216

12161217
def __getitem__(self, indices):
12171218
"""Array style access"""
1218-
if len(indices) != 2:
1219+
if not isinstance(indices, Iterable) or len(indices) != 2:
12191220
raise IOError('must provide indices of length 2 for state space')
1220-
outdx = indices[0] if isinstance(indices[0], list) else [indices[0]]
1221-
inpdx = indices[1] if isinstance(indices[1], list) else [indices[1]]
1221+
outdx, inpdx = indices
1222+
1223+
# Convert int to slice to ensure that numpy doesn't drop the dimension
1224+
if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1)
1225+
if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1)
1226+
1227+
if not isinstance(outdx, slice) or not isinstance(inpdx, slice):
1228+
raise TypeError(f"system indices must be integers or slices")
1229+
12221230
sysname = config.defaults['iosys.indexed_system_name_prefix'] + \
12231231
self.name + config.defaults['iosys.indexed_system_name_suffix']
12241232
return StateSpace(
12251233
self.A, self.B[:, inpdx], self.C[outdx, :], self.D[outdx, inpdx],
1226-
self.dt, name=sysname,
1227-
inputs=[self.input_labels[i] for i in list(inpdx)],
1228-
outputs=[self.output_labels[i] for i in list(outdx)])
1234+
self.dt, name=sysname,
1235+
inputs=self.input_labels[inpdx], outputs=self.output_labels[outdx])
12291236

12301237
def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None,
12311238
name=None, copy_names=True, **kwargs):

control/tests/statesp_test.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,28 +463,53 @@ def test_append_tf(self):
463463
np.testing.assert_array_almost_equal(sys3c.A[:3, 3:], np.zeros((3, 2)))
464464
np.testing.assert_array_almost_equal(sys3c.A[3:, :3], np.zeros((2, 3)))
465465

466-
def test_array_access_ss(self):
467-
466+
def test_array_access_ss_failure(self):
468467
sys1 = StateSpace(
469468
[[1., 2.], [3., 4.]],
470469
[[5., 6.], [6., 8.]],
471470
[[9., 10.], [11., 12.]],
472471
[[13., 14.], [15., 16.]], 1,
473472
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
473+
with pytest.raises(IOError):
474+
sys1[0]
475+
476+
@pytest.mark.parametrize("outdx, inpdx",
477+
[(0, 1),
478+
(slice(0, 1, 1), 1),
479+
(0, slice(1, 2, 1)),
480+
(slice(0, 1, 1), slice(1, 2, 1)),
481+
(slice(None, None, -1), 1),
482+
(0, slice(None, None, -1)),
483+
(slice(None, 2, None), 1),
484+
(slice(None, None, 1), slice(None, None, 2)),
485+
(0, slice(1, 2, 1)),
486+
(slice(0, 1, 1), slice(1, 2, 1))])
487+
def test_array_access_ss(self, outdx, inpdx):
488+
sys1 = StateSpace(
489+
[[1., 2.], [3., 4.]],
490+
[[5., 6.], [7., 8.]],
491+
[[9., 10.], [11., 12.]],
492+
[[13., 14.], [15., 16.]], 1,
493+
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
474494

475-
sys1_01 = sys1[0, 1]
495+
sys1_01 = sys1[outdx, inpdx]
496+
497+
# Convert int to slice to ensure that numpy doesn't drop the dimension
498+
if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1)
499+
if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1)
500+
476501
np.testing.assert_array_almost_equal(sys1_01.A,
477502
sys1.A)
478503
np.testing.assert_array_almost_equal(sys1_01.B,
479-
sys1.B[:, 1:2])
504+
sys1.B[:, inpdx])
480505
np.testing.assert_array_almost_equal(sys1_01.C,
481-
sys1.C[0:1, :])
506+
sys1.C[outdx, :])
482507
np.testing.assert_array_almost_equal(sys1_01.D,
483-
sys1.D[0, 1])
508+
sys1.D[outdx, inpdx])
484509

485510
assert sys1.dt == sys1_01.dt
486-
assert sys1_01.input_labels == ['u1']
487-
assert sys1_01.output_labels == ['y0']
511+
assert sys1_01.input_labels == sys1.input_labels[inpdx]
512+
assert sys1_01.output_labels == sys1.output_labels[outdx]
488513
assert sys1_01.name == sys1.name + "$indexed"
489514

490515
def test_dc_gain_cont(self):

control/xferfcn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
4848
"""
4949

50+
from collections.abc import Iterable
51+
5052
# External function declarations
5153
import numpy as np
5254
from numpy import angle, array, empty, finfo, ndarray, ones, \
@@ -758,7 +760,12 @@ def __pow__(self, other):
758760
return (TransferFunction([1], [1]) / self) * (self**(other + 1))
759761

760762
def __getitem__(self, key):
763+
if not isinstance(key, Iterable) or len(key) != 2:
764+
raise IOError('must provide indices of length 2 for transfer functions')
765+
761766
key1, key2 = key
767+
if not isinstance(key1, (int, slice)) or not isinstance(key2, (int, slice)):
768+
raise TypeError(f"system indices must be integers or slices")
762769

763770
# pre-process
764771
if isinstance(key1, int):

0 commit comments

Comments
 (0)