Skip to content

Commit 63a4077

Browse files
committed
finally got all packages to convert again!
1 parent 1e92a40 commit 63a4077

File tree

4 files changed

+124
-48
lines changed

4 files changed

+124
-48
lines changed

nipype2pydra/interface/base.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -492,20 +492,21 @@ def _referenced_funcs_and_methods(self):
492492
method_stacks = {}
493493
method_supers = defaultdict(dict)
494494
already_processed = set(
495-
getattr(self.nipype_interface, m) for m in self.INCLUDED_METHODS
495+
getattr(self.nipype_interface, m) for m in self.included_methods
496496
)
497-
for method_name in self.INCLUDED_METHODS:
497+
for method_name in self.included_methods:
498498
method_args[method_name] = []
499499
method_returns[method_name] = []
500500
method_stacks[method_name] = ()
501-
for method_name in self.INCLUDED_METHODS:
502-
base = find_super_method(
501+
for method_name in self.included_methods:
502+
method = getattr(self.nipype_interface, method_name)
503+
super_base = find_super_method(
503504
self.nipype_interface, method_name, include_class=True
504505
)[1]
505-
if self.package.is_omitted(base):
506-
continue # Don't include base methods
507-
method = getattr(self.nipype_interface, method_name)
508-
referenced_methods.add(method)
506+
# if super_base is not self.nipype_interface:
507+
# method_supers[self.nipype_interface][method_name] = (
508+
# self._common_parent_pkg_prefix(super_base) + method_name
509+
# )
509510
self._get_referenced(
510511
method,
511512
referenced_funcs=referenced_funcs,
@@ -516,6 +517,7 @@ def _referenced_funcs_and_methods(self):
516517
method_stacks=method_stacks,
517518
method_supers=method_supers,
518519
already_processed=already_processed,
520+
super_base=super_base,
519521
)
520522
return (
521523
referenced_funcs,
@@ -1095,7 +1097,14 @@ def _get_referenced(
10951097

10961098
ref_method_names = re.findall(r"(?<=self\.)(\w+)\(", method_body)
10971099
ref_methods = set(m for m in self.methods if m.__name__ in ref_method_names)
1098-
1100+
# Filter methods in omitted common base-classes like BaseInterface & CommandLine
1101+
ref_methods = [
1102+
m
1103+
for m in ref_methods
1104+
if not self.package.is_omitted(
1105+
find_super_method(super_base, m.__name__, include_class=True)[1]
1106+
)
1107+
]
10991108
referenced_funcs.update(ref_local_funcs)
11001109
referenced_methods.update(ref_methods)
11011110

@@ -1107,7 +1116,7 @@ def _get_referenced(
11071116
)
11081117
for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body):
11091118
super_method, base = find_super_method(super_base, match)
1110-
if self.package.is_omitted(super_method):
1119+
if self.package.is_omitted(base):
11111120
continue
11121121
func_name = self._common_parent_pkg_prefix(base) + match
11131122
if func_name not in referenced_supers:
@@ -1144,7 +1153,6 @@ def _get_referenced(
11441153
method_supers=method_supers,
11451154
already_processed=already_processed,
11461155
method_stack=method_stack,
1147-
super_base=super_base,
11481156
)
11491157
referenced_inputs.update(rf_inputs)
11501158
referenced_outputs.update(rf_outputs)
@@ -1162,7 +1170,6 @@ def _get_referenced(
11621170
method_supers=method_supers,
11631171
already_processed=already_processed,
11641172
method_stack=method_stack,
1165-
super_base=super_base,
11661173
)
11671174
method_args[meth.__name__] = ref_inputs
11681175
method_returns[meth.__name__] = ref_outputs
@@ -1215,13 +1222,12 @@ def process_method(
12151222
pass
12161223
if "runtime" in args:
12171224
args.remove("runtime")
1218-
if method.__name__ in self.method_args:
1219-
args += [
1220-
f"{a}=None"
1221-
for a in (
1222-
list(self.method_args[method.__name__]) + list(additional_args)
1223-
)
1224-
]
1225+
args_to_add = list(self.method_args.get(method.__name__, [])) + list(
1226+
additional_args
1227+
)
1228+
if args_to_add:
1229+
kwargs = [args.pop()] if args and args[-1].startswith("**") else []
1230+
args += [f"{a}=None" for a in args_to_add] + kwargs
12251231
# Insert method args in signature if present
12261232
return_types, method_body = post.split(":", maxsplit=1)
12271233
method_body = method_body.split("\n", maxsplit=1)[1]
@@ -1255,6 +1261,8 @@ def process_method_body(
12551261
output_names: ty.List[str],
12561262
super_base: ty.Optional[type] = None,
12571263
) -> str:
1264+
if not method_body:
1265+
return ""
12581266
if super_base is None:
12591267
super_base = self.nipype_interface
12601268
return_value = get_return_line(method_body)
@@ -1330,7 +1338,7 @@ def unwrap_nested_methods(
13301338
# Add args to the function signature of method calls
13311339
method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL)
13321340
method_names = [m.__name__ for m in self.referenced_methods] + list(
1333-
self.INCLUDED_METHODS
1341+
self.included_methods
13341342
)
13351343
method_body = strip_comments(method_body)
13361344
omitted_methods = {}
@@ -1345,9 +1353,16 @@ def unwrap_nested_methods(
13451353
for name, args in zip(splits[1::2], splits[2::2]):
13461354
if name in omitted_methods:
13471355
args, post = extract_args(args)[1:]
1348-
new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_methods[name]].format(
1349-
args=", ".join(args)
1350-
)
1356+
omitted_method = omitted_methods[name]
1357+
try:
1358+
new_body += self.SPECIAL_SUPER_MAPPINGS[omitted_method].format(
1359+
args=", ".join(args)
1360+
)
1361+
except KeyError:
1362+
raise KeyError(
1363+
f"Require special mapping for {omitted_methods[name]} method "
1364+
"as methods in that module are being omitted from the conversion"
1365+
) from None
13511366
new_body += post[1:] # drop the leading parenthesis
13521367
continue
13531368
# Assign additional return values (which were previously saved to member
@@ -1407,6 +1422,10 @@ def unwrap_nested_methods(
14071422
CommandLine._filename_from_source: "{args} + '_generated'",
14081423
BaseInterface._check_version_requirements: "[]",
14091424
CommandLine._parse_inputs: "{{}}",
1425+
CommandLine._gen_filename: "",
1426+
BaseInterface.aggregate_outputs: "{{}}",
1427+
BaseInterface.run: "None",
1428+
BaseInterface._list_outputs: "{{}}",
14101429
}
14111430

14121431
INPUT_KEYS = [

nipype2pydra/interface/function.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import attrs
99
from nipype.interfaces.base import BaseInterface, TraitedSpec
1010
from .base import BaseInterfaceConverter
11-
from ..utils import UsedSymbols, get_return_line
11+
from ..utils import UsedSymbols, get_return_line, find_super_method
1212

1313

1414
logger = logging.getLogger("nipype2pydra")
@@ -17,7 +17,9 @@
1717
@attrs.define(slots=False)
1818
class FunctionInterfaceConverter(BaseInterfaceConverter):
1919

20-
INCLUDED_METHODS = ("_run_interface", "_list_outputs")
20+
@property
21+
def included_methods(self) -> ty.Tuple[str, ...]:
22+
return ("_run_interface", "_list_outputs")
2123

2224
def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[
2325
str,
@@ -68,14 +70,29 @@ def types_to_names(spec_fields):
6870
if re.match(r"\s*return", method_lines[-1]):
6971
method_lines = method_lines[:-1]
7072
method_body = "\n".join(method_lines)
73+
method_body = self.process_method_body(
74+
method_body,
75+
input_names,
76+
output_names,
77+
super_base=find_super_method(
78+
self.nipype_interface, "_run_interface", include_class=True
79+
)[1],
80+
)
7181
lo_src = inspect.getsource(self.nipype_interface._list_outputs).strip()
7282
# Strip out method def and return statement
7383
lo_lines = lo_src.strip().split("\n")[1:]
7484
if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]):
7585
lo_lines = lo_lines[:-1]
7686
lo_src = "\n".join(lo_lines)
87+
lo_src = self.process_method_body(
88+
lo_src,
89+
input_names,
90+
output_names,
91+
super_base=find_super_method(
92+
self.nipype_interface, "_list_outputs", include_class=True
93+
)[1],
94+
)
7795
method_body += "\n" + lo_src
78-
method_body = self.process_method_body(method_body, input_names, output_names)
7996
method_body = re.sub(
8097
r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", method_body
8198
)

nipype2pydra/interface/shell_command.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,21 @@ class ShellCommandInterfaceConverter(BaseInterfaceConverter):
2929

3030
_format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict)
3131

32-
INCLUDED_METHODS = (
33-
"_parse_inputs",
34-
"_format_arg",
35-
"_list_outputs",
36-
"_gen_filename",
37-
)
32+
@cached_property
33+
def included_methods(self) -> ty.Tuple[str, ...]:
34+
included = []
35+
if not self.method_omitted("_parse_inputs"):
36+
included.append("_parse_inputs"),
37+
if not self.method_omitted("_format_arg"):
38+
included.append("_format_arg")
39+
if not self.method_omitted("_gen_filename"):
40+
included.append("_gen_filename")
41+
if self.callable_output_fields:
42+
if not self.method_omitted("aggregate_outputs"):
43+
included.append("aggregate_outputs")
44+
if not self.method_omitted("_list_outputs"):
45+
included.append("_list_outputs")
46+
return tuple(included)
3847

3948
def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[
4049
str,
@@ -142,7 +151,7 @@ def types_to_names(spec_fields):
142151
spec_str = re.sub(r"'#([^'#]+)#'", r"\1", spec_str)
143152

144153
for m in sorted(self.referenced_methods, key=attrgetter("__name__")):
145-
if m.__name__ in self.INCLUDED_METHODS:
154+
if m.__name__ in self.included_methods:
146155
continue
147156
if self.method_stacks[m.__name__][0] == self.nipype_interface._list_outputs:
148157
additional_args = CALLABLES_ARGS
@@ -251,23 +260,23 @@ def callable_output_field_names(self):
251260

252261
@cached_property
253262
def _format_arg_body(self):
254-
if "_format_arg" not in self.nipype_interface.__dict__:
263+
if self.method_omitted("_format_arg"):
255264
return ""
256265
return _strip_doc_string(
257266
inspect.getsource(self.nipype_interface._format_arg).split("\n", 1)[-1]
258267
)
259268

260269
@cached_property
261270
def _gen_filename_body(self):
262-
if "_gen_filename" not in self.nipype_interface.__dict__:
271+
if self.method_omitted("_gen_filename"):
263272
return ""
264273
return _strip_doc_string(
265274
inspect.getsource(self.nipype_interface._gen_filename).split("\n", 1)[-1]
266275
)
267276

268277
@property
269278
def format_arg_code(self):
270-
if not self._format_arg_body:
279+
if "_format_arg" not in self.included_methods:
271280
return ""
272281
body = self._format_arg_body
273282
body = self._process_inputs(body)
@@ -296,7 +305,12 @@ def format_arg_code(self):
296305
if not body:
297306
return ""
298307
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
299-
body = self.replace_supers(body)
308+
body = self.replace_supers(
309+
body,
310+
super_base=find_super_method(
311+
self.nipype_interface, "_format_arg", include_class=True
312+
)[1],
313+
)
300314

301315
code_str = f"""def _format_arg({name_arg}, {val_arg}, inputs, argstr):{self.parse_inputs_call}
302316
if {val_arg} is None:
@@ -318,7 +332,7 @@ def format_arg_code(self):
318332

319333
@property
320334
def parse_inputs_code(self) -> str:
321-
if "_parse_inputs" not in self.nipype_interface.__dict__:
335+
if "_parse_inputs" not in self.included_methods:
322336
return ""
323337
body = _strip_doc_string(
324338
inspect.getsource(self.nipype_interface._parse_inputs).split("\n", 1)[-1]
@@ -336,7 +350,12 @@ def parse_inputs_code(self) -> str:
336350
if not body:
337351
return ""
338352
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
339-
body = self.replace_supers(body)
353+
body = self.replace_supers(
354+
body,
355+
super_base=find_super_method(
356+
self.nipype_interface, "_parse_inputs", include_class=True
357+
)[1],
358+
)
340359

341360
code_str = "def _parse_inputs(inputs):\n parsed_inputs = {}"
342361
if re.findall(r"\bargstrs\b", body):
@@ -352,7 +371,7 @@ def parse_inputs_code(self) -> str:
352371

353372
@cached_property
354373
def defaults_code(self):
355-
if not self.callable_default_input_field_names:
374+
if "_gen_filename" not in self.included_methods:
356375
return ""
357376

358377
body = _strip_doc_string(
@@ -364,7 +383,12 @@ def defaults_code(self):
364383
if not body:
365384
return ""
366385
body = self.unwrap_nested_methods(body, inputs_as_dict=True)
367-
body = self.replace_supers(body)
386+
body = self.replace_supers(
387+
body,
388+
super_base=find_super_method(
389+
self.nipype_interface, "_gen_filename", include_class=True
390+
)[1],
391+
)
368392

369393
code_str = f"""def _gen_filename(name, inputs):{self.parse_inputs_call}
370394
{body}
@@ -387,10 +411,7 @@ def callables_code(self):
387411
if not self.callable_output_fields:
388412
return ""
389413
code_str = ""
390-
if (
391-
find_super_method(self.nipype_interface, "aggregate_outputs")[1]
392-
is not BaseInterface
393-
):
414+
if "aggregate_outputs" in self.included_methods:
394415
func_name = "aggregate_outputs"
395416
body = _strip_doc_string(
396417
inspect.getsource(self.nipype_interface.aggregate_outputs).split(
@@ -406,7 +427,12 @@ def callables_code(self):
406427
body = self.unwrap_nested_methods(
407428
body, additional_args=CALLABLES_ARGS, inputs_as_dict=True
408429
)
409-
body = self.replace_supers(body)
430+
body = self.replace_supers(
431+
body,
432+
super_base=find_super_method(
433+
self.nipype_interface, "aggregate_outputs", include_class=True
434+
)[1],
435+
)
410436

411437
code_str += f"""def aggregate_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):
412438
inputs = attrs.asdict(inputs){self.parse_inputs_call}
@@ -436,7 +462,12 @@ def callables_code(self):
436462
body = self.unwrap_nested_methods(
437463
body, additional_args=CALLABLES_ARGS, inputs_as_dict=True
438464
)
439-
body = self.replace_supers(body)
465+
body = self.replace_supers(
466+
body,
467+
super_base=find_super_method(
468+
self.nipype_interface, "_list_outputs", include_class=True
469+
)[1],
470+
)
440471

441472
code_str += f"""def _list_outputs(inputs=None, stdout=None, stderr=None, output_dir=None):{inputs_as_dict_call}{self.parse_inputs_call}
442473
{body}
@@ -476,6 +507,11 @@ def parse_inputs_call(self):
476507
return ""
477508
return "\n parsed_inputs = _parse_inputs(inputs) if inputs else {}"
478509

510+
def method_omitted(self, method_name: str) -> bool:
511+
return self.package.is_omitted(
512+
find_super_method(self.nipype_interface, method_name, include_class=True)[1]
513+
)
514+
479515

480516
def _strip_doc_string(body: str) -> str:
481517
if re.match(r"\s*(\"|')", body):

nipype2pydra/utils/misc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,11 @@ def insert_args_in_signature(snippet: str, new_args: ty.Iterable[str]) -> str:
361361
pre, args, post = extract_args(snippet)
362362
if "runtime" in args:
363363
args.remove("runtime")
364-
return pre + ", ".join(args + new_args) + post
364+
if args and args[-1].startswith("**"):
365+
kwargs = [args.pop()]
366+
else:
367+
kwargs = []
368+
return pre + ", ".join(args + new_args + kwargs) + post
365369

366370

367371
def get_source_code(func_or_klass: ty.Union[ty.Callable, ty.Type]) -> str:

0 commit comments

Comments
 (0)