Skip to content

Commit 7983829

Browse files
committed
fix up copying of signal names
1 parent 91466d1 commit 7983829

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

control/statesp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,6 @@ def output(self, t, x, u=None, params=None):
15211521

15221522

15231523
# TODO: add discrete time check
1524-
# TODO: copy signal names
15251524
def _convert_to_statespace(sys):
15261525
"""Convert a system to state space form (if needed).
15271526

control/tests/xferfcn_test.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import operator
99

1010
import control as ct
11-
from control import StateSpace, TransferFunction, rss, ss2tf, evalfr
11+
from control import StateSpace, TransferFunction, rss, evalfr
12+
from control import ss, ss2tf, tf, tf2ss
1213
from control import isctime, isdtime, sample_system, defaults
1314
from control.statesp import _convert_to_statespace
1415
from control.xferfcn import _convert_to_transfer_function
@@ -1111,3 +1112,33 @@ def test_zpk(zeros, poles, gain, args, kwargs):
11111112

11121113
if kwargs.get('name'):
11131114
assert sys.name == kwargs.get('name')
1115+
1116+
@pytest.mark.parametrize("sys, convert", [
1117+
(StateSpace([-1], [1], [1], [0]), ss2tf),
1118+
(StateSpace([-1], [1], [1], [0]), ss),
1119+
(StateSpace([-1], [1], [1], [0]), tf),
1120+
(StateSpace([-1], [1], [1], [0], inputs='in', outputs='out'), ss2tf),
1121+
(StateSpace([-1], [1], [1], [0], inputs=1, outputs=1), ss2tf),
1122+
(StateSpace([-1], [1], [1], [0], inputs='in', outputs='out'), ss),
1123+
(StateSpace([-1], [1], [1], [0], inputs='in', outputs='out'), tf),
1124+
(TransferFunction([1], [1, 1]), tf2ss),
1125+
(TransferFunction([1], [1, 1]), tf),
1126+
(TransferFunction([1], [1, 1]), ss),
1127+
(TransferFunction([1], [1, 1], inputs='in', outputs='out'), tf2ss),
1128+
(TransferFunction([1], [1, 1], inputs=1, outputs=1), tf2ss),
1129+
(TransferFunction([1], [1, 1], inputs='in', outputs='out'), tf),
1130+
(TransferFunction([1], [1, 1], inputs='in', outputs='out'), ss),
1131+
])
1132+
def test_copy_names(sys, convert):
1133+
# Convert a system with no renaming
1134+
cpy = convert(sys)
1135+
1136+
assert cpy.input_labels == sys.input_labels
1137+
assert cpy.input_labels == sys.input_labels
1138+
if cpy.nstates is not None and sys.nstates is not None:
1139+
assert cpy.state_labels == sys.state_labels
1140+
1141+
# Relabel inputs and outputs
1142+
cpy = convert(sys, inputs='myin', outputs='myout')
1143+
assert cpy.input_labels == ['myin']
1144+
assert cpy.output_labels == ['myout']

control/xferfcn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,16 +1424,13 @@ def _convert_to_transfer_function(sys, inputs=1, outputs=1):
14241424
num = squeeze(num) # Convert to 1D array
14251425
den = squeeze(den) # Probably not needed
14261426

1427-
return TransferFunction(
1428-
num, den, sys.dt, inputs=sys.input_labels,
1429-
outputs=sys.output_labels)
1427+
return TransferFunction(num, den, sys.dt)
14301428

14311429
elif isinstance(sys, (int, float, complex, np.number)):
14321430
num = [[[sys] for j in range(inputs)] for i in range(outputs)]
14331431
den = [[[1] for j in range(inputs)] for i in range(outputs)]
14341432

1435-
return TransferFunction(
1436-
num, den, inputs=inputs, outputs=outputs)
1433+
return TransferFunction(num, den)
14371434

14381435
elif isinstance(sys, FrequencyResponseData):
14391436
raise TypeError("Can't convert given FRD to TransferFunction system.")
@@ -1623,7 +1620,6 @@ def zpk(zeros, poles, gain, *args, **kwargs):
16231620
return TransferFunction(num, den, *args, **kwargs)
16241621

16251622

1626-
# TODO: copy signal names
16271623
def ss2tf(*args, **kwargs):
16281624

16291625
"""ss2tf(sys)
@@ -1705,6 +1701,11 @@ def ss2tf(*args, **kwargs):
17051701
if len(args) == 1:
17061702
sys = args[0]
17071703
if isinstance(sys, StateSpace):
1704+
kwargs = kwargs.copy()
1705+
if not kwargs.get('inputs'):
1706+
kwargs['inputs'] = sys.input_labels
1707+
if not kwargs.get('outputs'):
1708+
kwargs['outputs'] = sys.output_labels
17081709
return TransferFunction(
17091710
_convert_to_transfer_function(sys), **kwargs)
17101711
else:

0 commit comments

Comments
 (0)