From 4968cb3ed4b76a147439b20c6f58e36c6e2b4839 Mon Sep 17 00:00:00 2001 From: Richard Murray Date: Fri, 18 Nov 2022 22:00:22 -0800 Subject: [PATCH] check for and fix mutable keyword args --- control/flatsys/flatsys.py | 12 +++--- control/flatsys/linflat.py | 4 +- control/iosys.py | 5 ++- control/tests/kwargs_test.py | 71 ++++++++++++++++++++++++++++++++++-- 4 files changed, 78 insertions(+), 14 deletions(-) diff --git a/control/flatsys/flatsys.py b/control/flatsys/flatsys.py index 849c41c72..e0023c4de 100644 --- a/control/flatsys/flatsys.py +++ b/control/flatsys/flatsys.py @@ -142,7 +142,7 @@ def __init__(self, forward, reverse, # flat system updfcn=None, outfcn=None, # I/O system inputs=None, outputs=None, - states=None, params={}, dt=None, name=None): + states=None, params=None, dt=None, name=None): """Create a differentially flat I/O system. The FlatIOSystem constructor is used to create an input/output system @@ -171,7 +171,7 @@ def __str__(self): + f"Forward: {self.forward}\n" \ + f"Reverse: {self.reverse}" - def forward(self, x, u, params={}): + def forward(self, x, u, params=None): """Compute the flat flag given the states and input. @@ -200,7 +200,7 @@ def forward(self, x, u, params={}): """ raise NotImplementedError("internal error; forward method not defined") - def reverse(self, zflag, params={}): + def reverse(self, zflag, params=None): """Compute the states and input given the flat flag. Parameters @@ -224,18 +224,18 @@ def reverse(self, zflag, params={}): """ raise NotImplementedError("internal error; reverse method not defined") - def _flat_updfcn(self, t, x, u, params={}): + def _flat_updfcn(self, t, x, u, params=None): # TODO: implement state space update using flat coordinates raise NotImplementedError("update function for flat system not given") - def _flat_outfcn(self, t, x, u, params={}): + def _flat_outfcn(self, t, x, u, params=None): # Return the flat output zflag = self.forward(x, u, params) return np.array([zflag[i][0] for i in range(len(zflag))]) # Utility function to compute flag matrix given a basis -def _basis_flag_matrix(sys, basis, flag, t, params={}): +def _basis_flag_matrix(sys, basis, flag, t): """Compute the matrix of basis functions and their derivatives This function computes the matrix ``M`` that is used to solve for the diff --git a/control/flatsys/linflat.py b/control/flatsys/linflat.py index e4a31c6de..8e6c23604 100644 --- a/control/flatsys/linflat.py +++ b/control/flatsys/linflat.py @@ -142,11 +142,11 @@ def reverse(self, zflag, params): return np.reshape(x, self.nstates), np.reshape(u, self.ninputs) # Update function - def _rhs(self, t, x, u, params={}): + def _rhs(self, t, x, u): # Use LinearIOSystem._rhs instead of default (MRO) NonlinearIOSystem return LinearIOSystem._rhs(self, t, x, u) # output function - def _out(self, t, x, u, params={}): + def _out(self, t, x, u): # Use LinearIOSystem._out instead of default (MRO) NonlinearIOSystem return LinearIOSystem._out(self, t, x, u) diff --git a/control/iosys.py b/control/iosys.py index df75f3b54..6fa4a3e76 100644 --- a/control/iosys.py +++ b/control/iosys.py @@ -1584,7 +1584,7 @@ def __init__(self, io_sys, ss_sys=None): def input_output_response( sys, T, U=0., X0=0, params=None, transpose=False, return_x=False, squeeze=None, - solve_ivp_kwargs={}, t_eval='T', **kwargs): + solve_ivp_kwargs=None, t_eval='T', **kwargs): """Compute the output response of a system to a given input. Simulate a dynamical system with a given input and return its output @@ -1650,7 +1650,7 @@ def input_output_response( solve_ivp_method : str, optional Set the method used by :func:`scipy.integrate.solve_ivp`. Defaults to 'RK45'. - solve_ivp_kwargs : str, optional + solve_ivp_kwargs : dict, optional Pass additional keywords to :func:`scipy.integrate.solve_ivp`. Raises @@ -1676,6 +1676,7 @@ def input_output_response( # # Figure out the method to be used + solve_ivp_kwargs = solve_ivp_kwargs.copy() if solve_ivp_kwargs else {} if kwargs.get('solve_ivp_method', None): if kwargs.get('method', None): raise ValueError("ivp_method specified more than once") diff --git a/control/tests/kwargs_test.py b/control/tests/kwargs_test.py index 855bb9dda..2dc7f0563 100644 --- a/control/tests/kwargs_test.py +++ b/control/tests/kwargs_test.py @@ -38,6 +38,10 @@ def test_kwarg_search(module, prefix): # Skip anything that isn't part of the control package continue + # Look for classes and then check member functions + if inspect.isclass(obj): + test_kwarg_search(obj, prefix + obj.__name__ + '.') + # Only look for functions with keyword arguments if not inspect.isfunction(obj): continue @@ -70,10 +74,6 @@ def test_kwarg_search(module, prefix): f"'unrecognized keyword' not found in unit test " f"for {name}") - # Look for classes and then check member functions - if inspect.isclass(obj): - test_kwarg_search(obj, prefix + obj.__name__ + '.') - @pytest.mark.parametrize( "function, nsssys, ntfsys, moreargs, kwargs", @@ -201,3 +201,66 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup): 'TimeResponseData.__call__': trdata_test.test_response_copy, 'TransferFunction.__init__': test_unrecognized_kwargs, } + +# +# Look for keywords with mutable defaults +# +# This test goes through every function and looks for signatures that have a +# default value for a keyword that is mutable. An error is generated unless +# the function is listed in the `mutable_ok` set (which should only be used +# for cases were the code has been explicitly checked to make sure that the +# value of the mutable is not modified in the code). +# +mutable_ok = { # initial and date + control.flatsys.SystemTrajectory.__init__, # RMM, 18 Nov 2022 + control.freqplot._add_arrows_to_line2D, # RMM, 18 Nov 2022 + control.namedio._process_dt_keyword, # RMM, 13 Nov 2022 + control.namedio._process_namedio_keywords, # RMM, 18 Nov 2022 + control.optimal.OptimalControlProblem.__init__, # RMM, 18 Nov 2022 + control.optimal.solve_ocp, # RMM, 18 Nov 2022 + control.optimal.create_mpc_iosystem, # RMM, 18 Nov 2022 +} + +@pytest.mark.parametrize("module", [control, control.flatsys]) +def test_mutable_defaults(module, recurse=True): + # Look through every object in the package + for name, obj in inspect.getmembers(module): + # Skip anything that is outside of this module + if inspect.getmodule(obj) is not None and \ + not inspect.getmodule(obj).__name__.startswith('control'): + # Skip anything that isn't part of the control package + continue + + # Look for classes and then check member functions + if inspect.isclass(obj): + test_mutable_defaults(obj, True) + + # Look for modules and check for internal functions (w/ no recursion) + if inspect.ismodule(obj) and recurse: + test_mutable_defaults(obj, False) + + # Only look at functions and skip any that are marked as OK + if not inspect.isfunction(obj) or obj in mutable_ok: + continue + + # Get the signature for the function + sig = inspect.signature(obj) + + # Skip anything that is inherited + if inspect.isclass(module) and obj.__name__ not in module.__dict__: + continue + + # See if there is a variable keyword argument + for argname, par in sig.parameters.items(): + if par.default is inspect._empty or \ + not par.kind == inspect.Parameter.KEYWORD_ONLY and \ + not par.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + continue + + # Check to see if the default value is mutable + if par.default is not None and not \ + isinstance(par.default, (bool, int, float, tuple, str)): + pytest.fail( + f"function '{obj.__name__}' in module '{module.__name__}'" + f" has mutable default for keyword '{par.name}'") +