Skip to content

Commit 73f93a2

Browse files
committed
ENH: Enable Merge(0) to merge a list of lists
1 parent 2534772 commit 73f93a2

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

nipype/interfaces/utility/base.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,38 @@ class Merge(IOBase):
128128
def __init__(self, numinputs=0, **inputs):
129129
super(Merge, self).__init__(**inputs)
130130
self._numinputs = numinputs
131-
add_traits(self.inputs, ['in%d' % (i + 1) for i in range(numinputs)])
131+
if numinputs > 0:
132+
input_names = ['in%d' % (i + 1) for i in range(numinputs)]
133+
elif numinputs == 0:
134+
input_names = ['in_lists']
135+
else:
136+
input_names = []
137+
add_traits(self.inputs, input_names)
132138

133139
def _list_outputs(self):
134140
outputs = self._outputs().get()
135141
out = []
142+
143+
if self._numinputs == 0:
144+
values = getattr(self.inputs, 'in_lists')
145+
if not isdefined(values):
146+
return outputs
147+
else:
148+
getval = lambda idx: getattr(self.inputs, 'in%d' % (idx + 1))
149+
values = [getval(idx) for idx in range(self._numinputs)
150+
if isdefined(getval(idx))]
151+
136152
if self.inputs.axis == 'vstack':
137-
for idx in range(self._numinputs):
138-
value = getattr(self.inputs, 'in%d' % (idx + 1))
139-
if isdefined(value):
140-
if isinstance(value, list) and not self.inputs.no_flatten:
141-
out.extend(value)
142-
else:
143-
out.append(value)
153+
for value in values:
154+
if isinstance(value, list) and not self.inputs.no_flatten:
155+
out.extend(value)
156+
else:
157+
out.append(value)
144158
else:
145-
for i in range(len(filename_to_list(self.inputs.in1))):
159+
for i in range(len(filename_to_list(values[0]))):
146160
out.insert(i, [])
147-
for j in range(self._numinputs):
148-
out[i].append(filename_to_list(getattr(self.inputs, 'in%d' % (j + 1)))[i])
161+
for value in values:
162+
out[i].append(filename_to_list(value)[i])
149163
if out:
150164
outputs['out'] = out
151165
return outputs

nipype/interfaces/utility/tests/test_base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,32 @@ def test_split(tmpdir, args, expected):
5555
@pytest.mark.parametrize("args, kwargs, in_lists, expected", [
5656
([3], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
5757
([], {}, None, None),
58+
([], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
5859
([3], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
5960
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6061
[[0, 2, 4], [1, 3, 5]]),
6162
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6263
[[0, 2, 4], [1, 3, 5]]),
64+
# Note: Merge(0, axis='hstack') would error on run, prior to
65+
# in_lists implementation
66+
([0], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
67+
([0], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
68+
[[0, 2, 4], [1, 3, 5]]),
6369
])
6470
def test_merge(tmpdir, args, kwargs, in_lists, expected):
6571
os.chdir(str(tmpdir))
6672

6773
node = pe.Node(utility.Merge(*args, **kwargs), name='merge')
6874

6975
numinputs = args[0] if args else 0
70-
for i in range(1, numinputs + 1):
71-
setattr(node.inputs, 'in{:d}'.format(i), in_lists[i - 1])
76+
if numinputs == 0 and in_lists:
77+
node.inputs.in_lists = in_lists
78+
else:
79+
for i in range(1, numinputs + 1):
80+
setattr(node.inputs, 'in{:d}'.format(i), in_lists[i - 1])
7281

7382
res = node.run()
74-
if numinputs == 0:
83+
if numinputs == 0 and in_lists is None:
7584
assert not isdefined(res.outputs.out)
7685
else:
7786
assert res.outputs.out == expected

0 commit comments

Comments
 (0)