Skip to content

check for and fix mutable keyword defaults #794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions control/flatsys/flatsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self,
forward, reverse, # flat system
updfcn=None, outfcn=None, # I/O system
inputs=None, outputs=None,
states=None, params={}, dt=None, name=None):
states=None, params=None, dt=None, name=None):
"""Create a differentially flat I/O system.

The FlatIOSystem constructor is used to create an input/output system
Expand Down Expand Up @@ -171,7 +171,7 @@ def __str__(self):
+ f"Forward: {self.forward}\n" \
+ f"Reverse: {self.reverse}"

def forward(self, x, u, params={}):
def forward(self, x, u, params=None):

"""Compute the flat flag given the states and input.

Expand Down Expand Up @@ -200,7 +200,7 @@ def forward(self, x, u, params={}):
"""
raise NotImplementedError("internal error; forward method not defined")

def reverse(self, zflag, params={}):
def reverse(self, zflag, params=None):
"""Compute the states and input given the flat flag.

Parameters
Expand All @@ -224,18 +224,18 @@ def reverse(self, zflag, params={}):
"""
raise NotImplementedError("internal error; reverse method not defined")

def _flat_updfcn(self, t, x, u, params={}):
def _flat_updfcn(self, t, x, u, params=None):
# TODO: implement state space update using flat coordinates
raise NotImplementedError("update function for flat system not given")

def _flat_outfcn(self, t, x, u, params={}):
def _flat_outfcn(self, t, x, u, params=None):
# Return the flat output
zflag = self.forward(x, u, params)
return np.array([zflag[i][0] for i in range(len(zflag))])


# Utility function to compute flag matrix given a basis
def _basis_flag_matrix(sys, basis, flag, t, params={}):
def _basis_flag_matrix(sys, basis, flag, t):
"""Compute the matrix of basis functions and their derivatives

This function computes the matrix ``M`` that is used to solve for the
Expand Down
4 changes: 2 additions & 2 deletions control/flatsys/linflat.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def reverse(self, zflag, params):
return np.reshape(x, self.nstates), np.reshape(u, self.ninputs)

# Update function
def _rhs(self, t, x, u, params={}):
def _rhs(self, t, x, u):
# Use LinearIOSystem._rhs instead of default (MRO) NonlinearIOSystem
return LinearIOSystem._rhs(self, t, x, u)

# output function
def _out(self, t, x, u, params={}):
def _out(self, t, x, u):
# Use LinearIOSystem._out instead of default (MRO) NonlinearIOSystem
return LinearIOSystem._out(self, t, x, u)
5 changes: 3 additions & 2 deletions control/iosys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ def __init__(self, io_sys, ss_sys=None):
def input_output_response(
sys, T, U=0., X0=0, params=None,
transpose=False, return_x=False, squeeze=None,
solve_ivp_kwargs={}, t_eval='T', **kwargs):
solve_ivp_kwargs=None, t_eval='T', **kwargs):
"""Compute the output response of a system to a given input.

Simulate a dynamical system with a given input and return its output
Expand Down Expand Up @@ -1650,7 +1650,7 @@ def input_output_response(
solve_ivp_method : str, optional
Set the method used by :func:`scipy.integrate.solve_ivp`. Defaults
to 'RK45'.
solve_ivp_kwargs : str, optional
solve_ivp_kwargs : dict, optional
Pass additional keywords to :func:`scipy.integrate.solve_ivp`.

Raises
Expand All @@ -1676,6 +1676,7 @@ def input_output_response(
#

# Figure out the method to be used
solve_ivp_kwargs = solve_ivp_kwargs.copy() if solve_ivp_kwargs else {}
if kwargs.get('solve_ivp_method', None):
if kwargs.get('method', None):
raise ValueError("ivp_method specified more than once")
Expand Down
71 changes: 67 additions & 4 deletions control/tests/kwargs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def test_kwarg_search(module, prefix):
# Skip anything that isn't part of the control package
continue

# Look for classes and then check member functions
if inspect.isclass(obj):
test_kwarg_search(obj, prefix + obj.__name__ + '.')

# Only look for functions with keyword arguments
if not inspect.isfunction(obj):
continue
Expand Down Expand Up @@ -70,10 +74,6 @@ def test_kwarg_search(module, prefix):
f"'unrecognized keyword' not found in unit test "
f"for {name}")

# Look for classes and then check member functions
if inspect.isclass(obj):
test_kwarg_search(obj, prefix + obj.__name__ + '.')


@pytest.mark.parametrize(
"function, nsssys, ntfsys, moreargs, kwargs",
Expand Down Expand Up @@ -201,3 +201,66 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup):
'TimeResponseData.__call__': trdata_test.test_response_copy,
'TransferFunction.__init__': test_unrecognized_kwargs,
}

#
# Look for keywords with mutable defaults
#
# This test goes through every function and looks for signatures that have a
# default value for a keyword that is mutable. An error is generated unless
# the function is listed in the `mutable_ok` set (which should only be used
# for cases were the code has been explicitly checked to make sure that the
# value of the mutable is not modified in the code).
#
mutable_ok = { # initial and date
control.flatsys.SystemTrajectory.__init__, # RMM, 18 Nov 2022
control.freqplot._add_arrows_to_line2D, # RMM, 18 Nov 2022
control.namedio._process_dt_keyword, # RMM, 13 Nov 2022
control.namedio._process_namedio_keywords, # RMM, 18 Nov 2022
control.optimal.OptimalControlProblem.__init__, # RMM, 18 Nov 2022
control.optimal.solve_ocp, # RMM, 18 Nov 2022
control.optimal.create_mpc_iosystem, # RMM, 18 Nov 2022
}

@pytest.mark.parametrize("module", [control, control.flatsys])
def test_mutable_defaults(module, recurse=True):
# Look through every object in the package
for name, obj in inspect.getmembers(module):
# Skip anything that is outside of this module
if inspect.getmodule(obj) is not None and \
not inspect.getmodule(obj).__name__.startswith('control'):
# Skip anything that isn't part of the control package
continue

# Look for classes and then check member functions
if inspect.isclass(obj):
test_mutable_defaults(obj, True)

# Look for modules and check for internal functions (w/ no recursion)
if inspect.ismodule(obj) and recurse:
test_mutable_defaults(obj, False)

# Only look at functions and skip any that are marked as OK
if not inspect.isfunction(obj) or obj in mutable_ok:
continue

# Get the signature for the function
sig = inspect.signature(obj)

# Skip anything that is inherited
if inspect.isclass(module) and obj.__name__ not in module.__dict__:
continue

# See if there is a variable keyword argument
for argname, par in sig.parameters.items():
if par.default is inspect._empty or \
not par.kind == inspect.Parameter.KEYWORD_ONLY and \
not par.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
continue

# Check to see if the default value is mutable
if par.default is not None and not \
isinstance(par.default, (bool, int, float, tuple, str)):
pytest.fail(
f"function '{obj.__name__}' in module '{module.__name__}'"
f" has mutable default for keyword '{par.name}'")