Skip to content

Commit 51c797e

Browse files
authored
Merge pull request #445 from bnavigator/extend-tf2scipylti
Extend returnScipySignalLTI() to discrete systems
2 parents 26206dc + d6d77dc commit 51c797e

File tree

4 files changed

+161
-16
lines changed

4 files changed

+161
-16
lines changed

control/statesp.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
from numpy.linalg import solve, eigvals, matrix_rank
6060
from numpy.linalg.linalg import LinAlgError
6161
import scipy as sp
62-
from scipy.signal import lti, cont2discrete
62+
from scipy.signal import cont2discrete
63+
from scipy.signal import StateSpace as signalStateSpace
6364
from warnings import warn
6465
from .lti import LTI, timebase, timebaseEqual, isdtime
6566
from . import config
@@ -200,7 +201,7 @@ def __init__(self, *args, **kw):
200201
raise ValueError("Needs 1 or 4 arguments; received %i." % len(args))
201202

202203
# Process keyword arguments
203-
remove_useless = kw.get('remove_useless',
204+
remove_useless = kw.get('remove_useless',
204205
config.defaults['statesp.remove_useless_states'])
205206

206207
# Convert all matrices to standard form
@@ -798,9 +799,7 @@ def minreal(self, tol=0.0):
798799
else:
799800
return StateSpace(self)
800801

801-
802-
# TODO: add discrete time check
803-
def returnScipySignalLTI(self):
802+
def returnScipySignalLTI(self, strict=True):
804803
"""Return a list of a list of :class:`scipy.signal.lti` objects.
805804
806805
For instance,
@@ -809,15 +808,45 @@ def returnScipySignalLTI(self):
809808
>>> out[3][5]
810809
811810
is a :class:`scipy.signal.lti` object corresponding to the transfer
812-
function from the 6th input to the 4th output."""
811+
function from the 6th input to the 4th output.
812+
813+
Parameters
814+
----------
815+
strict : bool, optional
816+
True (default):
817+
The timebase `ssobject.dt` cannot be None; it must
818+
be continuous (0) or discrete (True or > 0).
819+
False:
820+
If `ssobject.dt` is None, continuous time
821+
:class:`scipy.signal.lti` objects are returned.
822+
823+
Returns
824+
-------
825+
out : list of list of :class:`scipy.signal.StateSpace`
826+
continuous time (inheriting from :class:`scipy.signal.lti`)
827+
or discrete time (inheriting from :class:`scipy.signal.dlti`)
828+
SISO objects
829+
"""
830+
if strict and self.dt is None:
831+
raise ValueError("with strict=True, dt cannot be None")
832+
833+
if self.dt:
834+
kwdt = {'dt': self.dt}
835+
else:
836+
# scipy convention for continuous time lti systems: call without
837+
# dt keyword argument
838+
kwdt = {}
813839

814840
# Preallocate the output.
815841
out = [[[] for _ in range(self.inputs)] for _ in range(self.outputs)]
816842

817843
for i in range(self.outputs):
818844
for j in range(self.inputs):
819-
out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]),
820-
asarray(self.C[i, :]), self.D[i, j])
845+
out[i][j] = signalStateSpace(asarray(self.A),
846+
asarray(self.B[:, j:j + 1]),
847+
asarray(self.C[i:i + 1, :]),
848+
asarray(self.D[i:i + 1, j:j + 1]),
849+
**kwdt)
821850

822851
return out
823852

control/tests/statesp_matrix_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import unittest
77
import numpy as np
8+
import pytest
89
from numpy.linalg import solve
910
from scipy.linalg import eigvals, block_diag
1011
from control import matlab
@@ -673,5 +674,56 @@ def test_sample_system_prewarping(self):
673674
decimal=4)
674675

675676

677+
class TestLTIConverter:
678+
"""Test returnScipySignalLTI method"""
679+
680+
@pytest.fixture
681+
def mimoss(self, request):
682+
"""Test system with various dt values"""
683+
n = 5
684+
m = 3
685+
p = 2
686+
bx, bu = np.mgrid[1:n + 1, 1:m + 1]
687+
cy, cx = np.mgrid[1:p + 1, 1:n + 1]
688+
dy, du = np.mgrid[1:p + 1, 1:m + 1]
689+
return StateSpace(np.eye(5) + np.eye(5, 5, 1),
690+
bx * bu,
691+
cy * cx,
692+
dy * du,
693+
request.param)
694+
695+
@pytest.mark.parametrize("mimoss",
696+
[None,
697+
0,
698+
0.1,
699+
1,
700+
True],
701+
indirect=True)
702+
def test_returnScipySignalLTI(self, mimoss):
703+
"""Test returnScipySignalLTI method with strict=False"""
704+
sslti = mimoss.returnScipySignalLTI(strict=False)
705+
for i in range(mimoss.outputs):
706+
for j in range(mimoss.inputs):
707+
np.testing.assert_allclose(sslti[i][j].A, mimoss.A)
708+
np.testing.assert_allclose(sslti[i][j].B, mimoss.B[:,
709+
j:j + 1])
710+
np.testing.assert_allclose(sslti[i][j].C, mimoss.C[i:i + 1,
711+
:])
712+
np.testing.assert_allclose(sslti[i][j].D, mimoss.D[i:i + 1,
713+
j:j + 1])
714+
if mimoss.dt == 0:
715+
assert sslti[i][j].dt is None
716+
else:
717+
assert sslti[i][j].dt == mimoss.dt
718+
719+
@pytest.mark.parametrize("mimoss", [None], indirect=True)
720+
def test_returnScipySignalLTI_error(self, mimoss):
721+
"""Test returnScipySignalLTI method with dt=None and strict=True"""
722+
with pytest.raises(ValueError):
723+
mimoss.returnScipySignalLTI()
724+
with pytest.raises(ValueError):
725+
mimoss.returnScipySignalLTI(strict=True)
726+
727+
676728
if __name__ == "__main__":
677729
unittest.main()

control/tests/xferfcn_test.py

+41
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# RMM, 30 Mar 2011 (based on TestXferFcn from v0.4a)
55

66
import unittest
7+
import pytest
8+
79
import sys as pysys
810
import numpy as np
911
from control.statesp import StateSpace, _convertToStateSpace, rss
@@ -934,5 +936,44 @@ def test_sample_system_prewarping(self):
934936
decimal=4)
935937

936938

939+
class TestLTIConverter:
940+
"""Test returnScipySignalLTI method"""
941+
942+
@pytest.fixture
943+
def mimotf(self, request):
944+
"""Test system with various dt values"""
945+
return TransferFunction([[[11], [12], [13]],
946+
[[21], [22], [23]]],
947+
[[[1, -1]] * 3] * 2,
948+
request.param)
949+
950+
@pytest.mark.parametrize("mimotf",
951+
[None,
952+
0,
953+
0.1,
954+
1,
955+
True],
956+
indirect=True)
957+
def test_returnScipySignalLTI(self, mimotf):
958+
"""Test returnScipySignalLTI method with strict=False"""
959+
sslti = mimotf.returnScipySignalLTI(strict=False)
960+
for i in range(2):
961+
for j in range(3):
962+
np.testing.assert_allclose(sslti[i][j].num, mimotf.num[i][j])
963+
np.testing.assert_allclose(sslti[i][j].den, mimotf.den[i][j])
964+
if mimotf.dt == 0:
965+
assert sslti[i][j].dt is None
966+
else:
967+
assert sslti[i][j].dt == mimotf.dt
968+
969+
@pytest.mark.parametrize("mimotf", [None], indirect=True)
970+
def test_returnScipySignalLTI_error(self, mimotf):
971+
"""Test returnScipySignalLTI method with dt=None and strict=True"""
972+
with pytest.raises(ValueError):
973+
mimotf.returnScipySignalLTI()
974+
with pytest.raises(ValueError):
975+
mimotf.returnScipySignalLTI(strict=True)
976+
977+
937978
if __name__ == "__main__":
938979
unittest.main()

control/xferfcn.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@
5757
polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \
5858
where, delete, real, poly, nonzero
5959
import scipy as sp
60-
from scipy.signal import lti, tf2zpk, zpk2tf, cont2discrete
60+
from scipy.signal import tf2zpk, zpk2tf, cont2discrete
61+
from scipy.signal import TransferFunction as signalTransferFunction
6162
from copy import deepcopy
6263
from warnings import warn
6364
from itertools import chain
@@ -801,30 +802,52 @@ def minreal(self, tol=None):
801802
# end result
802803
return TransferFunction(num, den, self.dt)
803804

804-
def returnScipySignalLTI(self):
805+
def returnScipySignalLTI(self, strict=True):
805806
"""Return a list of a list of :class:`scipy.signal.lti` objects.
806807
807808
For instance,
808809
809810
>>> out = tfobject.returnScipySignalLTI()
810811
>>> out[3][5]
811812
812-
is a class:`scipy.signal.lti` object corresponding to the
813+
is a :class:`scipy.signal.lti` object corresponding to the
813814
transfer function from the 6th input to the 4th output.
814815
816+
Parameters
817+
----------
818+
strict : bool, optional
819+
True (default):
820+
The timebase `tfobject.dt` cannot be None; it must be
821+
continuous (0) or discrete (True or > 0).
822+
False:
823+
if `tfobject.dt` is None, continuous time
824+
:class:`scipy.signal.lti`objects are returned
825+
826+
Returns
827+
-------
828+
out : list of list of :class:`scipy.signal.TransferFunction`
829+
continuous time (inheriting from :class:`scipy.signal.lti`)
830+
or discrete time (inheriting from :class:`scipy.signal.dlti`)
831+
SISO objects
815832
"""
833+
if strict and self.dt is None:
834+
raise ValueError("with strict=True, dt cannot be None")
816835

817-
# TODO: implement for discrete time systems
818-
if self.dt != 0 and self.dt is not None:
819-
raise NotImplementedError("Function not \
820-
implemented in discrete time")
836+
if self.dt:
837+
kwdt = {'dt': self.dt}
838+
else:
839+
# scipy convention for continuous time lti systems: call without
840+
# dt keyword argument
841+
kwdt = {}
821842

822843
# Preallocate the output.
823844
out = [[[] for j in range(self.inputs)] for i in range(self.outputs)]
824845

825846
for i in range(self.outputs):
826847
for j in range(self.inputs):
827-
out[i][j] = lti(self.num[i][j], self.den[i][j])
848+
out[i][j] = signalTransferFunction(self.num[i][j],
849+
self.den[i][j],
850+
**kwdt)
828851

829852
return out
830853

0 commit comments

Comments
 (0)