Skip to content

Commit f367cf6

Browse files
committed
unify frequency response processing + unit tests
1 parent 848112d commit f367cf6

File tree

5 files changed

+75
-38
lines changed

5 files changed

+75
-38
lines changed

control/frdata.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from numpy import angle, array, empty, ones, \
5151
real, imag, absolute, eye, linalg, where, dot, sort
5252
from scipy.interpolate import splprep, splev
53-
from .lti import LTI
53+
from .lti import LTI, _process_frequency_response
5454
from . import config
5555

5656
__all__ = ['FrequencyResponseData', 'FRD', 'frd']
@@ -391,14 +391,10 @@ def eval(self, omega, squeeze=None):
391391
for k, w in enumerate(omega_array):
392392
frraw = splev(w, self.ifunc[i, j], der=0)
393393
out[i, j, k] = frraw[0] + 1.0j * frraw[1]
394-
if not hasattr(omega, '__len__'):
395-
# omega is a scalar, squeeze down array along last dim
396-
out = np.squeeze(out, axis=2)
397-
if squeeze and self.issiso():
398-
out = out[0][0]
399-
return out
400-
401-
def __call__(self, s, squeeze=True):
394+
395+
return _process_frequency_response(self, omega, out, squeeze=squeeze)
396+
397+
def __call__(self, s, squeeze=None):
402398
"""Evaluate system's transfer function at complex frequencies.
403399
404400
Returns the complex frequency response `sys(s)` of system `sys` with

control/lti.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
from numpy import absolute, real, angle, abs
1717
from warnings import warn
18+
from . import config
1819

1920
__all__ = ['issiso', 'timebase', 'common_timebase', 'timebaseEqual',
2021
'isdtime', 'isctime', 'pole', 'zero', 'damp', 'evalfr',
@@ -596,3 +597,18 @@ def dcgain(sys):
596597
at the origin
597598
"""
598599
return sys.dcgain()
600+
601+
602+
# Process frequency responses in a uniform way
603+
def _process_frequency_response(sys, omega, out, squeeze=None):
604+
if not hasattr(omega, '__len__'):
605+
# received a scalar x, squeeze down the array along last dim
606+
out = np.squeeze(out, axis=2)
607+
608+
# Get rid of unneeded dimensions
609+
if squeeze is None:
610+
squeeze = config.defaults['control.squeeze']
611+
if squeeze and sys.issiso():
612+
return out[0][0]
613+
else:
614+
return out

control/statesp.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from scipy.signal import cont2discrete
6363
from scipy.signal import StateSpace as signalStateSpace
6464
from warnings import warn
65-
from .lti import LTI, common_timebase, isdtime
65+
from .lti import LTI, common_timebase, isdtime, _process_frequency_response
6666
from . import config
6767
from copy import deepcopy
6868

@@ -646,9 +646,9 @@ def __call__(self, x, squeeze=None):
646646
Returns the complex frequency response `sys(x)` where `x` is `s` for
647647
continuous-time systems and `z` for discrete-time systems.
648648
649-
In general the system may be multiple input, multiple output (MIMO), where
650-
`m = self.inputs` number of inputs and `p = self.outputs` number of
651-
outputs.
649+
In general the system may be multiple input, multiple output
650+
(MIMO), where `m = self.inputs` number of inputs and `p =
651+
self.outputs` number of outputs.
652652
653653
To evaluate at a frequency omega in radians per second, enter
654654
``x = omega * 1j``, for continuous-time systems, or
@@ -671,19 +671,9 @@ def __call__(self, x, squeeze=None):
671671
only if system is SISO and ``squeeze=True``.
672672
673673
"""
674-
# Set value of squeeze argument if not set
675-
if squeeze is None:
676-
squeeze = config.defaults['control.squeeze']
677-
678674
# Use Slycot if available
679675
out = self.horner(x)
680-
if not hasattr(x, '__len__'):
681-
# received a scalar x, squeeze down the array along last dim
682-
out = np.squeeze(out, axis=2)
683-
if squeeze and self.issiso():
684-
return out[0][0]
685-
else:
686-
return out
676+
return _process_frequency_response(self, x, out, squeeze=squeeze)
687677

688678
def slycot_laub(self, x):
689679
"""Evaluate system's transfer function at complex frequency

control/tests/lti_test.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import numpy as np
44
import pytest
5+
from .conftest import editsdefaults
56

7+
import control as ct
68
from control import c2d, tf, tf2ss, NonlinearIOSystem
79
from control.lti import (LTI, common_timebase, damp, dcgain, isctime, isdtime,
810
issiso, pole, timebaseEqual, zero)
911
from control.tests.conftest import slycotonly
10-
12+
from control.exception import slycot_check
1113

1214
class TestLTI:
1315

@@ -153,3 +155,47 @@ def test_isdtime(self, objfun, arg, dt, ref, strictref):
153155
strictref = not strictref
154156
assert isctime(obj) == ref
155157
assert isctime(obj, strict=True) == strictref
158+
159+
@pytest.mark.usefixtures("editsdefaults")
160+
@pytest.mark.parametrize("fcn", [ct.ss, ct.tf, ct.frd])
161+
@pytest.mark.parametrize("nstate, nout, ninp, squeeze, shape", [
162+
[1, 1, 1, None, (8,)], # SISO
163+
[2, 1, 1, True, (8,)],
164+
[3, 1, 1, False, (1, 1, 8)],
165+
[1, 2, 1, None, (2, 1, 8)], # SIMO
166+
[2, 2, 1, True, (2, 1, 8)],
167+
[3, 2, 1, False, (2, 1, 8)],
168+
[1, 1, 2, None, (1, 2, 8)], # MISO
169+
[2, 1, 2, True, (1, 2, 8)],
170+
[3, 1, 2, False, (1, 2, 8)],
171+
[1, 2, 2, None, (2, 2, 8)], # MIMO
172+
[2, 2, 2, True, (2, 2, 8)],
173+
[3, 2, 2, False, (2, 2, 8)]
174+
])
175+
def test_squeeze(self, fcn, nstate, nout, ninp, squeeze, shape):
176+
# Compute the length of the frequency array
177+
omega = np.logspace(-2, 2, 8)
178+
179+
# Create the system to be tested
180+
if fcn == ct.frd:
181+
sys = fcn(ct.rss(nstate, nout, ninp), omega)
182+
elif fcn == ct.tf and (nout > 1 or ninp > 1) and not slycot_check():
183+
pytest.skip("Conversion of MIMO systems to transfer functions "
184+
"requires slycot.")
185+
else:
186+
sys = fcn(ct.rss(nstate, nout, ninp))
187+
188+
# Pass squeeze argument and make sure the shape is correct
189+
mag, phase, _ = sys.frequency_response(omega, squeeze=squeeze)
190+
assert mag.shape == shape
191+
assert phase.shape == shape
192+
assert sys(omega * 1j, squeeze=squeeze).shape == shape
193+
assert ct.evalfr(sys, omega * 1j, squeeze=squeeze).shape == shape
194+
195+
# Changing config.default to False should return 3D frequency response
196+
ct.config.set_defaults('control', squeeze=False)
197+
mag, phase, _ = sys.frequency_response(omega)
198+
assert mag.shape == (sys.outputs, sys.inputs, 8)
199+
assert phase.shape == (sys.outputs, sys.inputs, 8)
200+
assert sys(omega * 1j).shape == (sys.outputs, sys.inputs, 8)
201+
assert ct.evalfr(sys, omega * 1j).shape == (sys.outputs, sys.inputs, 8)

control/xferfcn.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from warnings import warn
6464
from itertools import chain
6565
from re import sub
66-
from .lti import LTI, common_timebase, isdtime
66+
from .lti import LTI, common_timebase, isdtime, _process_frequency_response
6767
from . import config
6868

6969
__all__ = ['TransferFunction', 'tf', 'ss2tf', 'tfdata']
@@ -265,19 +265,8 @@ def __call__(self, x, squeeze=None):
265265
only if system is SISO and ``squeeze=True``.
266266
267267
"""
268-
# Set value of squeeze argument if not set
269-
if squeeze is None:
270-
squeeze = config.defaults['control.squeeze']
271-
272268
out = self.horner(x)
273-
if not hasattr(x, '__len__'):
274-
# received a scalar x, squeeze down the array along last dim
275-
out = np.squeeze(out, axis=2)
276-
if squeeze and self.issiso():
277-
# return a scalar/1d array of outputs
278-
return out[0][0]
279-
else:
280-
return out
269+
return _process_frequency_response(self, x, out, squeeze=squeeze)
281270

282271
def horner(self, x):
283272
"""Evaluate system's transfer function at complex frequency

0 commit comments

Comments
 (0)