Skip to content

Commit 6b61ed0

Browse files
committed
Add slicing access for state-space models with tests
1 parent ad6b49e commit 6b61ed0

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

control/statesp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import math
5151
from copy import deepcopy
5252
from warnings import warn
53+
from collections.abc import Iterable
5354

5455
import numpy as np
5556
import scipy as sp
@@ -1215,17 +1216,16 @@ def append(self, other):
12151216

12161217
def __getitem__(self, indices):
12171218
"""Array style access"""
1218-
if len(indices) != 2:
1219+
if not isinstance(indices, Iterable) or len(indices) != 2:
12191220
raise IOError('must provide indices of length 2 for state space')
1220-
outdx = indices[0] if isinstance(indices[0], list) else [indices[0]]
1221-
inpdx = indices[1] if isinstance(indices[1], list) else [indices[1]]
1221+
outdx, inpdx = indices
1222+
if not isinstance(outdx, (int, slice)) or not isinstance(inpdx, (int, slice)):
1223+
raise TypeError(f"system indices must be integers or slices")
12221224
sysname = config.defaults['iosys.indexed_system_name_prefix'] + \
12231225
self.name + config.defaults['iosys.indexed_system_name_suffix']
12241226
return StateSpace(
12251227
self.A, self.B[:, inpdx], self.C[outdx, :], self.D[outdx, inpdx],
1226-
self.dt, name=sysname,
1227-
inputs=[self.input_labels[i] for i in list(inpdx)],
1228-
outputs=[self.output_labels[i] for i in list(outdx)])
1228+
self.dt, name=sysname, inputs=self.input_labels[inpdx], outputs=self.output_labels[outdx])
12291229

12301230
def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None,
12311231
name=None, copy_names=True, **kwargs):

control/tests/statesp_test.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,24 +463,38 @@ def test_append_tf(self):
463463
np.testing.assert_array_almost_equal(sys3c.A[:3, 3:], np.zeros((3, 2)))
464464
np.testing.assert_array_almost_equal(sys3c.A[3:, :3], np.zeros((2, 3)))
465465

466-
def test_array_access_ss(self):
467-
466+
def test_array_access_ss_failure(self):
467+
sys1 = StateSpace(
468+
[[1., 2.], [3., 4.]],
469+
[[5., 6.], [6., 8.]],
470+
[[9., 10.], [11., 12.]],
471+
[[13., 14.], [15., 16.]], 1,
472+
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
473+
with pytest.raises(IOError):
474+
sys1[0]
475+
476+
@pytest.mark.parametrize("outdx, inpdx",
477+
[(0, 1),
478+
(slice(0, 1, 1), 1),
479+
(0, slice(1, 2, 1)),
480+
(slice(0, 1, 1), slice(1, 2, 1))])
481+
def test_array_access_ss(self, outdx, inpdx):
468482
sys1 = StateSpace(
469483
[[1., 2.], [3., 4.]],
470484
[[5., 6.], [6., 8.]],
471485
[[9., 10.], [11., 12.]],
472486
[[13., 14.], [15., 16.]], 1,
473487
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
474488

475-
sys1_01 = sys1[0, 1]
489+
sys1_01 = sys1[outdx, inpdx]
476490
np.testing.assert_array_almost_equal(sys1_01.A,
477491
sys1.A)
478492
np.testing.assert_array_almost_equal(sys1_01.B,
479493
sys1.B[:, 1:2])
480494
np.testing.assert_array_almost_equal(sys1_01.C,
481495
sys1.C[0:1, :])
482496
np.testing.assert_array_almost_equal(sys1_01.D,
483-
sys1.D[0, 1])
497+
sys1.D[0:1, 1:2])
484498

485499
assert sys1.dt == sys1_01.dt
486500
assert sys1_01.input_labels == ['u1']

control/xferfcn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
4848
"""
4949

50+
from collections.abc import Iterable
51+
5052
# External function declarations
5153
import numpy as np
5254
from numpy import angle, array, empty, finfo, ndarray, ones, \
@@ -758,7 +760,12 @@ def __pow__(self, other):
758760
return (TransferFunction([1], [1]) / self) * (self**(other + 1))
759761

760762
def __getitem__(self, key):
763+
if not isinstance(key, Iterable) or len(key) != 2:
764+
raise IOError('must provide indices of length 2 for state space')
765+
761766
key1, key2 = key
767+
if not isinstance(key1, (int, slice)) or not isinstance(key2, (int, slice)):
768+
raise TypeError(f"system indices must be integers or slices")
762769

763770
# pre-process
764771
if isinstance(key1, int):

0 commit comments

Comments
 (0)