diff --git a/control/statesp.py b/control/statesp.py index 5af916bf0..d23fbd7be 100644 --- a/control/statesp.py +++ b/control/statesp.py @@ -59,7 +59,8 @@ from numpy.linalg import solve, eigvals, matrix_rank from numpy.linalg.linalg import LinAlgError import scipy as sp -from scipy.signal import lti, cont2discrete +from scipy.signal import cont2discrete +from scipy.signal import StateSpace as signalStateSpace from warnings import warn from .lti import LTI, timebase, timebaseEqual, isdtime from . import config @@ -200,7 +201,7 @@ def __init__(self, *args, **kw): raise ValueError("Needs 1 or 4 arguments; received %i." % len(args)) # Process keyword arguments - remove_useless = kw.get('remove_useless', + remove_useless = kw.get('remove_useless', config.defaults['statesp.remove_useless_states']) # Convert all matrices to standard form @@ -798,9 +799,7 @@ def minreal(self, tol=0.0): else: return StateSpace(self) - - # TODO: add discrete time check - def returnScipySignalLTI(self): + def returnScipySignalLTI(self, strict=True): """Return a list of a list of :class:`scipy.signal.lti` objects. For instance, @@ -809,15 +808,45 @@ def returnScipySignalLTI(self): >>> out[3][5] is a :class:`scipy.signal.lti` object corresponding to the transfer - function from the 6th input to the 4th output.""" + function from the 6th input to the 4th output. + + Parameters + ---------- + strict : bool, optional + True (default): + The timebase `ssobject.dt` cannot be None; it must + be continuous (0) or discrete (True or > 0). + False: + If `ssobject.dt` is None, continuous time + :class:`scipy.signal.lti` objects are returned. + + Returns + ------- + out : list of list of :class:`scipy.signal.StateSpace` + continuous time (inheriting from :class:`scipy.signal.lti`) + or discrete time (inheriting from :class:`scipy.signal.dlti`) + SISO objects + """ + if strict and self.dt is None: + raise ValueError("with strict=True, dt cannot be None") + + if self.dt: + kwdt = {'dt': self.dt} + else: + # scipy convention for continuous time lti systems: call without + # dt keyword argument + kwdt = {} # Preallocate the output. out = [[[] for _ in range(self.inputs)] for _ in range(self.outputs)] for i in range(self.outputs): for j in range(self.inputs): - out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]), - asarray(self.C[i, :]), self.D[i, j]) + out[i][j] = signalStateSpace(asarray(self.A), + asarray(self.B[:, j:j + 1]), + asarray(self.C[i:i + 1, :]), + asarray(self.D[i:i + 1, j:j + 1]), + **kwdt) return out diff --git a/control/tests/statesp_matrix_test.py b/control/tests/statesp_matrix_test.py index 34a17f992..e7e91364a 100644 --- a/control/tests/statesp_matrix_test.py +++ b/control/tests/statesp_matrix_test.py @@ -5,6 +5,7 @@ import unittest import numpy as np +import pytest from numpy.linalg import solve from scipy.linalg import eigvals, block_diag from control import matlab @@ -673,5 +674,56 @@ def test_sample_system_prewarping(self): decimal=4) +class TestLTIConverter: + """Test returnScipySignalLTI method""" + + @pytest.fixture + def mimoss(self, request): + """Test system with various dt values""" + n = 5 + m = 3 + p = 2 + bx, bu = np.mgrid[1:n + 1, 1:m + 1] + cy, cx = np.mgrid[1:p + 1, 1:n + 1] + dy, du = np.mgrid[1:p + 1, 1:m + 1] + return StateSpace(np.eye(5) + np.eye(5, 5, 1), + bx * bu, + cy * cx, + dy * du, + request.param) + + @pytest.mark.parametrize("mimoss", + [None, + 0, + 0.1, + 1, + True], + indirect=True) + def test_returnScipySignalLTI(self, mimoss): + """Test returnScipySignalLTI method with strict=False""" + sslti = mimoss.returnScipySignalLTI(strict=False) + for i in range(mimoss.outputs): + for j in range(mimoss.inputs): + np.testing.assert_allclose(sslti[i][j].A, mimoss.A) + np.testing.assert_allclose(sslti[i][j].B, mimoss.B[:, + j:j + 1]) + np.testing.assert_allclose(sslti[i][j].C, mimoss.C[i:i + 1, + :]) + np.testing.assert_allclose(sslti[i][j].D, mimoss.D[i:i + 1, + j:j + 1]) + if mimoss.dt == 0: + assert sslti[i][j].dt is None + else: + assert sslti[i][j].dt == mimoss.dt + + @pytest.mark.parametrize("mimoss", [None], indirect=True) + def test_returnScipySignalLTI_error(self, mimoss): + """Test returnScipySignalLTI method with dt=None and strict=True""" + with pytest.raises(ValueError): + mimoss.returnScipySignalLTI() + with pytest.raises(ValueError): + mimoss.returnScipySignalLTI(strict=True) + + if __name__ == "__main__": unittest.main() diff --git a/control/tests/xferfcn_test.py b/control/tests/xferfcn_test.py index 02e6c2b37..17e602090 100644 --- a/control/tests/xferfcn_test.py +++ b/control/tests/xferfcn_test.py @@ -4,6 +4,8 @@ # RMM, 30 Mar 2011 (based on TestXferFcn from v0.4a) import unittest +import pytest + import sys as pysys import numpy as np from control.statesp import StateSpace, _convertToStateSpace, rss @@ -934,5 +936,44 @@ def test_sample_system_prewarping(self): decimal=4) +class TestLTIConverter: + """Test returnScipySignalLTI method""" + + @pytest.fixture + def mimotf(self, request): + """Test system with various dt values""" + return TransferFunction([[[11], [12], [13]], + [[21], [22], [23]]], + [[[1, -1]] * 3] * 2, + request.param) + + @pytest.mark.parametrize("mimotf", + [None, + 0, + 0.1, + 1, + True], + indirect=True) + def test_returnScipySignalLTI(self, mimotf): + """Test returnScipySignalLTI method with strict=False""" + sslti = mimotf.returnScipySignalLTI(strict=False) + for i in range(2): + for j in range(3): + np.testing.assert_allclose(sslti[i][j].num, mimotf.num[i][j]) + np.testing.assert_allclose(sslti[i][j].den, mimotf.den[i][j]) + if mimotf.dt == 0: + assert sslti[i][j].dt is None + else: + assert sslti[i][j].dt == mimotf.dt + + @pytest.mark.parametrize("mimotf", [None], indirect=True) + def test_returnScipySignalLTI_error(self, mimotf): + """Test returnScipySignalLTI method with dt=None and strict=True""" + with pytest.raises(ValueError): + mimotf.returnScipySignalLTI() + with pytest.raises(ValueError): + mimotf.returnScipySignalLTI(strict=True) + + if __name__ == "__main__": unittest.main() diff --git a/control/xferfcn.py b/control/xferfcn.py index 1cba50bd7..4077080e3 100644 --- a/control/xferfcn.py +++ b/control/xferfcn.py @@ -57,7 +57,8 @@ polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \ where, delete, real, poly, nonzero import scipy as sp -from scipy.signal import lti, tf2zpk, zpk2tf, cont2discrete +from scipy.signal import tf2zpk, zpk2tf, cont2discrete +from scipy.signal import TransferFunction as signalTransferFunction from copy import deepcopy from warnings import warn from itertools import chain @@ -801,7 +802,7 @@ def minreal(self, tol=None): # end result return TransferFunction(num, den, self.dt) - def returnScipySignalLTI(self): + def returnScipySignalLTI(self, strict=True): """Return a list of a list of :class:`scipy.signal.lti` objects. For instance, @@ -809,22 +810,44 @@ def returnScipySignalLTI(self): >>> out = tfobject.returnScipySignalLTI() >>> out[3][5] - is a class:`scipy.signal.lti` object corresponding to the + is a :class:`scipy.signal.lti` object corresponding to the transfer function from the 6th input to the 4th output. + Parameters + ---------- + strict : bool, optional + True (default): + The timebase `tfobject.dt` cannot be None; it must be + continuous (0) or discrete (True or > 0). + False: + if `tfobject.dt` is None, continuous time + :class:`scipy.signal.lti`objects are returned + + Returns + ------- + out : list of list of :class:`scipy.signal.TransferFunction` + continuous time (inheriting from :class:`scipy.signal.lti`) + or discrete time (inheriting from :class:`scipy.signal.dlti`) + SISO objects """ + if strict and self.dt is None: + raise ValueError("with strict=True, dt cannot be None") - # TODO: implement for discrete time systems - if self.dt != 0 and self.dt is not None: - raise NotImplementedError("Function not \ - implemented in discrete time") + if self.dt: + kwdt = {'dt': self.dt} + else: + # scipy convention for continuous time lti systems: call without + # dt keyword argument + kwdt = {} # Preallocate the output. out = [[[] for j in range(self.inputs)] for i in range(self.outputs)] for i in range(self.outputs): for j in range(self.inputs): - out[i][j] = lti(self.num[i][j], self.den[i][j]) + out[i][j] = signalTransferFunction(self.num[i][j], + self.den[i][j], + **kwdt) return out