9
9
from types import ModuleType
10
10
import black .report
11
11
import yaml
12
+ from .symbols import UsedSymbols
12
13
from .utils import (
13
- UsedSymbols ,
14
14
extract_args ,
15
15
full_address ,
16
16
multiline_comment ,
@@ -121,17 +121,13 @@ def nipype_object(self):
121
121
return getattr (self .nipype_module , self .nipype_name )
122
122
123
123
@cached_property
124
- def used_symbols (self ) -> UsedSymbols :
124
+ def used (self ) -> UsedSymbols :
125
125
used = UsedSymbols .find (
126
126
self .nipype_module ,
127
127
[self .src ],
128
+ package = self .package ,
128
129
collapse_intra_pkg = False ,
129
- omit_classes = self .package .omit_classes ,
130
- omit_modules = self .package .omit_modules ,
131
- omit_functions = self .package .omit_functions ,
132
- omit_constants = self .package .omit_constants ,
133
130
always_include = self .package .all_explicit ,
134
- translations = self .package .all_import_translations ,
135
131
)
136
132
used .import_stmts .update (i .to_statement () for i in self .imports )
137
133
return used
@@ -147,10 +143,10 @@ def converted_code(self) -> ty.List[str]:
147
143
@cached_property
148
144
def nested_interfaces (self ):
149
145
potential_classes = {
150
- full_address (c [1 ]): c [0 ] for c in self .used_symbols .imported_classes if c [0 ]
146
+ full_address (c [1 ]): c [0 ] for c in self .used .imported_classes if c [0 ]
151
147
}
152
148
potential_classes .update (
153
- (full_address (c ), c .__name__ ) for c in self .used_symbols .classes
149
+ (full_address (c ), c .__name__ ) for c in self .used .classes
154
150
)
155
151
return {
156
152
potential_classes [address ]: workflow
@@ -377,8 +373,7 @@ class ClassConverter(BaseHelperConverter):
377
373
378
374
@cached_property
379
375
def _converted_code (self ) -> ty .Tuple [str , ty .List [str ]]:
380
- """Convert the Nipype workflow function to a Pydra workflow function and determine
381
- the configuration parameters that are used
376
+ """Convert a class into Pydra-
382
377
383
378
Returns
384
379
-------
@@ -389,18 +384,30 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]:
389
384
"""
390
385
391
386
used_configs = set ()
392
- parts = re .split (
393
- r"\n (?!\s|\))" , replace_undefined (self .src ), flags = re .MULTILINE
394
- )
387
+
388
+ src = replace_undefined (self .src )[len ("class " ) :]
389
+ name , bases , class_body = extract_args (src , drop_parens = True )
390
+ bases = [
391
+ b
392
+ for b in bases
393
+ if not self .package .is_omitted (getattr (self .nipype_module , b ))
394
+ ]
395
+
396
+ parts = re .split (r"\n (?!\s|\))" , class_body , flags = re .MULTILINE )
395
397
converted_parts = []
396
- for part in parts :
398
+ for part in parts [ 1 :] :
397
399
if part .startswith ("def" ):
398
400
converted_func , func_used_configs = self ._convert_function (part )
399
401
converted_parts .append (converted_func )
400
402
used_configs .update (func_used_configs )
401
403
else :
402
404
converted_parts .append (part )
403
- code_str = "\n " .join (converted_parts )
405
+ code_str = (
406
+ f"class { name } ("
407
+ + ", " .join (bases )
408
+ + "):\n "
409
+ + "\n " .join (converted_parts )
410
+ )
404
411
# Format the the code before the find and replace so it is more predictable
405
412
try :
406
413
code_str = black .format_file_contents (
0 commit comments