Skip to content

Commit fcf1d48

Browse files
committed
Move list-flattening to Merge(1), make default
If accumulated list is empty, return as valid, rather than undefined
1 parent b515ac2 commit fcf1d48

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

nipype/interfaces/utility/base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class MergeOutputSpec(TraitedSpec):
109109
class Merge(IOBase):
110110
"""Basic interface class to merge inputs into a single list
111111
112+
``Merge(1)`` will merge a list of lists
113+
112114
Examples
113115
--------
114116
@@ -121,16 +123,22 @@ class Merge(IOBase):
121123
>>> out.outputs.out
122124
[1, 2, 5, 3]
123125
126+
>>> merge = Merge() # Or Merge(1)
127+
>>> merge.inputs.in_lists = [1, [2, 5], 3]
128+
>>> out = merge.run()
129+
>>> out.outputs.out
130+
[1, 2, 5, 3]
131+
124132
"""
125133
input_spec = MergeInputSpec
126134
output_spec = MergeOutputSpec
127135

128-
def __init__(self, numinputs=0, **inputs):
136+
def __init__(self, numinputs=1, **inputs):
129137
super(Merge, self).__init__(**inputs)
130138
self._numinputs = numinputs
131-
if numinputs > 0:
139+
if numinputs > 1:
132140
input_names = ['in%d' % (i + 1) for i in range(numinputs)]
133-
elif numinputs == 0:
141+
elif numinputs == 1:
134142
input_names = ['in_lists']
135143
else:
136144
input_names = []
@@ -140,10 +148,10 @@ def _list_outputs(self):
140148
outputs = self._outputs().get()
141149
out = []
142150

143-
if self._numinputs == 0:
144-
values = getattr(self.inputs, 'in_lists')
145-
if not isdefined(values):
146-
return outputs
151+
if self._numinputs < 1:
152+
return outputs
153+
elif self._numinputs == 1:
154+
values = self.inputs.in_lists
147155
else:
148156
getval = lambda idx: getattr(self.inputs, 'in%d' % (idx + 1))
149157
values = [getval(idx) for idx in range(self._numinputs)
@@ -158,8 +166,7 @@ def _list_outputs(self):
158166
else:
159167
lists = [filename_to_list(val) for val in values]
160168
out = [[val[i] for val in lists] for i in range(len(lists[0]))]
161-
if out:
162-
outputs['out'] = out
169+
outputs['out'] = out
163170
return outputs
164171

165172

nipype/interfaces/utility/tests/test_base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,32 @@ def test_split(tmpdir, args, expected):
5454

5555
@pytest.mark.parametrize("args, kwargs, in_lists, expected", [
5656
([3], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
57-
([], {}, None, None),
57+
([0], {}, None, None),
58+
([], {}, [], []),
5859
([], {}, [0, [1, 2], [3, 4, 5]], [0, 1, 2, 3, 4, 5]),
5960
([3], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
6061
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6162
[[0, 2, 4], [1, 3, 5]]),
6263
([3], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6364
[[0, 2, 4], [1, 3, 5]]),
64-
# Note: Merge(0, axis='hstack') would error on run, prior to
65-
# in_lists implementation
66-
([0], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
67-
([0], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
65+
([1], {'axis': 'hstack'}, [[0], [1, 2], [3, 4, 5]], [[0, 1, 3]]),
66+
([1], {'axis': 'hstack'}, [[0, 1], [2, 3], [4, 5]],
6867
[[0, 2, 4], [1, 3, 5]]),
6968
])
7069
def test_merge(tmpdir, args, kwargs, in_lists, expected):
7170
os.chdir(str(tmpdir))
7271

7372
node = pe.Node(utility.Merge(*args, **kwargs), name='merge')
7473

75-
numinputs = args[0] if args else 0
76-
if numinputs == 0 and in_lists:
74+
numinputs = args[0] if args else 1
75+
if numinputs == 1:
7776
node.inputs.in_lists = in_lists
78-
else:
77+
elif numinputs > 1:
7978
for i in range(1, numinputs + 1):
8079
setattr(node.inputs, 'in{:d}'.format(i), in_lists[i - 1])
8180

8281
res = node.run()
83-
if numinputs == 0 and in_lists is None:
82+
if numinputs < 1:
8483
assert not isdefined(res.outputs.out)
8584
else:
8685
assert res.outputs.out == expected

0 commit comments

Comments
 (0)