@@ -63,49 +63,9 @@ def types_to_names(spec_fields):
63
63
output_names = [o [0 ] for o in output_fields ]
64
64
output_type_names = [o [1 ] for o in output_fields_str ]
65
65
66
- # Combined src of run_interface and list_outputs
67
- method_body = inspect .getsource (self .nipype_interface ._run_interface ).strip ()
68
- # Strip out method def and return statement
69
- method_lines = method_body .strip ().split ("\n " )[1 :]
70
- if re .match (r"\s*return" , method_lines [- 1 ]):
71
- method_lines = method_lines [:- 1 ]
72
- method_body = "\n " .join (method_lines )
73
- method_body = self .process_method_body (
74
- method_body ,
75
- input_names ,
76
- output_names ,
77
- super_base = find_super_method (
78
- self .nipype_interface , "_run_interface" , include_class = True
79
- )[1 ],
80
- )
81
- lo_src = inspect .getsource (self .nipype_interface ._list_outputs ).strip ()
82
- # Strip out method def and return statement
83
- lo_lines = lo_src .strip ().split ("\n " )[1 :]
84
- if re .match (r"\s*(return|raise NotImplementedError)" , lo_lines [- 1 ]):
85
- lo_lines = lo_lines [:- 1 ]
86
- lo_src = "\n " .join (lo_lines )
87
- lo_src = self .process_method_body (
88
- lo_src ,
89
- input_names ,
90
- output_names ,
91
- super_base = find_super_method (
92
- self .nipype_interface , "_list_outputs" , include_class = True
93
- )[1 ],
94
- )
95
- method_body += "\n " + lo_src
96
- method_body = re .sub (
97
- r"self\._results\[(?:'|\")(\w+)(?:'|\")\]" , r"\1" , method_body
98
- )
99
-
100
66
used = UsedSymbols .find (
101
67
self .nipype_module ,
102
- [method_body ]
103
- + [
104
- inspect .getsource (f )
105
- for f in itertools .chain (
106
- self .referenced_local_functions , self .referenced_methods
107
- )
108
- ],
68
+ self .referenced_local_functions ,
109
69
omit_classes = self .package .omit_classes + [BaseInterface , TraitedSpec ],
110
70
omit_modules = self .package .omit_modules ,
111
71
omit_functions = self .package .omit_functions ,
@@ -115,6 +75,128 @@ def types_to_names(spec_fields):
115
75
absolute_imports = True ,
116
76
)
117
77
78
+ for ref_method in self .referenced_methods :
79
+ method_module = find_super_method (
80
+ self .nipype_interface , ref_method .__name__ , include_class = True
81
+ )[1 ].__module__
82
+ method_used = UsedSymbols .find (
83
+ method_module ,
84
+ [ref_method ],
85
+ omit_classes = self .package .omit_classes + [BaseInterface , TraitedSpec ],
86
+ omit_modules = self .package .omit_modules ,
87
+ omit_functions = self .package .omit_functions ,
88
+ omit_constants = self .package .omit_constants ,
89
+ always_include = self .package .all_explicit ,
90
+ translations = self .package .all_import_translations ,
91
+ absolute_imports = True ,
92
+ )
93
+ used .update (method_used , from_other_module = False )
94
+
95
+ method_body = ""
96
+ for field in output_fields :
97
+ method_body += f" { field [0 ]} = attrs.NOTHING\n "
98
+
99
+ # Combined src of init and list_outputs
100
+ init_code = inspect .getsource (self .nipype_interface .__init__ ).strip ()
101
+ init_class = find_super_method (
102
+ self .nipype_interface , "__init__" , include_class = True
103
+ )[1 ]
104
+ if not self .package .is_omitted (init_class ):
105
+ # Strip out method def and return statement
106
+ method_lines = init_code .strip ().split ("\n " )[1 :]
107
+ if re .match (r"\s*return" , method_lines [- 1 ]):
108
+ method_lines = method_lines [:- 1 ]
109
+ init_code = "\n " .join (method_lines )
110
+ init_code = self .process_method_body (
111
+ init_code ,
112
+ input_names ,
113
+ output_names ,
114
+ super_base = init_class ,
115
+ )
116
+
117
+ init_used = UsedSymbols .find (
118
+ init_class .__module__ ,
119
+ [init_code ],
120
+ omit_classes = self .package .omit_classes + [BaseInterface , TraitedSpec ],
121
+ omit_modules = self .package .omit_modules ,
122
+ omit_functions = self .package .omit_functions ,
123
+ omit_constants = self .package .omit_constants ,
124
+ always_include = self .package .all_explicit ,
125
+ translations = self .package .all_import_translations ,
126
+ absolute_imports = True ,
127
+ )
128
+ used .update (init_used , from_other_module = False )
129
+ method_body += init_code + "\n "
130
+
131
+ # Combined src of run_interface and list_outputs
132
+ run_interface_code = inspect .getsource (
133
+ self .nipype_interface ._run_interface
134
+ ).strip ()
135
+ run_interface_class = find_super_method (
136
+ self .nipype_interface , "_run_interface" , include_class = True
137
+ )[1 ]
138
+ if not self .package .is_omitted (run_interface_class ):
139
+ # Strip out method def and return statement
140
+ method_lines = run_interface_code .strip ().split ("\n " )[1 :]
141
+ if re .match (r"\s*return" , method_lines [- 1 ]):
142
+ method_lines = method_lines [:- 1 ]
143
+ run_interface_code = "\n " .join (method_lines )
144
+ run_interface_code = self .process_method_body (
145
+ run_interface_code ,
146
+ input_names ,
147
+ output_names ,
148
+ super_base = run_interface_class ,
149
+ )
150
+
151
+ run_interface_used = UsedSymbols .find (
152
+ run_interface_class .__module__ ,
153
+ [run_interface_code ],
154
+ omit_classes = self .package .omit_classes + [BaseInterface , TraitedSpec ],
155
+ omit_modules = self .package .omit_modules ,
156
+ omit_functions = self .package .omit_functions ,
157
+ omit_constants = self .package .omit_constants ,
158
+ always_include = self .package .all_explicit ,
159
+ translations = self .package .all_import_translations ,
160
+ absolute_imports = True ,
161
+ )
162
+ used .update (run_interface_used , from_other_module = False )
163
+ method_body += run_interface_code + "\n "
164
+
165
+ list_outputs_code = inspect .getsource (
166
+ self .nipype_interface ._list_outputs
167
+ ).strip ()
168
+ list_outputs_class = find_super_method (
169
+ self .nipype_interface , "_list_outputs" , include_class = True
170
+ )[1 ]
171
+ if not self .package .is_omitted (list_outputs_class ):
172
+ # Strip out method def and return statement
173
+ lo_lines = list_outputs_code .strip ().split ("\n " )[1 :]
174
+ if re .match (r"\s*(return|raise NotImplementedError)" , lo_lines [- 1 ]):
175
+ lo_lines = lo_lines [:- 1 ]
176
+ list_outputs_code = "\n " .join (lo_lines )
177
+ list_outputs_code = self .process_method_body (
178
+ list_outputs_code ,
179
+ input_names ,
180
+ output_names ,
181
+ super_base = list_outputs_class ,
182
+ )
183
+
184
+ list_outputs_used = UsedSymbols .find (
185
+ list_outputs_class .__module__ ,
186
+ [list_outputs_code ],
187
+ omit_classes = self .package .omit_classes + [BaseInterface , TraitedSpec ],
188
+ omit_modules = self .package .omit_modules ,
189
+ omit_functions = self .package .omit_functions ,
190
+ omit_constants = self .package .omit_constants ,
191
+ always_include = self .package .all_explicit ,
192
+ translations = self .package .all_import_translations ,
193
+ absolute_imports = True ,
194
+ )
195
+ used .update (list_outputs_used , from_other_module = False )
196
+ method_body += list_outputs_code + "\n "
197
+
198
+ assert method_body , "Neither `run_interface` and `list_outputs` are defined"
199
+
118
200
spec_str = "@pydra.mark.task\n "
119
201
spec_str += "@pydra.mark.annotate({'return': {"
120
202
spec_str += ", " .join (f"'{ n } ': { t } " for n , t , _ in output_fields_str )
@@ -156,11 +238,13 @@ def types_to_names(spec_fields):
156
238
additional_imports .add (imprt )
157
239
spec_str = repl_spec_str
158
240
159
- used .imports = self .construct_imports (
160
- nonstd_types ,
161
- spec_str ,
162
- include_task = False ,
163
- base = base_imports + list (used .imports ) + list (additional_imports ),
241
+ used .imports .update (
242
+ self .construct_imports (
243
+ nonstd_types ,
244
+ spec_str ,
245
+ include_task = False ,
246
+ base = base_imports + list (used .imports ) + list (additional_imports ),
247
+ )
164
248
)
165
249
166
250
return spec_str , used
0 commit comments