Skip to content

Commit cf0b866

Browse files
committed
returnScipySignalLTI for discrete systems
1 parent ba7817c commit cf0b866

File tree

4 files changed

+157
-44
lines changed

4 files changed

+157
-44
lines changed

control/statesp.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
from numpy.linalg import solve, eigvals, matrix_rank
6060
from numpy.linalg.linalg import LinAlgError
6161
import scipy as sp
62-
from scipy.signal import lti, cont2discrete
62+
from scipy.signal import cont2discrete
63+
from scipy.signal import StateSpace as signalStateSpace
6364
from warnings import warn
6465
from .lti import LTI, timebase, timebaseEqual, isdtime
6566
from . import config
@@ -72,7 +73,7 @@
7273
_statesp_defaults = {
7374
'statesp.use_numpy_matrix': True,
7475
'statesp.default_dt': None,
75-
'statesp.remove_useless_states': True,
76+
'statesp.remove_useless_states': True,
7677
}
7778

7879

@@ -149,7 +150,7 @@ class StateSpace(LTI):
149150
Setting dt = 0 specifies a continuous system, while leaving dt = None
150151
means the system timebase is not specified. If 'dt' is set to True, the
151152
system will be treated as a discrete time system with unspecified sampling
152-
time. The default value of 'dt' is None and can be changed by changing the
153+
time. The default value of 'dt' is None and can be changed by changing the
153154
value of ``control.config.defaults['statesp.default_dt']``.
154155
155156
"""
@@ -785,26 +786,51 @@ def minreal(self, tol=0.0):
785786
else:
786787
return StateSpace(self)
787788

788-
789-
# TODO: add discrete time check
790-
def returnScipySignalLTI(self):
791-
"""Return a list of a list of scipy.signal.lti objects.
789+
def returnScipySignalLTI(self, strict=True):
790+
"""Return a list of a list of SISO scipy.signal.lti objects.
792791
793792
For instance,
794793
795794
>>> out = ssobject.returnScipySignalLTI()
796795
>>> out[3][5]
797796
798-
is a signal.scipy.lti object corresponding to the transfer function from
799-
the 6th input to the 4th output."""
797+
is a signal.scipy.lti object corresponding to the transfer function
798+
from the 6th input to the 4th output.
799+
800+
Parameters
801+
----------
802+
strict : bool, optional
803+
True (default):
804+
`ssobject` must be continuous or discrete. `tfobject.dt` cannot
805+
be None.
806+
False:
807+
if `ssobject.dt` is None, continuous time signal.StateSpace
808+
objects are returned
809+
810+
Returns
811+
-------
812+
out : list of list of scipy.signal.StateSpace
813+
"""
814+
if strict and self.dt is None:
815+
raise ValueError("with strict=True, dt cannot be None")
816+
817+
if self.dt:
818+
kwdt = {'dt': self.dt}
819+
else:
820+
# scipy convention for continuous time lti systems: call without
821+
# dt keyword argument
822+
kwdt = {}
800823

801824
# Preallocate the output.
802825
out = [[[] for _ in range(self.inputs)] for _ in range(self.outputs)]
803826

804827
for i in range(self.outputs):
805828
for j in range(self.inputs):
806-
out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]),
807-
asarray(self.C[i, :]), self.D[i, j])
829+
out[i][j] = signalStateSpace(asarray(self.A),
830+
asarray(self.B[:, j]),
831+
asarray(self.C[i, :]),
832+
self.D[i, j],
833+
**kwdt)
808834

809835
return out
810836

@@ -870,8 +896,8 @@ def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None):
870896
871897
prewarp_frequency : float within [0, infinity)
872898
The frequency [rad/s] at which to match with the input continuous-
873-
time system's magnitude and phase (the gain=1 crossover frequency,
874-
for example). Should only be specified with method='bilinear' or
899+
time system's magnitude and phase (the gain=1 crossover frequency,
900+
for example). Should only be specified with method='bilinear' or
875901
'gbt' with alpha=0.5 and ignored otherwise.
876902
877903
Returns
@@ -896,7 +922,7 @@ def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None):
896922
if (method=='bilinear' or (method=='gbt' and alpha==0.5)) and \
897923
prewarp_frequency is not None:
898924
Twarp = 2*np.tan(prewarp_frequency*Ts/2)/prewarp_frequency
899-
else:
925+
else:
900926
Twarp = Ts
901927
Ad, Bd, C, D, _ = cont2discrete(sys, Twarp, method, alpha)
902928
return StateSpace(Ad, Bd, C, D, Ts)

control/tests/statesp_test.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import unittest
77
import numpy as np
8+
import pytest
89
from numpy.linalg import solve
910
from scipy.linalg import eigvals, block_diag
1011
from control import matlab
@@ -653,7 +654,7 @@ def test_copy_constructor(self):
653654
linsys.A[0, 0] = -3
654655
np.testing.assert_array_equal(cpysys.A, [[-1]]) # original value
655656

656-
def test_sample_system_prewarping(self):
657+
def test_sample_system_prewarping(self):
657658
"""test that prewarping works when converting from cont to discrete time system"""
658659
A = np.array([
659660
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
@@ -668,10 +669,57 @@ def test_sample_system_prewarping(self):
668669
plant = StateSpace(A,B,C,0)
669670
plant_d_warped = plant.sample(Ts, 'bilinear', prewarp_frequency=wwarp)
670671
np.testing.assert_array_almost_equal(
671-
evalfr(plant, wwarp*1j),
672-
evalfr(plant_d_warped, np.exp(wwarp*1j*Ts)),
672+
evalfr(plant, wwarp*1j),
673+
evalfr(plant_d_warped, np.exp(wwarp*1j*Ts)),
673674
decimal=4)
674675

675676

677+
@pytest.fixture
678+
def mimoss(request):
679+
"""Test system with various dt values"""
680+
n = 5
681+
m = 3
682+
p = 2
683+
bx, bu = np.mgrid[1:n + 1, 1:m + 1]
684+
cy, cx = np.mgrid[1:p + 1, 1:n + 1]
685+
dy, du = np.mgrid[1:p + 1, 1:m + 1]
686+
return StateSpace(np.eye(5),
687+
bx * bu,
688+
cy * cx,
689+
dy * du,
690+
request.param)
691+
692+
693+
@pytest.mark.parametrize("mimoss",
694+
[None,
695+
0,
696+
0.1,
697+
1,
698+
True],
699+
indirect=True)
700+
def test_returnScipySignalLTI(mimoss):
701+
"""Test returnScipySignalLTI method with strict=False"""
702+
sslti = mimoss.returnScipySignalLTI(strict=False)
703+
for i in range(2):
704+
for j in range(3):
705+
np.testing.assert_allclose(sslti[i][j].A, mimoss.A)
706+
np.testing.assert_allclose(sslti[i][j].B, mimoss.B[:, j])
707+
np.testing.assert_allclose(sslti[i][j].C, mimoss.C[i, :])
708+
np.testing.assert_allclose(sslti[i][j].D, mimoss.D[i, j])
709+
if mimoss.dt == 0:
710+
assert sslti[i][j].dt is None
711+
else:
712+
assert sslti[i][j].dt == mimoss.dt
713+
714+
715+
@pytest.mark.parametrize("mimoss", [None], indirect=True)
716+
def test_returnScipySignalLTI_error(mimoss):
717+
"""Test returnScipySignalLTI method with dt=None and default strict=True"""
718+
with pytest.raises(ValueError):
719+
mimoss.returnScipySignalLTI()
720+
with pytest.raises(ValueError):
721+
mimoss.returnScipySignalLTI(strict=True)
722+
723+
676724
if __name__ == "__main__":
677725
unittest.main()

control/tests/xferfcn_test.py

+33-14
Original file line numberDiff line numberDiff line change
@@ -935,24 +935,43 @@ def test_sample_system_prewarping(self):
935935
evalfr(plant_d_warped, np.exp(wwarp*1j*Ts)),
936936
decimal=4)
937937

938-
@pytest.mark.parametrize("dt",
938+
939+
@pytest.fixture
940+
def mimotf(request):
941+
"""Test system with various dt values"""
942+
return TransferFunction([[[11], [12], [13]],
943+
[[21], [22], [23]]],
944+
[[[1, -1]] * 3] * 2,
945+
request.param)
946+
947+
948+
@pytest.mark.parametrize("mimotf",
939949
[None,
940950
0,
941-
pytest.param(1, marks=pytest.mark.xfail(
942-
reason="not implemented")),
943-
pytest.param(1, marks=pytest.mark.xfail(
944-
reason="not implemented"))])
945-
def test_returnScipySignalLTI(dt):
946-
"""Test returnScipySignalLTI method"""
947-
sys = TransferFunction([[[11], [12], [13]],
948-
[[21], [22], [23]]],
949-
[[[1, -1]] * 3] * 2,
950-
dt)
951-
sslti = sys.returnScipySignalLTI()
951+
0.1,
952+
1,
953+
True],
954+
indirect=True)
955+
def test_returnScipySignalLTI(mimotf):
956+
"""Test returnScipySignalLTI method with strict=False"""
957+
sslti = mimotf.returnScipySignalLTI(strict=False)
952958
for i in range(2):
953959
for j in range(3):
954-
np.testing.assert_allclose(sslti[i][j].num, sys.num[i][j])
955-
np.testing.assert_allclose(sslti[i][j].den, sys.den[i][j])
960+
np.testing.assert_allclose(sslti[i][j].num, mimotf.num[i][j])
961+
np.testing.assert_allclose(sslti[i][j].den, mimotf.den[i][j])
962+
if mimotf.dt == 0:
963+
assert sslti[i][j].dt is None
964+
else:
965+
assert sslti[i][j].dt == mimotf.dt
966+
967+
968+
@pytest.mark.parametrize("mimotf", [None], indirect=True)
969+
def test_returnScipySignalLTI_error(mimotf):
970+
"""Test returnScipySignalLTI method with dt=None and default strict=True"""
971+
with pytest.raises(ValueError):
972+
mimotf.returnScipySignalLTI()
973+
with pytest.raises(ValueError):
974+
mimotf.returnScipySignalLTI(strict=True)
956975

957976

958977
if __name__ == "__main__":

control/xferfcn.py

+33-13
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@
5757
polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \
5858
where, delete, real, poly, nonzero
5959
import scipy as sp
60-
from scipy.signal import lti, tf2zpk, zpk2tf, cont2discrete
60+
from scipy.signal import tf2zpk, zpk2tf, cont2discrete
61+
from scipy.signal import TransferFunction as signalTransferFunction
6162
from copy import deepcopy
6263
from warnings import warn
6364
from itertools import chain
@@ -93,8 +94,8 @@ class TransferFunction(LTI):
9394
instance variable and setting it to something other than 'None'. If 'dt'
9495
has a non-zero value, then it must match whenever two transfer functions
9596
are combined. If 'dt' is set to True, the system will be treated as a
96-
discrete time system with unspecified sampling time. The default value of
97-
'dt' is None and can be changed by changing the value of
97+
discrete time system with unspecified sampling time. The default value of
98+
'dt' is None and can be changed by changing the value of
9899
``control.config.defaults['xferfcn.default_dt']``.
99100
100101
The TransferFunction class defines two constants ``s`` and ``z`` that
@@ -801,7 +802,7 @@ def minreal(self, tol=None):
801802
# end result
802803
return TransferFunction(num, den, self.dt)
803804

804-
def returnScipySignalLTI(self):
805+
def returnScipySignalLTI(self, strict=True):
805806
"""Return a list of a list of scipy.signal.lti objects.
806807
807808
For instance,
@@ -812,19 +813,38 @@ def returnScipySignalLTI(self):
812813
is a signal.scipy.lti object corresponding to the
813814
transfer function from the 6th input to the 4th output.
814815
816+
Parameters
817+
----------
818+
strict : bool, optional
819+
True (default):
820+
`tfobject` must be continuous or discrete.
821+
`tfobject.dt`cannot be None.
822+
False:
823+
if `tfobject.dt` is None, continuous time signal.TransferFunction
824+
objects are is returned
825+
826+
Returns
827+
-------
828+
out : list of list of scipy.signal.TransferFunction
815829
"""
830+
if strict and self.dt is None:
831+
raise ValueError("with strict=True, dt cannot be None")
816832

817-
# TODO: implement for discrete time systems
818-
if self.dt != 0 and self.dt is not None:
819-
raise NotImplementedError("Function not \
820-
implemented in discrete time")
833+
if self.dt:
834+
kwdt = {'dt': self.dt}
835+
else:
836+
# scipy convention for continuous time lti systems: call without
837+
# dt keyword argument
838+
kwdt = {}
821839

822840
# Preallocate the output.
823841
out = [[[] for j in range(self.inputs)] for i in range(self.outputs)]
824842

825843
for i in range(self.outputs):
826844
for j in range(self.inputs):
827-
out[i][j] = lti(self.num[i][j], self.den[i][j])
845+
out[i][j] = signalTransferFunction(self.num[i][j],
846+
self.den[i][j],
847+
**kwdt)
828848

829849
return out
830850

@@ -1016,11 +1036,11 @@ def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None):
10161036
The generalized bilinear transformation weighting parameter, which
10171037
should only be specified with method="gbt", and is ignored
10181038
otherwise.
1019-
1039+
10201040
prewarp_frequency : float within [0, infinity)
10211041
The frequency [rad/s] at which to match with the input continuous-
1022-
time system's magnitude and phase (the gain=1 crossover frequency,
1023-
for example). Should only be specified with method='bilinear' or
1042+
time system's magnitude and phase (the gain=1 crossover frequency,
1043+
for example). Should only be specified with method='bilinear' or
10241044
'gbt' with alpha=0.5 and ignored otherwise.
10251045
10261046
Returns
@@ -1050,7 +1070,7 @@ def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None):
10501070
if (method=='bilinear' or (method=='gbt' and alpha==0.5)) and \
10511071
prewarp_frequency is not None:
10521072
Twarp = 2*np.tan(prewarp_frequency*Ts/2)/prewarp_frequency
1053-
else:
1073+
else:
10541074
Twarp = Ts
10551075
numd, dend, _ = cont2discrete(sys, Twarp, method, alpha)
10561076
return TransferFunction(numd[0, :], dend, Ts)

0 commit comments

Comments
 (0)