Skip to content

Commit ab77e02

Browse files
committed
Merge branch 'extend-tf2scipylti' into array-matrix-tests
2 parents 4824143 + a8aa41e commit ab77e02

File tree

4 files changed

+146
-17
lines changed

4 files changed

+146
-17
lines changed

control/statesp.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
from numpy.linalg import solve, eigvals, matrix_rank
6060
from numpy.linalg.linalg import LinAlgError
6161
import scipy as sp
62-
from scipy.signal import lti, cont2discrete
62+
from scipy.signal import cont2discrete
63+
from scipy.signal import StateSpace as signalStateSpace
6364
from warnings import warn
6465
from .lti import LTI, common_timebase, isdtime
6566
from . import config
@@ -802,26 +803,51 @@ def minreal(self, tol=0.0):
802803
else:
803804
return StateSpace(self)
804805

805-
806-
# TODO: add discrete time check
807-
def returnScipySignalLTI(self):
808-
"""Return a list of a list of scipy.signal.lti objects.
806+
def returnScipySignalLTI(self, strict=True):
807+
"""Return a list of a list of SISO scipy.signal.lti objects.
809808
810809
For instance,
811810
812811
>>> out = ssobject.returnScipySignalLTI()
813812
>>> out[3][5]
814813
815-
is a signal.scipy.lti object corresponding to the transfer function from
816-
the 6th input to the 4th output."""
814+
is a signal.scipy.lti object corresponding to the transfer function
815+
from the 6th input to the 4th output.
816+
817+
Parameters
818+
----------
819+
strict : bool, optional
820+
True (default):
821+
`ssobject` must be continuous or discrete. `tfobject.dt` cannot
822+
be None.
823+
False:
824+
if `ssobject.dt` is None, continuous time signal.StateSpace
825+
objects are returned
826+
827+
Returns
828+
-------
829+
out : list of list of scipy.signal.StateSpace
830+
"""
831+
if strict and self.dt is None:
832+
raise ValueError("with strict=True, dt cannot be None")
833+
834+
if self.dt:
835+
kwdt = {'dt': self.dt}
836+
else:
837+
# scipy convention for continuous time lti systems: call without
838+
# dt keyword argument
839+
kwdt = {}
817840

818841
# Preallocate the output.
819842
out = [[[] for _ in range(self.inputs)] for _ in range(self.outputs)]
820843

821844
for i in range(self.outputs):
822845
for j in range(self.inputs):
823-
out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]),
824-
asarray(self.C[i, :]), self.D[i, j])
846+
out[i][j] = signalStateSpace(asarray(self.A),
847+
asarray(self.B[:, j]),
848+
asarray(self.C[i, :]),
849+
self.D[i, j],
850+
**kwdt)
825851

826852
return out
827853

control/tests/statesp_test.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010

1111
import numpy as np
12-
from numpy.linalg import solve
1312
import pytest
13+
14+
from numpy.linalg import solve
1415
from scipy.linalg import eigvals, block_diag
1516

1617
from control.statesp import StateSpace, _convertToStateSpace, drss, rss, tf2ss
@@ -698,3 +699,49 @@ def test_pole(self, states, outputs, inputs):
698699
assert abs(z) < 1
699700

700701

702+
class TestLTIConverter:
703+
"""Test the LTI system return function"""
704+
705+
@pytest.fixture
706+
def mimoss(self, request):
707+
"""Test system with various dt values"""
708+
n = 5
709+
m = 3
710+
p = 2
711+
bx, bu = np.mgrid[1:n + 1, 1:m + 1]
712+
cy, cx = np.mgrid[1:p + 1, 1:n + 1]
713+
dy, du = np.mgrid[1:p + 1, 1:m + 1]
714+
return StateSpace(np.eye(5),
715+
bx * bu,
716+
cy * cx,
717+
dy * du,
718+
request.param)
719+
720+
@pytest.mark.parametrize("mimoss",
721+
[None,
722+
0,
723+
0.1,
724+
1,
725+
True],
726+
indirect=True)
727+
def test_returnScipySignalLTI(self, mimoss):
728+
"""Test returnScipySignalLTI method with strict=False"""
729+
sslti = mimoss.returnScipySignalLTI(strict=False)
730+
for i in range(2):
731+
for j in range(3):
732+
np.testing.assert_allclose(sslti[i][j].A, mimoss.A)
733+
np.testing.assert_allclose(sslti[i][j].B, mimoss.B[:, j])
734+
np.testing.assert_allclose(sslti[i][j].C, mimoss.C[i, :])
735+
np.testing.assert_allclose(sslti[i][j].D, mimoss.D[i, j])
736+
if mimoss.dt == 0:
737+
assert sslti[i][j].dt is None
738+
else:
739+
assert sslti[i][j].dt == mimoss.dt
740+
741+
@pytest.mark.parametrize("mimoss", [None], indirect=True)
742+
def test_returnScipySignalLTI_error(self, mimoss):
743+
"""Test returnScipySignalLTI method with dt=None and strict=True"""
744+
with pytest.raises(ValueError):
745+
mimoss.returnScipySignalLTI()
746+
with pytest.raises(ValueError):
747+
mimoss.returnScipySignalLTI(strict=True)

control/tests/xferfcn_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,39 @@ def test_repr(self, Hargs, ref):
957957
np.testing.assert_array_almost_equal(H.num[p][m], H2.num[p][m])
958958
np.testing.assert_array_almost_equal(H.den[p][m], H2.den[p][m])
959959
assert H.dt == H2.dt
960+
961+
@pytest.fixture
962+
def mimotf(self, request):
963+
"""Test system with various dt values"""
964+
return TransferFunction([[[11], [12], [13]],
965+
[[21], [22], [23]]],
966+
[[[1, -1]] * 3] * 2,
967+
request.param)
968+
969+
@pytest.mark.parametrize("mimotf",
970+
[None,
971+
0,
972+
0.1,
973+
1,
974+
True],
975+
indirect=True)
976+
def test_returnScipySignalLTI(self, mimotf):
977+
"""Test returnScipySignalLTI method with strict=False"""
978+
sslti = mimotf.returnScipySignalLTI(strict=False)
979+
for i in range(2):
980+
for j in range(3):
981+
np.testing.assert_allclose(sslti[i][j].num, mimotf.num[i][j])
982+
np.testing.assert_allclose(sslti[i][j].den, mimotf.den[i][j])
983+
if mimotf.dt == 0:
984+
assert sslti[i][j].dt is None
985+
else:
986+
assert sslti[i][j].dt == mimotf.dt
987+
988+
@pytest.mark.parametrize("mimotf", [None], indirect=True)
989+
def test_returnScipySignalLTI_error(self, mimotf):
990+
"""Test returnScipySignalLTI method with dt=None and default strict=True"""
991+
with pytest.raises(ValueError):
992+
mimotf.returnScipySignalLTI()
993+
with pytest.raises(ValueError):
994+
mimotf.returnScipySignalLTI(strict=True)
995+

control/xferfcn.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@
5757
polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \
5858
where, delete, real, poly, nonzero
5959
import scipy as sp
60-
from scipy.signal import lti, tf2zpk, zpk2tf, cont2discrete
60+
from scipy.signal import tf2zpk, zpk2tf, cont2discrete
61+
from scipy.signal import TransferFunction as signalTransferFunction
6162
from copy import deepcopy
6263
from warnings import warn
6364
from itertools import chain
@@ -788,7 +789,7 @@ def minreal(self, tol=None):
788789
# end result
789790
return TransferFunction(num, den, self.dt)
790791

791-
def returnScipySignalLTI(self):
792+
def returnScipySignalLTI(self, strict=True):
792793
"""Return a list of a list of scipy.signal.lti objects.
793794
794795
For instance,
@@ -799,19 +800,38 @@ def returnScipySignalLTI(self):
799800
is a signal.scipy.lti object corresponding to the
800801
transfer function from the 6th input to the 4th output.
801802
803+
Parameters
804+
----------
805+
strict : bool, optional
806+
True (default):
807+
`tfobject` must be continuous or discrete.
808+
`tfobject.dt`cannot be None.
809+
False:
810+
if `tfobject.dt` is None, continuous time signal.TransferFunction
811+
objects are is returned
812+
813+
Returns
814+
-------
815+
out : list of list of scipy.signal.TransferFunction
802816
"""
817+
if strict and self.dt is None:
818+
raise ValueError("with strict=True, dt cannot be None")
803819

804-
# TODO: implement for discrete time systems
805-
if self.dt != 0 and self.dt is not None:
806-
raise NotImplementedError("Function not \
807-
implemented in discrete time")
820+
if self.dt:
821+
kwdt = {'dt': self.dt}
822+
else:
823+
# scipy convention for continuous time lti systems: call without
824+
# dt keyword argument
825+
kwdt = {}
808826

809827
# Preallocate the output.
810828
out = [[[] for j in range(self.inputs)] for i in range(self.outputs)]
811829

812830
for i in range(self.outputs):
813831
for j in range(self.inputs):
814-
out[i][j] = lti(self.num[i][j], self.den[i][j])
832+
out[i][j] = signalTransferFunction(self.num[i][j],
833+
self.den[i][j],
834+
**kwdt)
815835

816836
return out
817837

0 commit comments

Comments
 (0)