5
5
from abc import ABCMeta , abstractmethod
6
6
from importlib import import_module
7
7
from types import ModuleType
8
+ from collections import defaultdict
8
9
import itertools
9
10
import inspect
10
11
import traits .trait_types
13
14
import attrs
14
15
from attrs .converters import default_if_none
15
16
import nipype .interfaces .base
16
- from nipype .interfaces .base import traits_extension
17
+ from nipype .interfaces .base import traits_extension , CommandLine
17
18
from pydra .engine import specs
18
19
from pydra .engine .helpers import ensure_list
19
20
from ..utils import (
@@ -459,34 +460,66 @@ def referenced_methods(self):
459
460
return self ._referenced_funcs_and_methods [1 ]
460
461
461
462
@property
462
- def method_args (self ):
463
+ def referenced_supers (self ):
463
464
return self ._referenced_funcs_and_methods [2 ]
464
465
465
466
@property
466
- def method_returns (self ):
467
+ def method_args (self ):
467
468
return self ._referenced_funcs_and_methods [3 ]
468
469
470
+ @property
471
+ def method_returns (self ):
472
+ return self ._referenced_funcs_and_methods [4 ]
473
+
474
+ @property
475
+ def method_stacks (self ):
476
+ return self ._referenced_funcs_and_methods [5 ]
477
+
478
+ @property
479
+ def method_supers (self ):
480
+ return self ._referenced_funcs_and_methods [6 ]
481
+
469
482
@cached_property
470
483
def _referenced_funcs_and_methods (self ):
471
484
referenced_funcs = set ()
472
485
referenced_methods = set ()
486
+ referenced_supers = {}
473
487
method_args = {}
474
488
method_returns = {}
489
+ method_stacks = {}
490
+ method_supers = defaultdict (dict )
475
491
already_processed = set (
476
492
getattr (self .nipype_interface , m ) for m in self .INCLUDED_METHODS
477
493
)
494
+ for method_name in self .INCLUDED_METHODS :
495
+ method_args [method_name ] = []
496
+ method_returns [method_name ] = []
497
+ method_stacks [method_name ] = ()
478
498
for method_name in self .INCLUDED_METHODS :
479
499
if method_name not in self .nipype_interface .__dict__ :
480
500
continue # Don't include base methods
501
+ method = getattr (self .nipype_interface , method_name )
502
+ referenced_methods .add (method )
481
503
self ._get_referenced (
482
- getattr (self .nipype_interface , method_name ),
483
- referenced_funcs ,
484
- referenced_methods ,
485
- method_args ,
486
- method_returns ,
504
+ method ,
505
+ referenced_funcs = referenced_funcs ,
506
+ referenced_methods = referenced_methods ,
507
+ referenced_supers = referenced_supers ,
508
+ method_args = method_args ,
509
+ method_returns = method_returns ,
510
+ method_stacks = method_stacks ,
511
+ method_supers = method_supers ,
487
512
already_processed = already_processed ,
488
513
)
489
- return referenced_funcs , referenced_methods , method_args , method_returns
514
+ return (
515
+ referenced_funcs ,
516
+ referenced_methods ,
517
+ referenced_supers ,
518
+ method_args ,
519
+ method_returns ,
520
+ method_stacks ,
521
+ method_supers ,
522
+ )
490
523
491
524
@cached_property
492
525
def source_code (self ):
@@ -717,13 +750,14 @@ def function_callables(self):
717
750
"callables module must be provided if output_callables are set in the spec file"
718
751
)
719
752
fun_str = ""
720
- fun_names = list (set (self .outputs .callables .values ()))
721
- fun_names .sort ()
722
- for fun_nm in fun_names :
723
- fun = getattr (self .callables_module , fun_nm )
724
- fun_str += inspect .getsource (fun ) + "\n "
725
- list_outputs = getattr (self .callables_module , "_list_outputs" )
726
- fun_str += inspect .getsource (list_outputs ) + "\n "
753
+ if list (set (self .outputs .callables .values ())):
754
+ fun_str = inspect .getsource (self .callables_module )
755
+ # fun_names.sort()
756
+ # for fun_nm in fun_names:
757
+ # fun = getattr(self.callables_module, fun_nm)
758
+ # fun_str += inspect.getsource(fun) + "\n"
759
+ # list_outputs = getattr(self.callables_module, "_list_outputs")
760
+ # fun_str += inspect.getsource(list_outputs) + "\n"
727
761
return fun_str
728
762
729
763
def pydra_type_converter (self , field , spec_type , name ):
@@ -975,14 +1009,33 @@ def create_doctests(self, input_fields, nonstd_types):
975
1009
976
1010
return " Examples\n -------\n \n " + doctest_str
977
1011
1012
+ def _misc_cleanups (self , body : str ) -> str :
1013
+ if hasattr (self .nipype_interface , "_cmd" ):
1014
+ body = body .replace ("self.cmd" , f'"{ self .nipype_interface ._cmd } "' )
1015
+
1016
+ body = re .sub (
1017
+ r"outputs = self\.(output_spec|_outputs)\(\).*$" ,
1018
+ r"outputs = {}" ,
1019
+ body ,
1020
+ flags = re .MULTILINE ,
1021
+ )
1022
+ body = re .sub (r"\w+runtime\.(stdout|stderr)" , r"\1" , body )
1023
+ body = body .replace ("os.getcwd()" , "output_dir" )
1024
+ return body
1025
+
978
1026
def _get_referenced (
979
1027
self ,
980
1028
method : ty .Callable ,
981
1029
referenced_funcs : ty .Set [ty .Callable ],
982
- referenced_methods : ty .Set [ty .Callable ] = None ,
1030
+ referenced_methods : ty .Set [ty .Callable ],
1031
+ referenced_supers : ty .Dict [str , ty .Tuple [ty .Callable , type ]],
983
1032
method_args : ty .Dict [str , ty .List [str ]] = None ,
984
1033
method_returns : ty .Dict [str , ty .List [str ]] = None ,
1034
+ method_stacks : ty .Dict [str , ty .Tuple [ty .Callable ]] = None ,
1035
+ method_supers : ty .Dict [type , ty .Dict [str , str ]] = None ,
985
1036
already_processed : ty .Set [ty .Callable ] = None ,
1037
+ method_stack : ty .Optional [ty .Tuple [ty .Callable ]] = None ,
1038
+ super_base : ty .Optional [type ] = None ,
986
1039
) -> ty .Tuple [ty .Set , ty .Set ]:
987
1040
"""Get the local functions referenced in the source code
988
1041
@@ -1012,6 +1065,12 @@ def _get_referenced(
1012
1065
already_processed .add (method )
1013
1066
else :
1014
1067
already_processed = {method }
1068
+ if method_stack is None :
1069
+ method_stack = (method ,)
1070
+ else :
1071
+ method_stack += (method ,)
1072
+ if super_base is None :
1073
+ super_base = self .nipype_interface
1015
1074
method_body = inspect .getsource (method )
1016
1075
method_body = re .sub (r"\s*#.*" , "" , method_body ) # Strip out comments
1017
1076
return_value = get_return_line (method_body )
@@ -1034,14 +1093,52 @@ def _get_referenced(
1034
1093
referenced_outputs .update (
1035
1094
re .findall (return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=" , method_body )
1036
1095
)
1096
+ for match in re .findall (r"super\([^\)]*\)\.(\w+)\(" , method_body ):
1097
+ super_method = None
1098
+ for base in self .nipype_interface .__mro__ [1 :]:
1099
+ if match in base .__dict__ : # Found the match
1100
+ super_method = getattr (base , match )
1101
+ break
1102
+ assert super_method is not None , (
1103
+ f"Could not find super of '{ match } ' method in base classes of "
1104
+ f"{ self .nipype_interface } "
1105
+ )
1106
+ func_name = self ._common_parent_pkg_prefix (base ) + match
1107
+ if func_name not in referenced_supers :
1108
+ referenced_supers [func_name ] = (super_method , base )
1109
+ method_supers [super_base ][match ] = func_name
1110
+ method_stacks [func_name ] = method_stack
1111
+ rf_inputs , rf_outputs = self ._get_referenced (
1112
+ super_method ,
1113
+ referenced_funcs ,
1114
+ referenced_methods ,
1115
+ referenced_supers = referenced_supers ,
1116
+ method_args = method_args ,
1117
+ method_returns = method_returns ,
1118
+ method_stacks = method_stacks ,
1119
+ method_supers = method_supers ,
1120
+ already_processed = already_processed ,
1121
+ method_stack = method_stack ,
1122
+ super_base = base ,
1123
+ )
1124
+ referenced_inputs .update (rf_inputs )
1125
+ referenced_outputs .update (rf_outputs )
1126
+ method_args [func_name ] = rf_inputs
1127
+ method_returns [func_name ] = rf_outputs
1128
+ method_stacks [func_name ] = method_stack
1037
1129
for func in ref_local_funcs :
1038
1130
if func in already_processed :
1039
1131
continue
1040
1132
rf_inputs , rf_outputs = self ._get_referenced (
1041
1133
func ,
1042
1134
referenced_funcs ,
1043
1135
referenced_methods ,
1136
+ referenced_supers = referenced_supers ,
1137
+ method_stacks = method_stacks ,
1138
+ method_supers = method_supers ,
1044
1139
already_processed = already_processed ,
1140
+ method_stack = method_stack ,
1141
+ super_base = super_base ,
1045
1142
)
1046
1143
referenced_inputs .update (rf_inputs )
1047
1144
referenced_outputs .update (rf_outputs )
@@ -1052,16 +1149,36 @@ def _get_referenced(
1052
1149
meth ,
1053
1150
referenced_funcs ,
1054
1151
referenced_methods ,
1152
+ referenced_supers = referenced_supers ,
1055
1153
method_args = method_args ,
1056
1154
method_returns = method_returns ,
1155
+ method_stacks = method_stacks ,
1156
+ method_supers = method_supers ,
1057
1157
already_processed = already_processed ,
1158
+ method_stack = method_stack ,
1159
+ super_base = super_base ,
1058
1160
)
1059
1161
method_args [meth .__name__ ] = ref_inputs
1060
1162
method_returns [meth .__name__ ] = ref_outputs
1163
+ method_stacks [meth .__name__ ] = method_stack
1061
1164
referenced_inputs .update (ref_inputs )
1062
1165
referenced_outputs .update (ref_outputs )
1063
1166
return referenced_inputs , sorted (referenced_outputs )
1064
1167
1168
+ def _common_parent_pkg_prefix (self , base : type ) -> str :
1169
+ """Return the common part of two package names"""
1170
+ ref_parts = self .nipype_interface .__module__ .split ("." )
1171
+ mod_parts = base .__module__ .split ("." )
1172
+ common = []
1173
+ for r_part , m_part in zip (ref_parts , mod_parts ):
1174
+ if r_part == m_part :
1175
+ common .append (r_part )
1176
+ else :
1177
+ break
1178
+ if not common :
1179
+ return ""
1180
+ return "_" .join (common + [base .__name__ ]) + "__"
1181
+
1065
1182
@cached_property
1066
1183
def local_functions (self ):
1067
1184
"""Get the functions defined in the same file as the interface"""
@@ -1078,7 +1195,12 @@ def process_method(
1078
1195
output_names : ty .List [str ],
1079
1196
method_args : ty .Dict [str , ty .List [str ]] = None ,
1080
1197
method_returns : ty .Dict [str , ty .List [str ]] = None ,
1198
+ additional_args : ty .Sequence [str ] = (),
1199
+ new_name : ty .Optional [str ] = None ,
1200
+ super_base : ty .Optional [type ] = None ,
1081
1201
):
1202
+ if super_base is None :
1203
+ super_base = self .nipype_interface
1082
1204
src = inspect .getsource (method )
1083
1205
pre , args , post = extract_args (src )
1084
1206
try :
@@ -1088,11 +1210,16 @@ def process_method(
1088
1210
if "runtime" in args :
1089
1211
args .remove ("runtime" )
1090
1212
if method .__name__ in self .method_args :
1091
- args += [f"{ a } =None" for a in self .method_args [method .__name__ ]]
1213
+ args += [
1214
+ f"{ a } =None"
1215
+ for a in (list (self .method_args [method .__name__ ]) + additional_args )
1216
+ ]
1092
1217
# Insert method args in signature if present
1093
1218
return_types , method_body = post .split (":" , maxsplit = 1 )
1094
1219
method_body = method_body .split ("\n " , maxsplit = 1 )[1 ]
1095
- method_body = self .process_method_body (method_body , input_names , output_names )
1220
+ method_body = self .process_method_body (
1221
+ method_body , input_names , output_names , super_base
1222
+ )
1096
1223
if self .method_returns .get (method .__name__ ):
1097
1224
return_args = self .method_returns [method .__name__ ]
1098
1225
method_body = (
@@ -1109,11 +1236,19 @@ def process_method(
1109
1236
)
1110
1237
pre = re .sub (r"^\s*" , "" , pre , flags = re .MULTILINE )
1111
1238
pre = pre .replace ("@staticmethod\n " , "" )
1239
+ if new_name :
1240
+ pre = re .sub (r"^def (\w+)\(" , f"def { new_name } (" , pre , flags = re .MULTILINE )
1112
1241
return f"{ pre } { ', ' .join (args )} { return_types } :\n { method_body } "
1113
1242
1114
1243
def process_method_body (
1115
- self , method_body : str , input_names : ty .List [str ], output_names : ty .List [str ]
1244
+ self ,
1245
+ method_body : str ,
1246
+ input_names : ty .List [str ],
1247
+ output_names : ty .List [str ],
1248
+ super_base : ty .Optional [type ] = None ,
1116
1249
) -> str :
1250
+ if super_base is None :
1251
+ super_base = self .nipype_interface
1117
1252
return_value = get_return_line (method_body )
1118
1253
method_body = method_body .replace ("if self.output_spec:" , "if True:" )
1119
1254
# Replace self.inputs.<name> with <name> in the function body
@@ -1129,6 +1264,7 @@ def process_method_body(
1129
1264
self .task_name ,
1130
1265
)
1131
1266
method_body = input_re .sub (r"\1" , method_body )
1267
+ method_body = self .replace_supers (method_body , super_base )
1132
1268
1133
1269
if return_value :
1134
1270
output_re = re .compile (return_value + r"\[(?:'|\")(\w+)(?:'|\")\]" )
@@ -1147,9 +1283,20 @@ def process_method_body(
1147
1283
method_body = re .sub (
1148
1284
r"outputs = self.output_spec().*" , r"outputs = {}" , method_body
1149
1285
)
1286
+ method_body = self ._misc_cleanups (method_body )
1150
1287
return self .unwrap_nested_methods (method_body )
1151
1288
1152
- def unwrap_nested_methods (self , method_body ):
1289
+ def replace_supers (self , method_body , super_base = None ):
1290
+ if super_base is None :
1291
+ super_base = self .nipype_interface
1292
+ super_name_map = self .method_supers [super_base ]
1293
+ return re .sub (
1294
+ r"super\([^\)]*\)\.(\w+)\(" ,
1295
+ lambda m : super_name_map [m .group (1 )] + "(" ,
1296
+ method_body ,
1297
+ )
1298
+
1299
+ def unwrap_nested_methods (self , method_body , additional_args = ()):
1153
1300
"""
1154
1301
Converts nested method calls into function calls
1155
1302
"""
@@ -1193,7 +1340,11 @@ def unwrap_nested_methods(self, method_body):
1193
1340
# Insert additional arguments to the method call (which were previously
1194
1341
# accessed via member attributes)
1195
1342
new_body += name + insert_args_in_signature (
1196
- args , [f"{ a } ={ a } " for a in self .method_args [name ]]
1343
+ args ,
1344
+ [
1345
+ f"{ a } ={ a } "
1346
+ for a in (list (self .method_args [name ]) + list (additional_args ))
1347
+ ],
1197
1348
)
1198
1349
method_body = new_body
1199
1350
# Convert assignment to self attributes into method-scoped variables (hopefully
@@ -1203,6 +1354,8 @@ def unwrap_nested_methods(self, method_body):
1203
1354
)
1204
1355
return cleanup_function_body (method_body )
1205
1356
1357
+ SUPER_MAPPINGS = {CommandLine : {"_list_outputs" : "{}" }}
1358
+
1206
1359
INPUT_KEYS = [
1207
1360
"allowed_values" ,
1208
1361
"argstr" ,
0 commit comments