Source code for CPAC.utils.interfaces.function

from builtins import str, bytes
import inspect

from nipype import logging
from nipype.interfaces.base import (traits, DynamicTraitedSpec, Undefined, isdefined,
                    BaseInterfaceInputSpec)
from nipype.interfaces.io import IOBase, add_traits
from nipype.utils.filemanip import ensure_list
from nipype.utils.functions import getsource, create_function_from_source

iflogger = logging.getLogger('nipype.interface')


class FunctionInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
    function_str = traits.Str(mandatory=True, desc='code for function')


[docs]class Function(IOBase): """Runs arbitrary function as an interface Examples -------- >>> func = 'def func(arg1, arg2=5): return arg1 + arg2' >>> fi = Function(input_names=['arg1', 'arg2'], output_names=['out']) >>> fi.inputs.function_str = func >>> res = fi.run(arg1=1) >>> res.outputs.out 6 """ input_spec = FunctionInputSpec output_spec = DynamicTraitedSpec
[docs] def __init__(self, input_names=None, output_names='out', function=None, imports=None, as_module=False, **inputs): """ Parameters ---------- input_names : single str or list or None names corresponding to function inputs if ``None``, derive input names from function argument names output_names : single str or list names corresponding to function outputs (default: 'out'). if list of length > 1, has to match the number of outputs function : callable callable python object. must be able to execute in an isolated namespace (possibly in concert with the ``imports`` parameter) imports : list of strings list of import statements that allow the function to execute in an otherwise empty namespace """ super(Function, self).__init__(**inputs) if function: if as_module: module = inspect.getmodule(function).__name__ full_name = "%s.%s" % (module, function.__name__) self.inputs.function_str = full_name elif hasattr(function, '__call__'): try: self.inputs.function_str = getsource(function) except IOError: raise Exception('Interface Function does not accept ' 'function objects defined interactively ' 'in a python session') else: if input_names is None: fninfo = function.__code__ elif isinstance(function, (str, bytes)): self.inputs.function_str = function if input_names is None: fninfo = create_function_from_source(function, imports).__code__ else: raise Exception('Unknown type of function') if input_names is None: input_names = fninfo.co_varnames[:fninfo.co_argcount] self.as_module = as_module self.inputs.on_trait_change(self._set_function_string, 'function_str') self._input_names = ensure_list(input_names) self._output_names = ensure_list(output_names) add_traits(self.inputs, [name for name in self._input_names]) self.imports = imports self._out = {} for name in self._output_names: self._out[name] = None
def _set_function_string(self, obj, name, old, new): if name == 'function_str': if self.as_module: module = inspect.getmodule(new).__name__ full_name = "%s.%s" % (module, new.__name__) self.inputs.function_str = full_name elif hasattr(new, '__call__'): function_source = getsource(new) fninfo = new.__code__ elif isinstance(new, (str, bytes)): function_source = new fninfo = create_function_from_source(new, self.imports).__code__ self.inputs.trait_set( trait_change_notify=False, **{ '%s' % name: function_source }) # Update input traits input_names = fninfo.co_varnames[:fninfo.co_argcount] new_names = set(input_names) - set(self._input_names) add_traits(self.inputs, list(new_names)) self._input_names.extend(new_names) def _add_output_traits(self, base): undefined_traits = {} for key in self._output_names: base.add_trait(key, traits.Any) undefined_traits[key] = Undefined base.trait_set(trait_change_notify=False, **undefined_traits) return base def _run_interface(self, runtime): # Create function handle if self.as_module: import importlib pieces = self.inputs.function_str.split('.') module = '.'.join(pieces[:-1]) function = pieces[-1] try: function_handle = getattr(importlib.import_module(module), function) except ImportError: raise RuntimeError('Could not import module: %s' % self.inputs.function_str) else: function_handle = create_function_from_source(self.inputs.function_str, self.imports) # Get function args args = {} for name in self._input_names: value = getattr(self.inputs, name) if isdefined(value): args[name] = value out = function_handle(**args) if len(self._output_names) == 1: self._out[self._output_names[0]] = out else: if isinstance(out, tuple) and \ (len(out) != len(self._output_names)): raise RuntimeError('Mismatch in number of expected outputs') else: for idx, name in enumerate(self._output_names): self._out[name] = out[idx] return runtime def _list_outputs(self): outputs = self._outputs().get() for key in self._output_names: outputs[key] = self._out[key] return outputs