# STATEMENT OF CHANGES:
# This file is derived from sources licensed under the Apache-2.0 terms,
# and this file has been changed.
# CHANGES:
# * Supports just-in-time dynamic memory allocation
# * Skips doctests that require files that we haven't copied over
# * Applies a random seed
# * Supports overriding memory estimates via a log file and a buffer
# * Adds quotation marks around strings in dotfiles
# ORIGINAL WORK'S ATTRIBUTION NOTICE:
# Copyright (c) 2009-2016, Nipype developers
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Prior to release 0.12, Nipype was licensed under a BSD license.
# Modifications Copyright (C) 2022-2024 C-PAC Developers
# This file is part of C-PAC.
# C-PAC is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
# C-PAC is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
# You should have received a copy of the GNU Lesser General Public
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
"""Module to import Nipype Pipeline engine and override some Classes.
See https://fcp-indi.github.io/docs/developer/nodes
for C-PAC-specific documentation.
See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html
for Nipype's documentation.
""" # pylint: disable=line-too-long
from copy import deepcopy
from inspect import Parameter, Signature, signature
import os
import re
from typing import Any, ClassVar, Optional
from numpy import prod
from traits.trait_base import Undefined
from traits.trait_handlers import TraitListObject
from nibabel import load
from nipype.interfaces.utility import Function
from nipype.pipeline import engine as pe
from nipype.pipeline.engine.utils import (
_create_dot_graph,
_replacefunk,
_run_dot,
format_dot,
generate_expanded_graph,
get_print_name,
load_resultfile as _load_resultfile,
)
from nipype.utils.filemanip import fname_presuffix
from nipype.utils.functions import getsource
from CPAC.utils.monitoring import getLogger, WFLOGGER
# set global default mem_gb
DEFAULT_MEM_GB = 2.0
UNDEFINED_SIZE = (42, 42, 42, 1200)
def _check_mem_x_path(mem_x_path):
"""Check if a supplied multiplier path exists.
Parameters
----------
mem_x_path : str, iterable, Undefined or None
Returns
-------
bool
"""
mem_x_path = _grab_first_path(mem_x_path)
try:
return mem_x_path is not Undefined and os.path.exists(mem_x_path)
except (TypeError, ValueError):
return False
def _doctest_skiplines(docstring, lines_to_skip):
"""
Add ' # doctest: +SKIP' to the end of docstring lines.
Used to skip doctests in imported docstrings.
Parameters
----------
docstring : str
lines_to_skip : set or list
Returns
-------
docstring : str
Examples
--------
>>> _doctest_skiplines('skip this line', {'skip this line'})
'skip this line # doctest: +SKIP'
"""
if not isinstance(lines_to_skip, set) and not isinstance(lines_to_skip, list):
msg = "_doctest_skiplines: `lines_to_skip` must be a set or list."
raise TypeError(msg)
return "\n".join(
[
f"{line} # doctest: +SKIP" if line in lines_to_skip else line
for line in docstring.split("\n")
]
)
def _grab_first_path(mem_x_path):
"""Grab the first path if multiple paths for given multiplier input.
Parameters
----------
mem_x_path : str, iterable, Undefined or None
Returns
-------
str, Undefined or None
"""
if isinstance(mem_x_path, (list, TraitListObject, tuple)):
mem_x_path = mem_x_path[0] if len(mem_x_path) else Undefined
return mem_x_path
[docs]
class Node(pe.Node): # noqa: D101
# pylint: disable=empty-docstring,too-many-instance-attributes
__doc__ = _doctest_skiplines(
pe.Node.__doc__, {" >>> realign.inputs.in_files = 'functional.nii'"}
)
[docs]
def __init__(
self,
*args,
mem_gb: Optional[float] = DEFAULT_MEM_GB,
throttle: Optional[bool] = False,
**kwargs,
) -> None:
# pylint: disable=import-outside-toplevel
from CPAC.pipeline.random_state import random_seed
super().__init__(*args, mem_gb=mem_gb, **kwargs)
self.logger = WFLOGGER
self.seed = random_seed()
self.seed_applied = False
self.input_data_shape = Undefined
self._debug = False
if throttle:
self.throttle = True
self.verbose_logger = None
self._mem_x = {}
if "mem_x" in kwargs and isinstance(kwargs["mem_x"], (tuple, list)):
if len(kwargs["mem_x"]) == 3: # noqa: PLR2004
(
self._mem_x["multiplier"],
self._mem_x["file"],
self._mem_x["mode"],
) = kwargs["mem_x"]
else:
self._mem_x["mode"] = "xyzt"
if len(kwargs["mem_x"]) == 2: # noqa: PLR2004
(self._mem_x["multiplier"], self._mem_x["file"]) = kwargs["mem_x"]
else:
self._mem_x["multiplier"] = kwargs["mem_x"]
self._mem_x["file"] = None
else:
delattr(self, "_mem_x")
setattr(self, "skip_timeout", False)
_orig_sig_params: ClassVar[list[tuple[str, Parameter]]] = list(
signature(pe.Node).parameters.items()
)
__init__.__signature__ = Signature(
parameters=[
p[1]
if p[0] != "mem_gb"
else (
"mem_gb",
Parameter(
"mem_gb", Parameter.POSITIONAL_OR_KEYWORD, default=DEFAULT_MEM_GB
),
)[1]
for p in _orig_sig_params[:-1]
]
+ [
Parameter("mem_x", Parameter.KEYWORD_ONLY, default=None),
Parameter("throttle", Parameter.KEYWORD_ONLY, default=False),
_orig_sig_params[-1][1],
]
)
del _orig_sig_params
__init__.__doc__ = re.sub(
r"(?<!\s):",
" :",
"\n".join(
[
pe.Node.__init__.__doc__.rstrip(),
"""
mem_gb : int or float
Estimate (in GB) of constant memory to allocate for this
node.
mem_x : 2-tuple or 3-tuple
(``multiplier``, ``input_file``)
(int or float, str)
(``multiplier``, ``input_file``, ``mode``)
(int or float, str, str)
**Note**
This parameter (``mem_x``) is likely to change in a future
release as we incorporate more factors into memory
estimates.
See also: `⚡️ Setting data- and operation-dependent memory-estimates <https://github.com/FCP-INDI/C-PAC/issues/1509>`_
GitHub epic of issues related to improving Node
memory estimates based on the data and operations
involved.
Multiplier for memory allocation such that ``multiplier``
times ``mode`` of 4-D file at ``input_file`` plus
``self._mem_gb`` equals the total memory allocation for
the node. ``input_file`` can be a Node input string or
an actual path.
``mode`` can be any one of
* 'xyzt' (spatial * temporal) (default if not specified)
* 'xyz' (spatial)
* 't' (temporal)
throttle : bool, optional
Assume this Node will use all available memory if no observation run is
provided.""",
]
),
) # pylint: disable=line-too-long
def _add_flags(self, flags: list[str] | tuple[str, str]) -> None:
r"""
Update an interface's flags by adding (list) or replacing (tuple).
Parameters
----------
flags : list or tuple
If a list, add ``flags`` to ``self.inputs.flags`` or
``self.inputs.args``
If a tuple, remove ``flags[1]`` from and add ``flags[0]``
to ``self.inputs.flags`` or ``self.inputs.args``
"""
def prep_flags(attr):
to_remove = []
if isinstance(flags, tuple):
to_remove += flags[1]
new_flags = flags[0]
else:
new_flags = flags
old_flags = getattr(self.inputs, attr)
if isinstance(old_flags, str):
to_remove.sort(key=lambda x: -x.count(" "))
for flag in to_remove:
if f" {flag} " in old_flags:
old_flags = old_flags.replace(f" {flag}", "")
old_flags = [old_flags]
if isinstance(old_flags, list):
new_flags = [
flag for flag in old_flags if flag not in to_remove
] + new_flags
if attr == "args":
new_flags = " ".join(new_flags)
while " " in new_flags:
new_flags = new_flags.replace(" ", " ")
return new_flags
if hasattr(self.inputs, "flags"):
self.inputs.flags = prep_flags("flags")
else:
self.inputs.args = prep_flags("args")
def _apply_mem_x(self, multiplicand=None):
"""Calculate and memoize a Node's estimated memory footprint.
Parameters
----------
multiplicand : str or int or float or list thereof or
3-or-4-tuple or None
Any of
* path to file(s) with shape to multiply by multiplier
* multiplicand
* shape of image to consider with mode
Returns
-------
number
estimated memory usage (GB)
"""
def parse_multiplicand(multiplicand: Any) -> Optional[int | float]:
"""Return a numeric value or None for a multiplicand."""
if self._debug:
self.verbose_logger.debug(
"%s multiplicand: %s", self.name, multiplicand
)
if isinstance(multiplicand, list):
return max([parse_multiplicand(part) for part in multiplicand])
if isinstance(multiplicand, (int, float)):
return multiplicand
if (
isinstance(multiplicand, tuple)
and 3 <= len(multiplicand) <= 4 # noqa: PLR2004
and all(isinstance(i, (int, float)) for i in multiplicand)
):
return get_data_size(
multiplicand, getattr(self, "_mem_x", {}).get("mode")
)
if _check_mem_x_path(multiplicand):
return get_data_size(
_grab_first_path(multiplicand),
getattr(self, "_mem_x", {}).get("mode"),
)
return 1
if hasattr(self, "_mem_x"):
if self._debug:
self.verbose_logger.debug("%s._mem_x: %s", self.name, self._mem_x)
if multiplicand is None:
multiplicand = self._mem_x_file()
setattr(
self,
"_mem_gb",
(
self._mem_gb
+ self._mem_x.get("multiplier", 0)
* parse_multiplicand(multiplicand)
),
)
try:
if self._mem_gb > 1000:
self.logger.warning(
"%s is estimated to use %.3f GB (%s).",
self.name,
self._mem_gb,
getattr(self, "_mem_x"),
)
except FileNotFoundError:
pass
del self._mem_x
if self._debug:
self.verbose_logger.debug("%s._mem_gb: %s", self.name, self._mem_gb)
return self._mem_gb
def _apply_random_seed(self):
"""Apply flags for the first matched interface."""
# pylint: disable=import-outside-toplevel
from CPAC.pipeline.random_state import random_seed_flags
if isinstance(self.interface, Function):
for rsf, flags in random_seed_flags()["functions"].items():
if self.interface.inputs.function_str == getsource(rsf):
self.interface.inputs.function_str = flags(
self.interface.inputs.function_str
)
self.seed_applied = True
return
for rsf, flags in random_seed_flags()["interfaces"].items():
if isinstance(self.interface, rsf):
self._add_flags(flags)
self.seed_applied = True
return
@property
def mem_gb(self):
"""Get estimated memory (GB)."""
if hasattr(self._interface, "estimated_memory_gb"):
self._mem_gb = self._interface.estimated_memory_gb
self.logger.warning(
'Setting "estimated_memory_gb" on Interfaces has been '
"deprecated as of nipype 1.0, please use Node.mem_gb."
)
if hasattr(self, "_mem_x"):
if self._mem_x["file"] is None:
return self._apply_mem_x()
try:
mem_x_path = getattr(self.inputs, self._mem_x["file"])
except AttributeError as attribute_error:
msg = f"{attribute_error.args[0]} in Node '{self.name}'"
raise AttributeError(msg) from attribute_error
if _check_mem_x_path(mem_x_path):
# constant + mem_x[0] * t
return self._apply_mem_x()
raise FileNotFoundError(
2,
"The memory estimate for Node "
f"'{self.name}' depends on the input "
f"'{self._mem_x['file']}' but "
"no such file or directory",
mem_x_path,
)
return self._mem_gb
@property
def mem_x(self) -> Optional[dict[str, int | float | str]]:
"""Get dict of 'multiplier', 'file', and 'multiplier mode'.
'multiplier' is a memory multiplier.
'file' is an input file.
'multiplier mode' is one of
- spatial * temporal
- spatial only
or
- temporal only.
Returns ``None`` if already consumed or not set.
"""
return getattr(self, "_mem_x", None)
def _mem_x_file(self):
return getattr(self.inputs, getattr(self, "_mem_x", {}).get("file"))
[docs]
def override_mem_gb(self, new_mem_gb):
"""Override the Node's memory estimate with a new value.
Parameters
----------
new_mem_gb : int or float
new memory estimate in GB
"""
if hasattr(self, "_mem_x"):
delattr(self, "_mem_x")
setattr(self, "_mem_gb", new_mem_gb)
[docs]
def run(self, updatehash=False):
self.__doc__ = getattr(super(), "__doc__", "")
if self.seed is not None:
self._apply_random_seed()
if self.seed_applied:
random_state_logger = getLogger("random")
random_state_logger.info(
"%s\t%s",
"# (Atropos constant)"
if "atropos" in self.name
else str(self.seed),
self.name,
)
return super().run(updatehash)
[docs]
class MapNode(Node, pe.MapNode):
# pylint: disable=empty-docstring
__doc__ = _doctest_skiplines(
pe.MapNode.__doc__, {" ... 'functional3.nii']"}
)
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.name.endswith("_"):
self.name = f"{self.name}_"
_parameters: ClassVar[dict[str, Parameter]] = {}
_custom_params: ClassVar[dict[str, bool | float]] = {
"mem_gb": DEFAULT_MEM_GB,
"throttle": False,
}
for param, default in _custom_params.items():
for p in signature(pe.Node).parameters.items():
if p[0] in _custom_params:
_parameters[p[0]] = Parameter(
param, Parameter.POSITIONAL_OR_KEYWORD, default=default
)
else:
_parameters[p[0]] = p[1]
__init__.__signature__ = Signature(parameters=list(_parameters.values()))
del _custom_params, _parameters
[docs]
class Workflow(pe.Workflow):
"""Controls the setup and execution of a pipeline of processes."""
[docs]
def __init__(self, name, base_dir=None, debug=False):
"""Create a workflow object.
Parameters
----------
name : alphanumeric string
unique identifier for the workflow
base_dir : string, optional
path to workflow storage
debug : boolean, optional
enable verbose debug-level logging
"""
import networkx as nx
super().__init__(name, base_dir)
self._debug = debug
self.verbose_logger = getLogger("CPAC.engine") if debug else None
self._graph = nx.DiGraph()
self._nodes_cache = set()
self._nested_workflows_cache = set()
def _configure_exec_nodes(self, graph):
"""Ensure that each node knows where to get inputs from."""
for node in graph.nodes():
node._debug = self._debug # pylint: disable=protected-access
node.verbose_logger = self.verbose_logger
node.input_source = {}
for edge in graph.in_edges(node):
data = graph.get_edge_data(*edge)
for sourceinfo, field in data["connect"]:
node.input_source[field] = (
os.path.join(
edge[0].output_dir(), "result_%s.pklz" % edge[0].name
),
sourceinfo,
)
if node and hasattr(node, "_mem_x"):
if (
isinstance(
node._mem_x, # pylint: disable=protected-access
dict,
)
and node._mem_x["file"] # pylint: disable=protected-access
== field
):
input_resultfile = node.input_source.get(field)
if input_resultfile:
# pylint: disable=protected-access
if isinstance(input_resultfile, tuple):
input_resultfile = input_resultfile[0]
try:
# memoize node._mem_gb if path
# already exists
node._apply_mem_x(
_load_resultfile(input_resultfile).inputs[field]
)
except (FileNotFoundError, KeyError, TypeError):
self._handle_just_in_time_exception(node)
def _get_dot(
self, prefix=None, hierarchy=None, colored=False, simple_form=True, level=0
):
"""Create a dot file with connection info."""
# pylint: disable=invalid-name,protected-access
import networkx as nx
if prefix is None:
prefix = " "
if hierarchy is None:
hierarchy = []
colorset = [
"#FFFFC8", # Y
"#0000FF",
"#B4B4FF",
"#E6E6FF", # B
"#FF0000",
"#FFB4B4",
"#FFE6E6", # R
"#00A300",
"#B4FFB4",
"#E6FFE6", # G
"#0000FF",
"#B4B4FF",
] # loop B
if level > len(colorset) - 2:
level = 3 # Loop back to blue
quoted_prefix = f'"{prefix}"' if len(prefix.strip()) else prefix
dotlist = [f'{quoted_prefix}label="{self.name}";']
for node in nx.topological_sort(self._graph):
fullname = ".".join([*hierarchy, node.fullname])
nodename = fullname.replace(".", "_")
if not isinstance(node, Workflow):
node_class_name = get_print_name(node, simple_form=simple_form)
if not simple_form:
node_class_name = ".".join(node_class_name.split(".")[1:])
if hasattr(node, "iterables") and node.iterables:
dotlist.append(
f'"{nodename}"[label="{node_class_name}", '
"shape=box3d, style=filled, color=black, "
"colorscheme=greys7 fillcolor=2];"
)
elif colored:
dotlist.append(
f'"{nodename}"[label="'
f'{node_class_name}", style=filled,'
f' fillcolor="{colorset[level]}"];'
)
else:
dotlist.append(f'"{nodename}"[label="' f'{node_class_name}"];')
for node in nx.topological_sort(self._graph):
if isinstance(node, Workflow):
fullname = ".".join([*hierarchy, node.fullname])
nodename = fullname.replace(".", "_")
dotlist.append(f'subgraph "cluster_{nodename}" {{')
if colored:
dotlist.append(
f'{prefix}{prefix}edge [color="' f'{colorset[level + 1]}"];'
)
dotlist.append(f"{prefix}{prefix}style=filled;")
dotlist.append(
f"{prefix}{prefix}fillcolor=" f'"{colorset[level + 2]}";'
)
dotlist.append(
node._get_dot(
prefix=prefix + prefix,
hierarchy=[*hierarchy, self.name],
colored=colored,
simple_form=simple_form,
level=level + 3,
)
)
dotlist.append("}")
else:
for subnode in self._graph.successors(node):
if node._hierarchy != subnode._hierarchy:
continue
if not isinstance(subnode, Workflow):
nodefullname = ".".join([*hierarchy, node.fullname])
subnodefullname = ".".join([*hierarchy, subnode.fullname])
nodename = nodefullname.replace(".", "_")
subnodename = subnodefullname.replace(".", "_")
for _ in self._graph.get_edge_data(node, subnode)["connect"]:
dotlist.append(f'"{nodename}" -> "{subnodename}";')
WFLOGGER.debug("connection: %s", dotlist[-1])
# add between workflow connections
for u, v, d in self._graph.edges(data=True):
uname = ".".join([*hierarchy, u.fullname])
vname = ".".join([*hierarchy, v.fullname])
for src, dest in d["connect"]:
uname1 = uname
vname1 = vname
if isinstance(src, tuple):
srcname = src[0]
else:
srcname = src
if "." in srcname:
uname1 += "." + ".".join(srcname.split(".")[:-1])
if "." in dest and "@" not in dest:
if not isinstance(v, Workflow):
if "datasink" not in str(v._interface.__class__).lower():
vname1 += "." + ".".join(dest.split(".")[:-1])
else:
vname1 += "." + ".".join(dest.split(".")[:-1])
if uname1.split(".")[:-1] != vname1.split(".")[:-1]:
dotlist.append(
f'"{uname1.replace(".", "_")}" -> '
f'"{vname1.replace(".", "_")}";'
)
WFLOGGER.debug("cross connection: %s", dotlist[-1])
return ("\n" + prefix).join(dotlist)
def _handle_just_in_time_exception(self, node):
# pylint: disable=protected-access
if hasattr(self, "_local_func_scans"):
node._apply_mem_x(self._local_func_scans) # pylint: disable=no-member
else:
# TODO: handle S3 files
node._apply_mem_x(UNDEFINED_SIZE)
[docs]
def write_graph(
self,
dotfilename="graph.dot",
graph2use="hierarchical",
format="png", # noqa: A002
simple_form=True,
):
graphtypes = ["orig", "flat", "hierarchical", "exec", "colored"]
if graph2use not in graphtypes:
raise ValueError(
"Unknown graph2use keyword. Must be one of: " + str(graphtypes)
)
base_dir, dotfilename = os.path.split(dotfilename)
if base_dir == "":
if self.base_dir:
base_dir = self.base_dir
if self.name:
base_dir = os.path.join(base_dir, self.name)
else:
base_dir = os.getcwd()
os.makedirs(base_dir, exist_ok=True)
if graph2use in ["hierarchical", "colored"]:
if self.name[:1].isdigit(): # these graphs break if int
msg = (
f"{graph2use} graph failed, workflow name "
"cannot begin with a number"
)
raise ValueError(msg)
dotfilename = os.path.join(base_dir, dotfilename)
self.write_hierarchical_dotfile(
dotfilename=dotfilename,
colored=graph2use == "colored",
simple_form=simple_form,
)
outfname = format_dot(dotfilename, format=format)
else:
graph = self._graph
if graph2use in ["flat", "exec"]:
graph = self._create_flat_graph()
if graph2use == "exec":
graph = generate_expanded_graph(deepcopy(graph))
outfname = export_graph(
graph,
base_dir,
dotfilename=dotfilename,
format=format,
simple_form=simple_form,
)
WFLOGGER.info(
"Generated workflow graph: %s (graph2use=%s, simple_form=%s).",
outfname,
graph2use,
simple_form,
)
return outfname
write_graph.__doc__ = pe.Workflow.write_graph.__doc__
def write_hierarchical_dotfile(
self, dotfilename=None, colored=False, simple_form=True
):
# pylint: disable=invalid-name
dotlist = [f'digraph "{self.name}"{{']
dotlist.append(
self._get_dot(prefix=" ", colored=colored, simple_form=simple_form)
)
dotlist.append("}")
dotstr = "\n".join(dotlist)
if dotfilename:
with open(dotfilename, "wt", encoding="utf-8") as fp:
fp.writelines(dotstr)
fp.close()
else:
WFLOGGER.info(dotstr)
def get_data_size(filepath, mode="xyzt"):
"""Return the size of a functional image (x * y * z * t).
Parameters
----------
filepath : str or path
path to image file
OR
4-tuple
stand-in dimensions (x, y, z, t)
mode : str
One of:
* 'xyzt' (all dimensions multiplied) (DEFAULT)
* 'xyz' (spatial dimensions multiplied)
* 't' (number of TRs)
Returns
-------
int or float
"""
if isinstance(filepath, str):
data_shape = load(filepath).shape
elif isinstance(filepath, tuple) and len(filepath) == 4: # noqa: PLR2004
data_shape = filepath
if mode == "t":
# if the data has muptiple TRs, return that number
if len(data_shape) > 3: # noqa: PLR2004
return data_shape[3]
# otherwise return 1
return 1
if mode == "xyz":
return prod(data_shape[0:3]).item()
return prod(data_shape).item()
def export_graph(
graph_in,
base_dir=None,
show=False,
use_execgraph=False,
show_connectinfo=False,
dotfilename="graph.dot",
format="png", # noqa: A002
simple_form=True,
):
"""Display the graph layout of the pipeline.
This function requires that pygraphviz and matplotlib are available on
the system.
Parameters
----------
show : boolean
Indicate whether to generate pygraphviz output fromn
networkx. default [False]
use_execgraph : boolean
Indicates whether to use the specification graph or the
execution graph. default [False]
show_connectioninfo : boolean
Indicates whether to show the edge data on the graph. This
makes the graph rather cluttered. default [False]
"""
import networkx as nx
graph = deepcopy(graph_in)
if use_execgraph:
graph = generate_expanded_graph(graph)
WFLOGGER.debug("using execgraph")
else:
WFLOGGER.debug("using input graph")
if base_dir is None:
base_dir = os.getcwd()
os.makedirs(base_dir, exist_ok=True)
out_dot = fname_presuffix(
dotfilename, suffix="_detailed.dot", use_ext=False, newpath=base_dir
)
_write_detailed_dot(graph, out_dot)
# Convert .dot if format != 'dot'
outfname, res = _run_dot(out_dot, format_ext=format)
if res is not None and res.runtime.returncode:
WFLOGGER.warning("dot2png: %s", res.runtime.stderr)
pklgraph = _create_dot_graph(graph, show_connectinfo, simple_form)
simple_dot = fname_presuffix(
dotfilename, suffix=".dot", use_ext=False, newpath=base_dir
)
nx.drawing.nx_pydot.write_dot(pklgraph, simple_dot)
# Convert .dot if format != 'dot'
simplefname, res = _run_dot(simple_dot, format_ext=format)
if res is not None and res.runtime.returncode:
WFLOGGER.warning("dot2png: %s", res.runtime.stderr)
if show:
pos = nx.graphviz_layout(pklgraph, prog="dot")
nx.draw(pklgraph, pos)
if show_connectinfo:
nx.draw_networkx_edge_labels(pklgraph, pos)
return simplefname if simple_form else outfname
def _write_detailed_dot(graph, dotfilename):
r"""
Create a dot file with connection info ::
digraph structs {
node [shape=record];
struct1 [label="<f0> left|<f1> middle|<f2> right"];
struct2 [label="<f0> one|<f1> two"];
struct3 [label="hello\nworld |{ b |{c|<here> d|e}| f}| g | h"];
struct1:f1 -> struct2:f0;
struct1:f0 -> struct2:f1;
struct1:f2 -> struct3:here;
}
""" # noqa: D205,D400
# pylint: disable=invalid-name
import networkx as nx
text = ["digraph structs {", "node [shape=record];"]
# write nodes
edges = []
for n in nx.topological_sort(graph):
nodename = n.itername
inports = []
for u, v, d in graph.in_edges(nbunch=n, data=True):
for cd in d["connect"]:
if isinstance(cd[0], (str, bytes)):
outport = cd[0]
else:
outport = cd[0][0]
inport = cd[1]
ipstrip = f"in{_replacefunk(inport)}"
opstrip = f"out{_replacefunk(outport)}"
edges.append(
f'"{u.itername.replace(".", "")}":'
f'"{opstrip}":e -> '
f'"{v.itername.replace(".", "")}":'
f'"{ipstrip}":w;'
)
if inport not in inports:
inports.append(inport)
inputstr = (
["{IN"]
+ [f"|<in{_replacefunk(ip)}> {ip}" for ip in sorted(inports)]
+ ["}"]
)
outports = []
for u, v, d in graph.out_edges(nbunch=n, data=True):
for cd in d["connect"]:
if isinstance(cd[0], (str, bytes)):
outport = cd[0]
else:
outport = cd[0][0]
if outport not in outports:
outports.append(outport)
outputstr = (
["{OUT"]
+ [f"|<out{_replacefunk(oport)}> {oport}" for oport in sorted(outports)]
+ ["}"]
)
srcpackage = ""
if hasattr(n, "_interface"):
pkglist = n.interface.__class__.__module__.split(".")
if len(pkglist) > 2: # noqa: PLR2004
srcpackage = pkglist[2]
srchierarchy = ".".join(nodename.split(".")[1:-1])
nodenamestr = f"{{ {nodename.split('.')[-1]} | {srcpackage} | {srchierarchy} }}"
text += [
f'"{nodename.replace(".", "")}" [label='
f'"{"".join(inputstr)}|{nodenamestr}|{"".join(outputstr)}"];'
]
# write edges
for edge in sorted(edges):
text.append(edge)
text.append("}")
with open(dotfilename, "wt", encoding="utf-8") as filep:
filep.write("\n".join(text))
return text