Skip to content

Commit 2e211f7

Browse files
authored
Merge pull request nipy#1390 from effigies/merge_all
ENH: Make interfaces.utility.Merge(1) merge a list of lists
2 parents cf2b1f8 + fcf1d48 commit 2e211f7

File tree

2 files changed

+68
-15
lines changed

2 files changed

+68
-15
lines changed

nipype/interfaces/utility/base.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class MergeOutputSpec(TraitedSpec):
109109
class Merge(IOBase):
110110
"""Basic interface class to merge inputs into a single list
111111
112+
``Merge(1)`` will merge a list of lists
113+
112114
Examples
113115
--------
114116
@@ -121,33 +123,50 @@ class Merge(IOBase):
121123
>>> out.outputs.out
122124
[1, 2, 5, 3]
123125
126+
>>> merge = Merge() # Or Merge(1)
127+
>>> merge.inputs.in_lists = [1, [2, 5], 3]
128+
>>> out = merge.run()
129+
>>> out.outputs.out
130+
[1, 2, 5, 3]
131+
124132
"""
125133
input_spec = MergeInputSpec
126134
output_spec = MergeOutputSpec
127135

128-
def __init__(self, numinputs=0, **inputs):
136+
def __init__(self, numinputs=1, **inputs):
129137
super(Merge, self).__init__(**inputs)
130138
self._numinputs = numinputs
131-
add_traits(self.inputs, ['in%d' % (i + 1) for i in range(numinputs)])
139+
if numinputs > 1:
140+
input_names = ['in%d' % (i + 1) for i in range(numinputs)]
141+
elif numinputs == 1:
142+
input_names = ['in_lists']
143+
else:
144+
input_names = []
145+
add_traits(self.inputs, input_names)
132146

133147
def _list_outputs(self):
134148
outputs = self._outputs().get()
135149
out = []
150+
151+
if self._numinputs < 1:
152+
return outputs
153+
elif self._numinputs == 1:
154+
values = self.inputs.in_lists
155+
else:
156+
getval = lambda idx: getattr(self.inputs, 'in%d' % (idx + 1))
157+
values = [getval(idx) for idx in range(self._numinputs)
158+
if isdefined(getval(idx))]
159+
136160
if self.inputs.axis == 'vstack':
137-
for idx in range(self._numinputs):
138-
value = getattr(self.inputs, 'in%d' % (idx + 1))
139-
if isdefined(value):
140-
if isinstance(value, list) and not self.inputs.no_flatten:
141-
out.extend(value)
142-
else:
143-
out.append(value)
161+
for value in values:
162+
if isinstance(value, list) and not self.inputs.no_flatten:
163+
out.extend(value)
164+
else:
165+
out.append(value)
144166
else:
145-
for i in range(len(filename_to_list(self.inputs.in1))):
146-
out.insert(i, [])
147-
for j in range(self._numinputs):
148-
out[i].append(filename_to_list(getattr(self.inputs, 'in%d' % (j + 1)))[i])
149-
if out:
150-
outputs['out'] = out
167+
lists = [filename_to_list(val) for val in values]
168+
out = [[val[i] for val in lists] for i in range(len(lists[0]))]
169+
outputs['out'] = out
151170
return outputs
152171

153172

nipype/interfaces/utility/tests/test_base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from nipype.interfaces import utility
9+
from nipype.interfaces.base import isdefined
910
import nipype.pipeline.engine as pe
1011

1112

@@ -49,3 +50,36 @@ def test_split(tmpdir, args, expected):
4950
res = node.run()
5051
assert res.outputs.out1 == expected[0]
5152
assert res.outputs.out2 == expected[1]
53+
54+
55+
@pytest.mark.parametrize("args, kwargs, in_lists, expected", [
56+
([3], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
57+
([0], {}, None, None),
58+
([], {}, [], []),
59+
([], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
60+
([3], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
61+
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
62+
[[0, 2, 4], [1, 3, 5]]),
63+
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
64+
[[0, 2, 4], [1, 3, 5]]),
65+
([1], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
66+
([1], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
67+
[[0, 2, 4], [1, 3, 5]]),
68+
])
69+
def test_merge(tmpdir, args, kwargs, in_lists, expected):
70+
os.chdir(str(tmpdir))
71+
72+
node = pe.Node(utility.Merge(*args, **kwargs), name='merge')
73+
74+
numinputs = args[0] if args else 1
75+
if numinputs == 1:
76+
node.inputs.in_lists = in_lists
77+
elif numinputs > 1:
78+
for i in range(1, numinputs + 1):
79+
setattr(node.inputs, 'in{:d}'.format(i), in_lists[i - 1])
80+
81+
res = node.run()
82+
if numinputs < 1:
83+
assert not isdefined(res.outputs.out)
84+
else:
85+
assert res.outputs.out == expected

0 commit comments

Comments
 (0)