Skip to content

Commit 07700f2

Browse files
committed
debugging reworked _list_outputs handling
1 parent 843a83a commit 07700f2

File tree

5 files changed

+379
-72
lines changed

5 files changed

+379
-72
lines changed

nipype2pydra/interface/base.py

Lines changed: 175 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import ABCMeta, abstractmethod
66
from importlib import import_module
77
from types import ModuleType
8+
from collections import defaultdict
89
import itertools
910
import inspect
1011
import traits.trait_types
@@ -13,7 +14,7 @@
1314
import attrs
1415
from attrs.converters import default_if_none
1516
import nipype.interfaces.base
16-
from nipype.interfaces.base import traits_extension
17+
from nipype.interfaces.base import traits_extension, CommandLine
1718
from pydra.engine import specs
1819
from pydra.engine.helpers import ensure_list
1920
from ..utils import (
@@ -459,34 +460,66 @@ def referenced_methods(self):
459460
return self._referenced_funcs_and_methods[1]
460461

461462
@property
462-
def method_args(self):
463+
def referenced_supers(self):
463464
return self._referenced_funcs_and_methods[2]
464465

465466
@property
466-
def method_returns(self):
467+
def method_args(self):
467468
return self._referenced_funcs_and_methods[3]
468469

470+
@property
471+
def method_returns(self):
472+
return self._referenced_funcs_and_methods[4]
473+
474+
@property
475+
def method_stacks(self):
476+
return self._referenced_funcs_and_methods[5]
477+
478+
@property
479+
def method_supers(self):
480+
return self._referenced_funcs_and_methods[6]
481+
469482
@cached_property
470483
def _referenced_funcs_and_methods(self):
471484
referenced_funcs = set()
472485
referenced_methods = set()
486+
referenced_supers = {}
473487
method_args = {}
474488
method_returns = {}
489+
method_stacks = {}
490+
method_supers = defaultdict(dict)
475491
already_processed = set(
476492
getattr(self.nipype_interface, m) for m in self.INCLUDED_METHODS
477493
)
494+
for method_name in self.INCLUDED_METHODS:
495+
method_args[method_name] = []
496+
method_returns[method_name] = []
497+
method_stacks[method_name] = ()
478498
for method_name in self.INCLUDED_METHODS:
479499
if method_name not in self.nipype_interface.__dict__:
480500
continue # Don't include base methods
501+
method = getattr(self.nipype_interface, method_name)
502+
referenced_methods.add(method)
481503
self._get_referenced(
482-
getattr(self.nipype_interface, method_name),
483-
referenced_funcs,
484-
referenced_methods,
485-
method_args,
486-
method_returns,
504+
method,
505+
referenced_funcs=referenced_funcs,
506+
referenced_methods=referenced_methods,
507+
referenced_supers=referenced_supers,
508+
method_args=method_args,
509+
method_returns=method_returns,
510+
method_stacks=method_stacks,
511+
method_supers=method_supers,
487512
already_processed=already_processed,
488513
)
489-
return referenced_funcs, referenced_methods, method_args, method_returns
514+
return (
515+
referenced_funcs,
516+
referenced_methods,
517+
referenced_supers,
518+
method_args,
519+
method_returns,
520+
method_stacks,
521+
method_supers,
522+
)
490523

491524
@cached_property
492525
def source_code(self):
@@ -717,13 +750,14 @@ def function_callables(self):
717750
"callables module must be provided if output_callables are set in the spec file"
718751
)
719752
fun_str = ""
720-
fun_names = list(set(self.outputs.callables.values()))
721-
fun_names.sort()
722-
for fun_nm in fun_names:
723-
fun = getattr(self.callables_module, fun_nm)
724-
fun_str += inspect.getsource(fun) + "\n"
725-
list_outputs = getattr(self.callables_module, "_list_outputs")
726-
fun_str += inspect.getsource(list_outputs) + "\n"
753+
if list(set(self.outputs.callables.values())):
754+
fun_str = inspect.getsource(self.callables_module)
755+
# fun_names.sort()
756+
# for fun_nm in fun_names:
757+
# fun = getattr(self.callables_module, fun_nm)
758+
# fun_str += inspect.getsource(fun) + "\n"
759+
# list_outputs = getattr(self.callables_module, "_list_outputs")
760+
# fun_str += inspect.getsource(list_outputs) + "\n"
727761
return fun_str
728762

729763
def pydra_type_converter(self, field, spec_type, name):
@@ -975,14 +1009,33 @@ def create_doctests(self, input_fields, nonstd_types):
9751009

9761010
return " Examples\n -------\n\n" + doctest_str
9771011

1012+
def _misc_cleanups(self, body: str) -> str:
1013+
if hasattr(self.nipype_interface, "_cmd"):
1014+
body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"')
1015+
1016+
body = re.sub(
1017+
r"outputs = self\.(output_spec|_outputs)\(\).*$",
1018+
r"outputs = {}",
1019+
body,
1020+
flags=re.MULTILINE,
1021+
)
1022+
body = re.sub(r"\w+runtime\.(stdout|stderr)", r"\1", body)
1023+
body = body.replace("os.getcwd()", "output_dir")
1024+
return body
1025+
9781026
def _get_referenced(
9791027
self,
9801028
method: ty.Callable,
9811029
referenced_funcs: ty.Set[ty.Callable],
982-
referenced_methods: ty.Set[ty.Callable] = None,
1030+
referenced_methods: ty.Set[ty.Callable],
1031+
referenced_supers: ty.Dict[str, ty.Tuple[ty.Callable, type]],
9831032
method_args: ty.Dict[str, ty.List[str]] = None,
9841033
method_returns: ty.Dict[str, ty.List[str]] = None,
1034+
method_stacks: ty.Dict[str, ty.Tuple[ty.Callable]] = None,
1035+
method_supers: ty.Dict[type, ty.Dict[str, str]] = None,
9851036
already_processed: ty.Set[ty.Callable] = None,
1037+
method_stack: ty.Optional[ty.Tuple[ty.Callable]] = None,
1038+
super_base: ty.Optional[type] = None,
9861039
) -> ty.Tuple[ty.Set, ty.Set]:
9871040
"""Get the local functions referenced in the source code
9881041
@@ -1012,6 +1065,12 @@ def _get_referenced(
10121065
already_processed.add(method)
10131066
else:
10141067
already_processed = {method}
1068+
if method_stack is None:
1069+
method_stack = (method,)
1070+
else:
1071+
method_stack += (method,)
1072+
if super_base is None:
1073+
super_base = self.nipype_interface
10151074
method_body = inspect.getsource(method)
10161075
method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments
10171076
return_value = get_return_line(method_body)
@@ -1034,14 +1093,52 @@ def _get_referenced(
10341093
referenced_outputs.update(
10351094
re.findall(return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", method_body)
10361095
)
1096+
for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body):
1097+
super_method = None
1098+
for base in self.nipype_interface.__mro__[1:]:
1099+
if match in base.__dict__: # Found the match
1100+
super_method = getattr(base, match)
1101+
break
1102+
assert super_method is not None, (
1103+
f"Could not find super of '{match}' method in base classes of "
1104+
f"{self.nipype_interface}"
1105+
)
1106+
func_name = self._common_parent_pkg_prefix(base) + match
1107+
if func_name not in referenced_supers:
1108+
referenced_supers[func_name] = (super_method, base)
1109+
method_supers[super_base][match] = func_name
1110+
method_stacks[func_name] = method_stack
1111+
rf_inputs, rf_outputs = self._get_referenced(
1112+
super_method,
1113+
referenced_funcs,
1114+
referenced_methods,
1115+
referenced_supers=referenced_supers,
1116+
method_args=method_args,
1117+
method_returns=method_returns,
1118+
method_stacks=method_stacks,
1119+
method_supers=method_supers,
1120+
already_processed=already_processed,
1121+
method_stack=method_stack,
1122+
super_base=base,
1123+
)
1124+
referenced_inputs.update(rf_inputs)
1125+
referenced_outputs.update(rf_outputs)
1126+
method_args[func_name] = rf_inputs
1127+
method_returns[func_name] = rf_outputs
1128+
method_stacks[func_name] = method_stack
10371129
for func in ref_local_funcs:
10381130
if func in already_processed:
10391131
continue
10401132
rf_inputs, rf_outputs = self._get_referenced(
10411133
func,
10421134
referenced_funcs,
10431135
referenced_methods,
1136+
referenced_supers=referenced_supers,
1137+
method_stacks=method_stacks,
1138+
method_supers=method_supers,
10441139
already_processed=already_processed,
1140+
method_stack=method_stack,
1141+
super_base=super_base,
10451142
)
10461143
referenced_inputs.update(rf_inputs)
10471144
referenced_outputs.update(rf_outputs)
@@ -1052,16 +1149,36 @@ def _get_referenced(
10521149
meth,
10531150
referenced_funcs,
10541151
referenced_methods,
1152+
referenced_supers=referenced_supers,
10551153
method_args=method_args,
10561154
method_returns=method_returns,
1155+
method_stacks=method_stacks,
1156+
method_supers=method_supers,
10571157
already_processed=already_processed,
1158+
method_stack=method_stack,
1159+
super_base=super_base,
10581160
)
10591161
method_args[meth.__name__] = ref_inputs
10601162
method_returns[meth.__name__] = ref_outputs
1163+
method_stacks[meth.__name__] = method_stack
10611164
referenced_inputs.update(ref_inputs)
10621165
referenced_outputs.update(ref_outputs)
10631166
return referenced_inputs, sorted(referenced_outputs)
10641167

1168+
def _common_parent_pkg_prefix(self, base: type) -> str:
1169+
"""Return the common part of two package names"""
1170+
ref_parts = self.nipype_interface.__module__.split(".")
1171+
mod_parts = base.__module__.split(".")
1172+
common = []
1173+
for r_part, m_part in zip(ref_parts, mod_parts):
1174+
if r_part == m_part:
1175+
common.append(r_part)
1176+
else:
1177+
break
1178+
if not common:
1179+
return ""
1180+
return "_".join(common + [base.__name__]) + "__"
1181+
10651182
@cached_property
10661183
def local_functions(self):
10671184
"""Get the functions defined in the same file as the interface"""
@@ -1078,7 +1195,12 @@ def process_method(
10781195
output_names: ty.List[str],
10791196
method_args: ty.Dict[str, ty.List[str]] = None,
10801197
method_returns: ty.Dict[str, ty.List[str]] = None,
1198+
additional_args: ty.Sequence[str] = (),
1199+
new_name: ty.Optional[str] = None,
1200+
super_base: ty.Optional[type] = None,
10811201
):
1202+
if super_base is None:
1203+
super_base = self.nipype_interface
10821204
src = inspect.getsource(method)
10831205
pre, args, post = extract_args(src)
10841206
try:
@@ -1088,11 +1210,16 @@ def process_method(
10881210
if "runtime" in args:
10891211
args.remove("runtime")
10901212
if method.__name__ in self.method_args:
1091-
args += [f"{a}=None" for a in self.method_args[method.__name__]]
1213+
args += [
1214+
f"{a}=None"
1215+
for a in (list(self.method_args[method.__name__]) + additional_args)
1216+
]
10921217
# Insert method args in signature if present
10931218
return_types, method_body = post.split(":", maxsplit=1)
10941219
method_body = method_body.split("\n", maxsplit=1)[1]
1095-
method_body = self.process_method_body(method_body, input_names, output_names)
1220+
method_body = self.process_method_body(
1221+
method_body, input_names, output_names, super_base
1222+
)
10961223
if self.method_returns.get(method.__name__):
10971224
return_args = self.method_returns[method.__name__]
10981225
method_body = (
@@ -1109,11 +1236,19 @@ def process_method(
11091236
)
11101237
pre = re.sub(r"^\s*", "", pre, flags=re.MULTILINE)
11111238
pre = pre.replace("@staticmethod\n", "")
1239+
if new_name:
1240+
pre = re.sub(r"^def (\w+)\(", f"def {new_name}(", pre, flags=re.MULTILINE)
11121241
return f"{pre}{', '.join(args)}{return_types}:\n{method_body}"
11131242

11141243
def process_method_body(
1115-
self, method_body: str, input_names: ty.List[str], output_names: ty.List[str]
1244+
self,
1245+
method_body: str,
1246+
input_names: ty.List[str],
1247+
output_names: ty.List[str],
1248+
super_base: ty.Optional[type] = None,
11161249
) -> str:
1250+
if super_base is None:
1251+
super_base = self.nipype_interface
11171252
return_value = get_return_line(method_body)
11181253
method_body = method_body.replace("if self.output_spec:", "if True:")
11191254
# Replace self.inputs.<name> with <name> in the function body
@@ -1129,6 +1264,7 @@ def process_method_body(
11291264
self.task_name,
11301265
)
11311266
method_body = input_re.sub(r"\1", method_body)
1267+
method_body = self.replace_supers(method_body, super_base)
11321268

11331269
if return_value:
11341270
output_re = re.compile(return_value + r"\[(?:'|\")(\w+)(?:'|\")\]")
@@ -1147,9 +1283,20 @@ def process_method_body(
11471283
method_body = re.sub(
11481284
r"outputs = self.output_spec().*", r"outputs = {}", method_body
11491285
)
1286+
method_body = self._misc_cleanups(method_body)
11501287
return self.unwrap_nested_methods(method_body)
11511288

1152-
def unwrap_nested_methods(self, method_body):
1289+
def replace_supers(self, method_body, super_base=None):
1290+
if super_base is None:
1291+
super_base = self.nipype_interface
1292+
super_name_map = self.method_supers[super_base]
1293+
return re.sub(
1294+
r"super\([^\)]*\)\.(\w+)\(",
1295+
lambda m: super_name_map[m.group(1)] + "(",
1296+
method_body,
1297+
)
1298+
1299+
def unwrap_nested_methods(self, method_body, additional_args=()):
11531300
"""
11541301
Converts nested method calls into function calls
11551302
"""
@@ -1193,7 +1340,11 @@ def unwrap_nested_methods(self, method_body):
11931340
# Insert additional arguments to the method call (which were previously
11941341
# accessed via member attributes)
11951342
new_body += name + insert_args_in_signature(
1196-
args, [f"{a}={a}" for a in self.method_args[name]]
1343+
args,
1344+
[
1345+
f"{a}={a}"
1346+
for a in (list(self.method_args[name]) + list(additional_args))
1347+
],
11971348
)
11981349
method_body = new_body
11991350
# Convert assignment to self attributes into method-scoped variables (hopefully
@@ -1203,6 +1354,8 @@ def unwrap_nested_methods(self, method_body):
12031354
)
12041355
return cleanup_function_body(method_body)
12051356

1357+
SUPER_MAPPINGS = {CommandLine: {"_list_outputs": "{}"}}
1358+
12061359
INPUT_KEYS = [
12071360
"allowed_values",
12081361
"argstr",

nipype2pydra/interface/function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def types_to_names(spec_fields):
7676
lo_src = "\n".join(lo_lines)
7777
method_body += "\n" + lo_src
7878
method_body = self.process_method_body(method_body, input_names, output_names)
79+
method_body = re.sub(
80+
r"self\._results\[(?:'|\")(\w+)(?:'|\")\]", r"\1", method_body
81+
)
7982

8083
used = UsedSymbols.find(
8184
self.nipype_module,

0 commit comments

Comments
 (0)