Skip to content

Commit 9aed7cc

Browse files
committed
changes so that a default dt=0 passes unit tests. fixed code everywhere that combines systems with different timebases
1 parent 0113f99 commit 9aed7cc

File tree

9 files changed

+173
-249
lines changed

9 files changed

+173
-249
lines changed

control/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def reset_defaults():
5959
from .statesp import _statesp_defaults
6060
defaults.update(_statesp_defaults)
6161

62+
from .iosys import _iosys_defaults
63+
defaults.update(_iosys_defaults)
64+
6265

6366
def _get_param(module, param, argval=None, defval=None, pop=False):
6467
"""Return the default value for a configuration option.
@@ -170,5 +173,8 @@ def use_legacy_defaults(version):
170173
"""
171174
if version == '0.8.3':
172175
use_numpy_matrix(True) # alternatively: set_defaults('statesp', use_numpy_matrix=True)
176+
set_defaults('statesp', default_dt=None)
177+
set_defaults('xferfcn', default_dt=None)
178+
set_defaults('iosys', default_dt=None)
173179
else:
174180
raise ValueError('''version number not recognized. Possible values are: ['0.8.3']''')

control/iosys.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@
3838

3939
from .statesp import StateSpace, tf2ss
4040
from .timeresp import _check_convert_array
41-
from .lti import isctime, isdtime, _find_timebase
41+
from .lti import isctime, isdtime, common_timebase
42+
from . import config
4243

4344
__all__ = ['InputOutputSystem', 'LinearIOSystem', 'NonlinearIOSystem',
4445
'InterconnectedSystem', 'input_output_response', 'find_eqpt',
4546
'linearize', 'ss2io', 'tf2io']
4647

48+
# Define module default parameter values
49+
_iosys_defaults = {
50+
'iosys.default_dt': 0}
4751

4852
class InputOutputSystem(object):
4953
"""A class for representing input/output systems.
@@ -109,7 +113,7 @@ class for a set of subclasses that are used to implement specific
109113
110114
"""
111115
def __init__(self, inputs=None, outputs=None, states=None, params={},
112-
dt=None, name=None):
116+
name=None, **kwargs):
113117
"""Create an input/output system.
114118
115119
The InputOutputSystem contructor is used to create an input/output
@@ -153,7 +157,7 @@ def __init__(self, inputs=None, outputs=None, states=None, params={},
153157
"""
154158
# Store the input arguments
155159
self.params = params.copy() # default parameters
156-
self.dt = dt # timebase
160+
self.dt = kwargs.get('dt', config.defaults['iosys.default_dt']) # timebase
157161
self.name = name # system name
158162

159163
# Parse and store the number of inputs, outputs, and states
@@ -200,10 +204,8 @@ def __mul__(sys2, sys1):
200204
"inputs and outputs")
201205

202206
# Make sure timebase are compatible
203-
dt = _find_timebase(sys1, sys2)
204-
if dt is False:
205-
raise ValueError("System timebases are not compabile")
206-
207+
dt = common_timebase(sys1.dt, sys2.dt)
208+
207209
# Return the series interconnection between the systems
208210
newsys = InterconnectedSystem((sys1, sys2))
209211

@@ -478,10 +480,8 @@ def feedback(self, other=1, sign=-1, params={}):
478480
"inputs and outputs")
479481

480482
# Make sure timebases are compatible
481-
dt = _find_timebase(self, other)
482-
if dt is False:
483-
raise ValueError("System timebases are not compabile")
484-
483+
dt = common_timebase(self.dt, other.dt)
484+
485485
# Return the series interconnection between the systems
486486
newsys = InterconnectedSystem((self, other), params=params, dt=dt)
487487

@@ -670,7 +670,8 @@ class NonlinearIOSystem(InputOutputSystem):
670670
671671
"""
672672
def __init__(self, updfcn, outfcn=None, inputs=None, outputs=None,
673-
states=None, params={}, dt=None, name=None):
673+
states=None, params={},
674+
name=None, **kwargs):
674675
"""Create a nonlinear I/O system given update and output functions.
675676
676677
Creates an `InputOutputSystem` for a nonlinear system by specifying a
@@ -741,6 +742,7 @@ def __init__(self, updfcn, outfcn=None, inputs=None, outputs=None,
741742
self.outfcn = outfcn
742743

743744
# Initialize the rest of the structure
745+
dt = kwargs.get('dt', config.defaults['iosys.default_dt'])
744746
super(NonlinearIOSystem, self).__init__(
745747
inputs=inputs, outputs=outputs, states=states,
746748
params=params, dt=dt, name=name
@@ -881,19 +883,14 @@ def __init__(self, syslist, connections=[], inplist=[], outlist=[],
881883
# Check to make sure all systems are consistent
882884
self.syslist = syslist
883885
self.syslist_index = {}
884-
dt = None
885886
nstates = 0; self.state_offset = []
886887
ninputs = 0; self.input_offset = []
887888
noutputs = 0; self.output_offset = []
888889
system_count = 0
890+
889891
for sys in syslist:
890892
# Make sure time bases are consistent
891-
# TODO: Use lti._find_timebase() instead?
892-
if dt is None and sys.dt is not None:
893-
# Timebase was not specified; set to match this system
894-
dt = sys.dt
895-
elif dt != sys.dt:
896-
raise TypeError("System timebases are not compatible")
893+
dt = common_timebase(dt, sys.dt)
897894

898895
# Make sure number of inputs, outputs, states is given
899896
if sys.ninputs is None or sys.noutputs is None or \

control/lti.py

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
isdtime()
1010
isctime()
1111
timebase()
12-
timebaseEqual()
12+
common_timebase()
1313
"""
1414

1515
import numpy as np
1616
from numpy import absolute, real
1717

18-
__all__ = ['issiso', 'timebase', 'timebaseEqual', 'isdtime', 'isctime',
18+
__all__ = ['issiso', 'timebase', 'common_timebase', 'isdtime', 'isctime',
1919
'pole', 'zero', 'damp', 'evalfr', 'freqresp', 'dcgain']
2020

2121
class LTI:
@@ -157,48 +157,31 @@ def timebase(sys, strict=True):
157157

158158
return sys.dt
159159

160-
# Check to see if two timebases are equal
161-
def timebaseEqual(sys1, sys2):
162-
"""Check to see if two systems have the same timebase
163-
164-
timebaseEqual(sys1, sys2)
165-
166-
returns True if the timebases for the two systems are compatible. By
167-
default, systems with timebase 'None' are compatible with either
168-
discrete or continuous timebase systems. If two systems have a discrete
169-
timebase (dt > 0) then their timebases must be equal.
170-
"""
171-
172-
if (type(sys1.dt) == bool or type(sys2.dt) == bool):
173-
# Make sure both are unspecified discrete timebases
174-
return type(sys1.dt) == type(sys2.dt) and sys1.dt == sys2.dt
175-
elif (sys1.dt is None or sys2.dt is None):
176-
# One or the other is unspecified => the other can be anything
177-
return True
178-
else:
179-
return sys1.dt == sys2.dt
180-
181-
# Find a common timebase between two or more systems
182-
def _find_timebase(sys1, *sysn):
183-
"""Find the common timebase between systems, otherwise return False"""
184-
185-
# Create a list of systems to check
186-
syslist = [sys1]
187-
syslist.append(*sysn)
188-
189-
# Look for a common timebase
190-
dt = None
191-
192-
for sys in syslist:
193-
# Make sure time bases are consistent
194-
if (dt is None and sys.dt is not None) or \
195-
(dt is True and isdiscrete(sys)):
196-
# Timebase was not specified; set to match this system
197-
dt = sys.dt
198-
elif dt != sys.dt:
199-
return False
200-
return dt
201-
160+
def common_timebase(dt1, dt2):
161+
"""Find the common timebase when interconnecting systems."""
162+
# cases:
163+
# if either dt is None, they are compatible with anything
164+
# if either dt is True (discrete with unspecified time base),
165+
# use the timebase of the other, if it is also discrete
166+
# otherwise they must be equal (holds for both cont and discrete systems)
167+
if dt1 is None:
168+
return dt2
169+
elif dt2 is None:
170+
return dt1
171+
elif dt1 is True:
172+
if dt2 > 0:
173+
return dt2
174+
else:
175+
raise ValueError("Systems have incompatible timebases")
176+
elif dt2 is True:
177+
if dt1 > 0:
178+
return dt1
179+
else:
180+
raise ValueError("Systems have incompatible timebases")
181+
elif np.isclose(dt1, dt2):
182+
return dt1
183+
else:
184+
raise ValueError("Systems have incompatible timebases")
202185

203186
# Check to see if a system is a discrete time system
204187
def isdtime(sys, strict=False):

control/statesp.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
import scipy as sp
6262
from scipy.signal import lti, cont2discrete
6363
from warnings import warn
64-
from .lti import LTI, timebase, timebaseEqual, isdtime
64+
from .lti import LTI, common_timebase, isdtime
6565
from . import config
6666
from copy import deepcopy
6767

@@ -174,7 +174,10 @@ def __init__(self, *args, **kw):
174174
if len(args) == 4:
175175
# The user provided A, B, C, and D matrices.
176176
(A, B, C, D) = args
177-
dt = config.defaults['statesp.default_dt']
177+
if _isstaticgain(A, B, C, D):
178+
dt = None
179+
else:
180+
dt = config.defaults['statesp.default_dt']
178181
elif len(args) == 5:
179182
# Discrete time system
180183
(A, B, C, D, dt) = args
@@ -190,9 +193,12 @@ def __init__(self, *args, **kw):
190193
try:
191194
dt = args[0].dt
192195
except NameError:
193-
dt = config.defaults['statesp.default_dt']
196+
if _isstaticgain(A, B, C, D):
197+
dt = None
198+
else:
199+
dt = config.defaults['statesp.default_dt']
194200
else:
195-
raise ValueError("Needs 1 or 4 arguments; received %i." % len(args))
201+
raise ValueError("Expected 1, 4, or 5 arguments; received %i." % len(args))
196202

197203
# Process keyword arguments
198204
remove_useless = kw.get('remove_useless', config.defaults['statesp.remove_useless_states'])
@@ -316,14 +322,7 @@ def __add__(self, other):
316322
(self.outputs != other.outputs)):
317323
raise ValueError("Systems have different shapes.")
318324

319-
# Figure out the sampling time to use
320-
if self.dt is None and other.dt is not None:
321-
dt = other.dt # use dt from second argument
322-
elif (other.dt is None and self.dt is not None) or \
323-
(timebaseEqual(self, other)):
324-
dt = self.dt # use dt from first argument
325-
else:
326-
raise ValueError("Systems have different sampling times")
325+
dt = common_timebase(self.dt, other.dt)
327326

328327
# Concatenate the various arrays
329328
A = concatenate((
@@ -372,16 +371,8 @@ def __mul__(self, other):
372371
# Check to make sure the dimensions are OK
373372
if self.inputs != other.outputs:
374373
raise ValueError("C = A * B: A has %i column(s) (input(s)), \
375-
but B has %i row(s)\n(output(s))." % (self.inputs, other.outputs))
376-
377-
# Figure out the sampling time to use
378-
if (self.dt == None and other.dt != None):
379-
dt = other.dt # use dt from second argument
380-
elif (other.dt == None and self.dt != None) or \
381-
(timebaseEqual(self, other)):
382-
dt = self.dt # use dt from first argument
383-
else:
384-
raise ValueError("Systems have different sampling times")
374+
but B has %i row(s)\n(output(s))." % (self.inputs, other.outputs))
375+
dt = common_timebase(self.dt, other.dt)
385376

386377
# Concatenate the various arrays
387378
A = concatenate(
@@ -453,9 +444,8 @@ def _evalfr(self, omega):
453444
"""Evaluate a SS system's transfer function at a single frequency"""
454445
# Figure out the point to evaluate the transfer function
455446
if isdtime(self, strict=True):
456-
dt = timebase(self)
457-
s = exp(1.j * omega * dt)
458-
if omega * dt > math.pi:
447+
s = exp(1.j * omega * self.dt)
448+
if omega * self.dt > math.pi:
459449
warn("_evalfr: frequency evaluation above Nyquist frequency")
460450
else:
461451
s = omega * 1.j
@@ -512,9 +502,8 @@ def freqresp(self, omega):
512502
# axis (continuous time) or unit circle (discrete time).
513503
omega.sort()
514504
if isdtime(self, strict=True):
515-
dt = timebase(self)
516-
cmplx_freqs = exp(1.j * omega * dt)
517-
if max(np.abs(omega)) * dt > math.pi:
505+
cmplx_freqs = exp(1.j * omega * self.dt)
506+
if max(np.abs(omega)) * self.dt > math.pi:
518507
warn("freqresp: frequency evaluation above Nyquist frequency")
519508
else:
520509
cmplx_freqs = omega * 1.j
@@ -617,14 +606,7 @@ def feedback(self, other=1, sign=-1):
617606
if (self.inputs != other.outputs) or (self.outputs != other.inputs):
618607
raise ValueError("State space systems don't have compatible inputs/outputs for "
619608
"feedback.")
620-
621-
# Figure out the sampling time to use
622-
if self.dt is None and other.dt is not None:
623-
dt = other.dt # use dt from second argument
624-
elif other.dt is None and self.dt is not None or timebaseEqual(self, other):
625-
dt = self.dt # use dt from first argument
626-
else:
627-
raise ValueError("Systems have different sampling times")
609+
dt = common_timebase(self.dt, other.dt)
628610

629611
A1 = self.A
630612
B1 = self.B
@@ -694,14 +676,7 @@ def lft(self, other, nu=-1, ny=-1):
694676
# dimension check
695677
# TODO
696678

697-
# Figure out the sampling time to use
698-
if (self.dt == None and other.dt != None):
699-
dt = other.dt # use dt from second argument
700-
elif (other.dt == None and self.dt != None) or \
701-
timebaseEqual(self, other):
702-
dt = self.dt # use dt from first argument
703-
else:
704-
raise ValueError("Systems have different time bases")
679+
dt = common_timebase(self.dt, other.dt)
705680

706681
# submatrices
707682
A = self.A
@@ -815,8 +790,7 @@ def append(self, other):
815790
if not isinstance(other, StateSpace):
816791
other = _convertToStateSpace(other)
817792

818-
if self.dt != other.dt:
819-
raise ValueError("Systems must have the same time step")
793+
self.dt = common_timebase(self.dt, other.dt)
820794

821795
n = self.states + other.states
822796
m = self.inputs + other.inputs
@@ -1246,6 +1220,11 @@ def _mimo2simo(sys, input, warn_conversion=False):
12461220

12471221
return sys
12481222

1223+
def _isstaticgain(A, B, C, D):
1224+
"""returns True if and only if the system has no dynamics, that is,
1225+
if A and B are zero. """
1226+
return not np.any(np.matrix(A, dtype=float)) \
1227+
and not np.any(np.matrix(B, dtype=float))
12491228

12501229
def ss(*args):
12511230
"""ss(A, B, C, D[, dt])

control/tests/config_test.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,27 @@ def test_legacy_defaults(self):
218218
assert(isinstance(ct.ss(0,0,0,1).D, np.ndarray))
219219

220220
def test_change_default_dt(self):
221-
ct.set_defaults('statesp', default_dt=0)
222-
self.assertEqual(ct.ss(0,0,0,1).dt, 0)
223-
ct.set_defaults('statesp', default_dt=None)
224-
self.assertEqual(ct.ss(0,0,0,1).dt, None)
221+
# TransferFunction
222+
# test that system with dynamics uses correct default dt
225223
ct.set_defaults('xferfcn', default_dt=0)
226-
self.assertEqual(ct.tf(1, 1).dt, 0)
224+
self.assertEqual(ct.tf(1, [1,1]).dt, 0)
227225
ct.set_defaults('xferfcn', default_dt=None)
226+
self.assertEqual(ct.tf(1, [1,1]).dt, None)
227+
# test that a static gain transfer function always has dt=None
228+
ct.set_defaults('xferfcn', default_dt=0)
228229
self.assertEqual(ct.tf(1, 1).dt, None)
230+
231+
# StateSpace
232+
# test that system with dynamics uses correct default dt
233+
ct.set_defaults('statesp', default_dt=0)
234+
self.assertEqual(ct.ss(1,0,0,1).dt, 0)
235+
ct.set_defaults('statesp', default_dt=None)
236+
self.assertEqual(ct.ss(1,0,0,1).dt, None)
237+
# test that a static gain state space system always has dt=None
238+
ct.set_defaults('statesp', default_dt=0)
239+
self.assertEqual(ct.ss(0,0,0,1).dt, None)
240+
241+
ct.reset_defaults()
229242

230243

231244
def tearDown(self):

0 commit comments

Comments
 (0)