Skip to content

Commit 0aa80aa

Browse files
committed
debugging updates to conversions to handle function task super methods
1 parent 63a4077 commit 0aa80aa

File tree

9 files changed

+259
-107
lines changed

9 files changed

+259
-107
lines changed

nipype2pydra/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]:
350350
# Write to file for debugging
351351
debug_file = "~/unparsable-nipype2pydra-output.py"
352352
with open(Path(debug_file).expanduser(), "w") as f:
353+
f.write(f"# Attemping to convert {self.full_name}\n")
353354
f.write(code_str)
354355
raise RuntimeError(
355356
f"Black could not parse generated code (written to {debug_file}): "
@@ -413,6 +414,7 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]:
413414
# Write to file for debugging
414415
debug_file = "~/unparsable-nipype2pydra-output.py"
415416
with open(Path(debug_file).expanduser(), "w") as f:
417+
f.write(f"# Attemping to convert {self.full_name}\n")
416418
f.write(code_str)
417419
raise RuntimeError(
418420
f"Black could not parse generated code (written to {debug_file}): "

nipype2pydra/interface/base.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
import attrs
1515
from attrs.converters import default_if_none
1616
import nipype.interfaces.base
17-
from nipype.interfaces.base import traits_extension, CommandLine, BaseInterface
17+
from nipype.interfaces.base import (
18+
traits_extension,
19+
CommandLine,
20+
BaseInterface,
21+
)
22+
from nipype.interfaces.base.core import SimpleInterface
1823
from pydra.engine import specs
1924
from pydra.engine.helpers import ensure_list
2025
from ..utils import (
@@ -33,6 +38,7 @@
3338
extract_args,
3439
strip_comments,
3540
find_super_method,
41+
min_indentation,
3642
)
3743
from ..statements import (
3844
ImportStatement,
@@ -364,6 +370,8 @@ class BaseInterfaceConverter(metaclass=ABCMeta):
364370
},
365371
)
366372

373+
_output_name_mappings: ty.Dict[str, str] = attrs.field(factory=dict)
374+
367375
def __attrs_post_init__(self):
368376
if self.output_module is None:
369377
if self.nipype_module.__name__.startswith("nipype.interfaces."):
@@ -682,6 +690,7 @@ def pydra_fld_input(self, field, nm):
682690
f"the filed {nm} has genfile=True, but no template or "
683691
"`callables_default` function in the callables_module provided"
684692
)
693+
self._output_name_mappings[getattr(field, "output_name")] = nm
685694

686695
pydra_metadata.update(metadata_extra_spec)
687696

@@ -1021,18 +1030,26 @@ def _misc_cleanups(self, body: str) -> str:
10211030
if hasattr(self.nipype_interface, "_cmd"):
10221031
body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"')
10231032

1024-
body = re.sub(
1025-
r"outputs = self\.(output_spec|_outputs)\(\).*$",
1026-
r"outputs = {}",
1027-
body,
1028-
flags=re.MULTILINE,
1029-
)
1033+
body = body.replace("self.output_spec().get()", "{}")
1034+
body = body.replace("self._outputs()", "{}")
1035+
# body = re.sub(
1036+
# r"outputs = self\.(output_spec|_outputs)\(\).*$",
1037+
# r"outputs = {}",
1038+
# body,
1039+
# flags=re.MULTILINE,
1040+
# )
10301041
body = re.sub(r"\bruntime\.(stdout|stderr)", r"\1", body)
10311042
body = re.sub(r"\boutputs\.(\w+)", r"outputs['\1']", body)
10321043
body = re.sub(r"getattr\(inputs, ([^)]+)\)", r"inputs[\1]", body)
10331044
body = re.sub(
10341045
r"setattr\(outputs, ([^,]+), ([^)]+)\)", r"outputs[\1] = \2", body
10351046
)
1047+
body = re.sub(r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", body)
1048+
body = re.sub(r"\s+runtime.returncode = (.*)", "", body)
1049+
new_body = re.sub(r"self\.(\w+)\b(?!\()", r"self_dict['\1']", body)
1050+
if new_body != body:
1051+
body = " " * min_indentation(body) + "self_dict = {}\n" + new_body
1052+
body = body.replace("return runtime", "")
10361053
body = body.replace("TraitError", "KeyError")
10371054
body = body.replace("os.getcwd()", "output_dir")
10381055
return body
@@ -1237,7 +1254,10 @@ def process_method(
12371254
if self.method_returns.get(method.__name__):
12381255
return_args = self.method_returns[method.__name__]
12391256
method_body = (
1240-
" " + " = ".join(return_args) + " = attrs.NOTHING\n" + method_body
1257+
" " * min_indentation(method_body)
1258+
+ " = ".join(return_args)
1259+
+ " = attrs.NOTHING\n"
1260+
+ method_body
12411261
)
12421262
method_lines = method_body.rstrip().splitlines()
12431263
method_body = "\n".join(method_lines[:-1])
@@ -1295,12 +1315,9 @@ def process_method_body(
12951315
self.task_name,
12961316
)
12971317
method_body = output_re.sub(r"\1", method_body)
1298-
# Strip initialisation of outputs
1299-
method_body = re.sub(
1300-
r"outputs = self.output_spec().*", r"outputs = {}", method_body
1301-
)
1302-
method_body = self._misc_cleanups(method_body)
1303-
return self.unwrap_nested_methods(method_body)
1318+
method_body = self.unwrap_nested_methods(method_body)
1319+
# method_body = self._misc_cleanups(method_body)
1320+
return method_body
13041321

13051322
def replace_supers(self, method_body, super_base=None):
13061323
if super_base is None:
@@ -1335,6 +1352,7 @@ def unwrap_nested_methods(
13351352
"""
13361353
Converts nested method calls into function calls
13371354
"""
1355+
method_body = self._misc_cleanups(method_body)
13381356
# Add args to the function signature of method calls
13391357
method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL)
13401358
method_names = [m.__name__ for m in self.referenced_methods] + list(
@@ -1426,6 +1444,10 @@ def unwrap_nested_methods(
14261444
BaseInterface.aggregate_outputs: "{{}}",
14271445
BaseInterface.run: "None",
14281446
BaseInterface._list_outputs: "{{}}",
1447+
BaseInterface.__init__: "",
1448+
SimpleInterface.__init__: "",
1449+
BaseInterface._outputs: "{{}}",
1450+
None: "",
14291451
}
14301452

14311453
INPUT_KEYS = [

nipype2pydra/interface/function.py

Lines changed: 130 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -63,49 +63,9 @@ def types_to_names(spec_fields):
6363
output_names = [o[0] for o in output_fields]
6464
output_type_names = [o[1] for o in output_fields_str]
6565

66-
# Combined src of run_interface and list_outputs
67-
method_body = inspect.getsource(self.nipype_interface._run_interface).strip()
68-
# Strip out method def and return statement
69-
method_lines = method_body.strip().split("\n")[1:]
70-
if re.match(r"\s*return", method_lines[-1]):
71-
method_lines = method_lines[:-1]
72-
method_body = "\n".join(method_lines)
73-
method_body = self.process_method_body(
74-
method_body,
75-
input_names,
76-
output_names,
77-
super_base=find_super_method(
78-
self.nipype_interface, "_run_interface", include_class=True
79-
)[1],
80-
)
81-
lo_src = inspect.getsource(self.nipype_interface._list_outputs).strip()
82-
# Strip out method def and return statement
83-
lo_lines = lo_src.strip().split("\n")[1:]
84-
if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]):
85-
lo_lines = lo_lines[:-1]
86-
lo_src = "\n".join(lo_lines)
87-
lo_src = self.process_method_body(
88-
lo_src,
89-
input_names,
90-
output_names,
91-
super_base=find_super_method(
92-
self.nipype_interface, "_list_outputs", include_class=True
93-
)[1],
94-
)
95-
method_body += "\n" + lo_src
96-
method_body = re.sub(
97-
r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", method_body
98-
)
99-
10066
used = UsedSymbols.find(
10167
self.nipype_module,
102-
[method_body]
103-
+ [
104-
inspect.getsource(f)
105-
for f in itertools.chain(
106-
self.referenced_local_functions, self.referenced_methods
107-
)
108-
],
68+
self.referenced_local_functions,
10969
omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec],
11070
omit_modules=self.package.omit_modules,
11171
omit_functions=self.package.omit_functions,
@@ -115,6 +75,128 @@ def types_to_names(spec_fields):
11575
absolute_imports=True,
11676
)
11777

78+
for ref_method in self.referenced_methods:
79+
method_module = find_super_method(
80+
self.nipype_interface, ref_method.__name__, include_class=True
81+
)[1].__module__
82+
method_used = UsedSymbols.find(
83+
method_module,
84+
[ref_method],
85+
omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec],
86+
omit_modules=self.package.omit_modules,
87+
omit_functions=self.package.omit_functions,
88+
omit_constants=self.package.omit_constants,
89+
always_include=self.package.all_explicit,
90+
translations=self.package.all_import_translations,
91+
absolute_imports=True,
92+
)
93+
used.update(method_used, from_other_module=False)
94+
95+
method_body = ""
96+
for field in output_fields:
97+
method_body += f" {field[0]} = attrs.NOTHING\n"
98+
99+
# Combined src of init and list_outputs
100+
init_code = inspect.getsource(self.nipype_interface.__init__).strip()
101+
init_class = find_super_method(
102+
self.nipype_interface, "__init__", include_class=True
103+
)[1]
104+
if not self.package.is_omitted(init_class):
105+
# Strip out method def and return statement
106+
method_lines = init_code.strip().split("\n")[1:]
107+
if re.match(r"\s*return", method_lines[-1]):
108+
method_lines = method_lines[:-1]
109+
init_code = "\n".join(method_lines)
110+
init_code = self.process_method_body(
111+
init_code,
112+
input_names,
113+
output_names,
114+
super_base=init_class,
115+
)
116+
117+
init_used = UsedSymbols.find(
118+
init_class.__module__,
119+
[init_code],
120+
omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec],
121+
omit_modules=self.package.omit_modules,
122+
omit_functions=self.package.omit_functions,
123+
omit_constants=self.package.omit_constants,
124+
always_include=self.package.all_explicit,
125+
translations=self.package.all_import_translations,
126+
absolute_imports=True,
127+
)
128+
used.update(init_used, from_other_module=False)
129+
method_body += init_code + "\n"
130+
131+
# Combined src of run_interface and list_outputs
132+
run_interface_code = inspect.getsource(
133+
self.nipype_interface._run_interface
134+
).strip()
135+
run_interface_class = find_super_method(
136+
self.nipype_interface, "_run_interface", include_class=True
137+
)[1]
138+
if not self.package.is_omitted(run_interface_class):
139+
# Strip out method def and return statement
140+
method_lines = run_interface_code.strip().split("\n")[1:]
141+
if re.match(r"\s*return", method_lines[-1]):
142+
method_lines = method_lines[:-1]
143+
run_interface_code = "\n".join(method_lines)
144+
run_interface_code = self.process_method_body(
145+
run_interface_code,
146+
input_names,
147+
output_names,
148+
super_base=run_interface_class,
149+
)
150+
151+
run_interface_used = UsedSymbols.find(
152+
run_interface_class.__module__,
153+
[run_interface_code],
154+
omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec],
155+
omit_modules=self.package.omit_modules,
156+
omit_functions=self.package.omit_functions,
157+
omit_constants=self.package.omit_constants,
158+
always_include=self.package.all_explicit,
159+
translations=self.package.all_import_translations,
160+
absolute_imports=True,
161+
)
162+
used.update(run_interface_used, from_other_module=False)
163+
method_body += run_interface_code + "\n"
164+
165+
list_outputs_code = inspect.getsource(
166+
self.nipype_interface._list_outputs
167+
).strip()
168+
list_outputs_class = find_super_method(
169+
self.nipype_interface, "_list_outputs", include_class=True
170+
)[1]
171+
if not self.package.is_omitted(list_outputs_class):
172+
# Strip out method def and return statement
173+
lo_lines = list_outputs_code.strip().split("\n")[1:]
174+
if re.match(r"\s*(return|raise NotImplementedError)", lo_lines[-1]):
175+
lo_lines = lo_lines[:-1]
176+
list_outputs_code = "\n".join(lo_lines)
177+
list_outputs_code = self.process_method_body(
178+
list_outputs_code,
179+
input_names,
180+
output_names,
181+
super_base=list_outputs_class,
182+
)
183+
184+
list_outputs_used = UsedSymbols.find(
185+
list_outputs_class.__module__,
186+
[list_outputs_code],
187+
omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec],
188+
omit_modules=self.package.omit_modules,
189+
omit_functions=self.package.omit_functions,
190+
omit_constants=self.package.omit_constants,
191+
always_include=self.package.all_explicit,
192+
translations=self.package.all_import_translations,
193+
absolute_imports=True,
194+
)
195+
used.update(list_outputs_used, from_other_module=False)
196+
method_body += list_outputs_code + "\n"
197+
198+
assert method_body, "Neither `run_interface` and `list_outputs` are defined"
199+
118200
spec_str = "@pydra.mark.task\n"
119201
spec_str += "@pydra.mark.annotate({'return': {"
120202
spec_str += ", ".join(f"'{n}': {t}" for n, t, _ in output_fields_str)
@@ -156,11 +238,13 @@ def types_to_names(spec_fields):
156238
additional_imports.add(imprt)
157239
spec_str = repl_spec_str
158240

159-
used.imports = self.construct_imports(
160-
nonstd_types,
161-
spec_str,
162-
include_task=False,
163-
base=base_imports + list(used.imports) + list(additional_imports),
241+
used.imports.update(
242+
self.construct_imports(
243+
nonstd_types,
244+
spec_str,
245+
include_task=False,
246+
base=base_imports + list(used.imports) + list(additional_imports),
247+
)
164248
)
165249

166250
return spec_str, used

0 commit comments

Comments
 (0)