"""
Nipype workflow executor for thesis framework.
This module provides the execution engine for running Nipype workflows
within your framework, managing execution context, caching, and result collection.
"""
import getpass
import sys
import types
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from thesis.core.config import NipypeConfig
from thesis.core.context import ProcessingContext
from thesis.core.logging import get_logger
from thesis.core.output import EventBus, EventLevel, NodeProgressProtocol
if TYPE_CHECKING:
from nipype.pipeline.engine import Workflow
logger = get_logger(__name__)
__all__ = [
"NipypeExecutor",
"NipypeConfig",
"NipypeStatusCallback",
"apply_nipype_execution_config",
"build_nipype_status_callback",
"count_workflow_nodes",
"run_workflow",
]
[docs]
def apply_nipype_execution_config(
workflow: "Workflow",
config: NipypeConfig,
crash_dir: Optional[Path] = None,
) -> None:
"""Apply NipypeConfig onto a workflow and Nipype's global config.
Must be called on every entry path before ``workflow.run()`` — including
the CLI's batch parallel path that bypasses :class:`NipypeExecutor`.
Two distinct updates happen here and BOTH are load-bearing:
1. ``workflow.config['execution']`` — consulted by Nipype's scheduler
(stop_on_first_crash, remove_unnecessary_outputs, keep_inputs, etc.).
2. ``nipype.config`` global singleton — consulted by
``BaseTraitedSpec._get_sorteddict`` (``interfaces/base/specs.py:307``)
when computing file-input hashes. Without step 2, ``hash_method``
silently falls back to Nipype's package default (``timestamp``) no
matter what the user configured, which is why content-mode trial
YAMLs were producing timestamp hashfiles and cache-missing on every
pipeline run.
Raises:
Exception: any error reading or mutating ``workflow.config`` is
re-raised so the caller can log it against its own logger
(executor vs CLI).
"""
from nipype import config as nipype_global_config
from thesis.core.logging import suppress_nipype_native_logging
suppress_nipype_native_logging()
exec_cfg = workflow.config.get("execution", {})
exec_cfg["stop_on_first_crash"] = config.stop_on_first_crash
exec_cfg["remove_unnecessary_outputs"] = config.remove_unnecessary_outputs
exec_cfg["keep_inputs"] = config.keep_inputs
exec_cfg["hash_method"] = config.hash_method
exec_cfg["use_profiler"] = config.use_profiler
if crash_dir is not None:
crash_dir.mkdir(parents=True, exist_ok=True)
exec_cfg["crashdump_dir"] = str(crash_dir)
workflow.config["execution"] = exec_cfg
nipype_global_config.set("execution", "hash_method", config.hash_method)
_NODE_FINISH_STATUSES = {"end", "done", "finish", "finished", "cached"}
_NODE_FAIL_STATUSES = {"exception", "failed", "crash", "error"}
[docs]
def count_workflow_nodes(workflow: "Workflow") -> int:
"""Count executable nodes in a workflow graph, including nested sub-workflows.
For meta-workflows (e.g. ``tract_synthseg``) whose top-level graph contains
child ``Workflow`` objects rather than ``Node`` objects, this recurses into
each nested workflow so that every leaf node with an ``interface`` attribute
is counted.
Args:
workflow: Nipype workflow whose executable leaf nodes should be counted.
Returns:
Number of executable nodes, or ``0`` if graph introspection fails.
"""
try:
count = 0
for node in workflow._graph.nodes():
if hasattr(node, "interface"):
count += 1
elif hasattr(node, "_graph"):
count += count_workflow_nodes(node)
return count
except (AttributeError, TypeError) as exc:
# ``_graph`` missing/not-yet-built (AttributeError) or not iterable
# (TypeError). Returning 0 sizes progress bars conservatively; log so a
# 0-on-error is distinguishable from a genuinely empty workflow.
logger.warning("Could not introspect workflow graph for node count: {}", exc)
return 0
def _node_label(node: Any) -> str:
"""Build a stable, human-readable node label."""
fullname = getattr(node, "fullname", None)
if fullname:
return str(fullname)
full_name = getattr(node, "full_name", None)
if full_name:
return str(full_name)
name = getattr(node, "name", None)
if name:
return str(name)
return str(node)
[docs]
class NipypeStatusCallback:
"""Picklable callable for Nipype's plugin ``status_callback`` hook.
Must be a top-level class (not a closure) so that ``multiprocessing`` can
pickle anything that transitively references it when ``MultiProcPlugin``
submits work to workers. A nested-function closure here causes
``AttributeError: Can't get local object ...`` during ``_CallItem``
pickling, which wedges the scheduler.
"""
[docs]
def __init__(
self,
progress: Optional[NodeProgressProtocol] = None,
event_bus: Optional[EventBus] = None,
patient_id: str = "",
) -> None:
self._progress = progress
self._event_bus = event_bus
self._patient_id = patient_id
self._completed_nodes: set[str] = set()
self._finished_ok_nodes: set[str] = set()
self._failed_nodes: set[str] = set()
@property
def failed_nodes(self) -> set[str]:
"""Fullnames of nodes whose terminal status was exception/failed/crash/error."""
return set(self._failed_nodes)
@property
def finished_ok_nodes(self) -> set[str]:
"""Fullnames of nodes whose terminal status was end/done/finish/finished/cached."""
return set(self._finished_ok_nodes)
def __call__(self, node: Any, status: Any, *_args: Any, **_kwargs: Any) -> None:
status_text = str(status).lower()
node_name = _node_label(node)
if status_text == "start":
if self._progress is not None:
self._progress.node_started(node_name)
if self._event_bus is not None:
self._event_bus.emit(
f"Node started: {node_name}",
level=EventLevel.INFO,
category="nipype.node",
patient_id=self._patient_id,
metadata={"node": node_name, "status": status_text},
)
return
if status_text in _NODE_FINISH_STATUSES | _NODE_FAIL_STATUSES:
if node_name not in self._completed_nodes:
self._completed_nodes.add(node_name)
if status_text in _NODE_FAIL_STATUSES:
self._failed_nodes.add(node_name)
else:
self._finished_ok_nodes.add(node_name)
if self._progress is not None:
icon = "[FAIL]" if status_text in _NODE_FAIL_STATUSES else "[OK]"
self._progress.node_finished(node_name, status=icon)
if self._event_bus is not None:
level = EventLevel.ERROR if status_text in _NODE_FAIL_STATUSES else EventLevel.INFO
self._event_bus.emit(
f"Node {status_text}: {node_name}",
level=level,
category="nipype.node",
patient_id=self._patient_id,
metadata={"node": node_name, "status": status_text},
)
def __getstate__(self) -> Dict[str, Any]:
# Workers never invoke this callback; ship only inert, picklable state
# so that any object transitively referencing it can still be sent
# through multiprocessing.
return {
"_progress": None,
"_event_bus": None,
"_patient_id": self._patient_id,
"_completed_nodes": set(),
"_finished_ok_nodes": set(),
"_failed_nodes": set(),
}
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
[docs]
def build_nipype_status_callback(
*,
progress: Optional[NodeProgressProtocol] = None,
event_bus: Optional[EventBus] = None,
patient_id: str = "",
) -> NipypeStatusCallback:
"""Create a Nipype status callback for node-level progress updates.
Args:
progress: Optional progress bar updated on node start and finish.
event_bus: Optional event bus for structured node lifecycle events.
patient_id: Patient identifier attached to emitted events.
Returns:
Callback compatible with Nipype plugin ``status_callback`` hooks.
Callers may also read ``failed_nodes`` / ``finished_ok_nodes`` after
the workflow runs to recover per-node final status.
"""
return NipypeStatusCallback(progress=progress, event_bus=event_bus, patient_id=patient_id)
def _uses_fireants_cuda_registration(context: ProcessingContext, workflow: "Workflow") -> bool:
"""Return whether this workflow should avoid fork-based MultiProc execution.
Detection is config-driven: any workflow run with
``registration.method == "fireants"`` and ``registration.fireants.device``
starting with ``"cuda"`` is unsafe under fork-based MultiProc workers,
regardless of whether the registration runs as a top-level workflow or
as an embedded sub-workflow inside a meta-workflow such as
``full_pipeline``.
"""
config = getattr(context, "config", None)
registration_cfg = getattr(config, "registration", None)
if registration_cfg is None:
return False
if getattr(registration_cfg, "method", "") != "fireants":
return False
fireants_cfg = getattr(registration_cfg, "fireants", None)
device = str(getattr(fireants_cfg, "device", "cpu"))
return device.startswith("cuda")
[docs]
class NipypeExecutor:
"""
Executes Nipype workflows within the thesis framework.
This executor bridges your context system with Nipype's workflow engine,
managing:
- Workflow base directory and working directories
- Plugin selection and configuration
- Logging and error handling
Workflow outputs are written by individual nodes to their configured
``output_dir`` paths under ``workflow.base_dir``; downstream code reads
them back from the filesystem rather than from a returned dict.
Example:
>>> from thesis.core import ConfigManager, create_context
>>> from nipype import Workflow
>>>
>>> config = ConfigManager().load_config("default", patient_id="patient_001")
>>> context = create_context("patient_001", config, Path("./data"))
>>> workflow = Workflow(name="example")
>>>
>>> executor = NipypeExecutor(context, config.nipype)
>>> executor.execute(workflow)
"""
[docs]
def __init__(
self,
context: ProcessingContext,
config: Optional[Union[Dict[str, Any], NipypeConfig]] = None,
):
"""
Initialize the executor.
Args:
context: ProcessingContext from your framework
config: Nipype configuration (dict or NipypeConfig instance)
"""
self.context = context
# Convert dict config to NipypeConfig if needed
if isinstance(config, dict):
self.config = NipypeConfig.model_validate(config)
elif isinstance(config, NipypeConfig):
self.config = config
else:
# Try to extract nipype config from context
if hasattr(context, "config") and hasattr(context.config, "nipype"):
nipype_cfg = context.config.nipype
if isinstance(nipype_cfg, dict):
self.config = NipypeConfig.model_validate(nipype_cfg)
else:
self.config = nipype_cfg
else:
self.config = NipypeConfig()
self.logger = get_logger(self.__class__.__name__)
self.workflow = None
[docs]
def execute(
self,
workflow: "Workflow",
*,
event_bus: Optional[EventBus] = None,
progress: Optional[NodeProgressProtocol] = None,
) -> None:
"""
Execute a Nipype workflow.
Workflow outputs are written to disk under ``workflow.base_dir`` (and
whatever ``output_dir`` paths individual nodes were configured with);
downstream code reads them from the filesystem rather than from a
returned dict.
Args:
workflow: nipype.pipeline.engine.Workflow instance
event_bus: Optional event bus for emitting execution events.
progress: Optional workflow progress tracker.
"""
self.workflow = workflow
# Set workflow base directory
workflow_base_dir = self._resolve_working_dir()
workflow_base_dir.mkdir(parents=True, exist_ok=True)
workflow.base_dir = str(workflow_base_dir)
self.logger.info(f"Executing workflow: {workflow.name}")
self.logger.info(f"Working directory: {workflow.base_dir}")
# Apply Nipype execution config options
self._apply_workflow_config(workflow)
try:
# WORKAROUND: Nipype unconditionally imports the Unix-only ``pwd``
# module when loading SGE/PBS/Torque plugin backends. On Windows
# ``pwd`` does not exist, so we inject a minimal stub into
# ``sys.modules`` to prevent ImportError. This is safe because the
# SGE/PBS plugins are never actually *used* on Windows.
if sys.platform.startswith("win"):
try:
import pwd as _pwd # noqa: F401
except ImportError:
pwd_stub = types.ModuleType("pwd")
# Create a stub class for pwd struct_passwd
class _PwdStub:
pw_name: str
def __init__(self, name: str):
self.pw_name = name
setattr(pwd_stub, "getpwuid", lambda _uid: _PwdStub(getpass.getuser()))
setattr(pwd_stub, "getpwnam", lambda name: _PwdStub(name))
sys.modules["pwd"] = pwd_stub
# Get plugin arguments
plugin = self.config.plugin
plugin_args = dict(self.config.plugin_args or {})
plugin_args = self._inject_gpu_plugin_args(plugin_args)
plugin, plugin_args = self._resolve_plugin_for_workflow(workflow, plugin, plugin_args)
self.logger.info(f"Plugin: {plugin}")
self.logger.info(f"Plugin args: {plugin_args}")
if progress is not None or event_bus is not None:
plugin_args["status_callback"] = build_nipype_status_callback(
progress=progress,
event_bus=event_bus,
patient_id=str(getattr(self.context, "patient_id", "")),
)
# Run workflow (avoid verbose kwarg for Nipype compatibility)
workflow.run(
plugin=plugin,
plugin_args=plugin_args,
)
self.logger.info("Workflow execution completed successfully")
except Exception as e:
self.logger.error(f"Workflow execution failed: {e}", exc_info=True)
raise
[docs]
def get_working_directory(self) -> Path:
"""
Get the workflow working directory.
Returns:
Path to the workflow working directory
"""
if self.workflow is None:
return self._resolve_working_dir()
return Path(str(self.workflow.base_dir))
[docs]
def cleanup(self, remove_intermediate: Optional[bool] = None) -> None:
"""
Clean up workflow intermediate files.
Args:
remove_intermediate: Whether to remove intermediate outputs.
If None, uses config setting.
"""
if remove_intermediate is None:
remove_intermediate = self.config.remove_unnecessary_outputs
if not remove_intermediate or self.workflow is None:
self.logger.info("Skipping cleanup of intermediate files")
return
import shutil
working_dir = Path(str(self.workflow.base_dir))
# Remove intermediate directories (keep final outputs)
try:
for item in working_dir.iterdir():
if item.is_dir() and item.name.startswith("_"):
self.logger.debug(f"Removing intermediate directory: {item}")
shutil.rmtree(item)
except Exception as e:
self.logger.warning(f"Error during cleanup: {e}")
def _resolve_working_dir(self) -> Path:
"""Resolve working directory, adding patient_id if not already templated."""
base_str = str(self.config.working_dir)
if "{patient_id}" in base_str:
return Path(base_str.format(patient_id=self.context.patient_id))
return Path(base_str) / str(self.context.patient_id)
def _resolve_crash_dir(self) -> Optional[Path]:
"""Resolve crash directory, if set."""
if self.config.crash_dir is None:
return None
crash_str = str(self.config.crash_dir)
if "{patient_id}" in crash_str:
return Path(crash_str.format(patient_id=self.context.patient_id))
return Path(crash_str)
def _inject_gpu_plugin_args(self, plugin_args: Dict[str, Any]) -> Dict[str, Any]:
"""
Inject ``n_gpu_procs`` and ``n_gpus`` into plugin_args when GPU is enabled.
Reads ``context.config.hardware.gpu_enabled``. If ``True`` and
``n_gpu_procs`` is not already present, sets ``n_gpu_procs=1`` and
``n_gpus=1`` so the Nipype scheduler serializes GPU nodes at launch time.
Existing user-provided values are never overridden.
Args:
plugin_args: Current plugin arguments dict.
Returns:
Updated plugin arguments dict.
"""
try:
hw_cfg = getattr(self.context.config, "hardware", None)
if hw_cfg is None:
return plugin_args
gpu_enabled = bool(getattr(hw_cfg, "gpu_enabled", False))
except AttributeError:
return plugin_args
if gpu_enabled and "n_gpu_procs" not in plugin_args:
plugin_args["n_gpu_procs"] = int(getattr(hw_cfg, "n_gpu_procs", 1))
plugin_args["n_gpus"] = int(getattr(hw_cfg, "n_gpus", 1))
return plugin_args
def _resolve_plugin_for_workflow(
self,
workflow: "Workflow",
plugin: str,
plugin_args: Dict[str, Any],
) -> tuple[str, Dict[str, Any]]:
"""Adjust plugin selection for workflows with backend-specific constraints."""
if plugin == "MultiProc" and _uses_fireants_cuda_registration(self.context, workflow):
self.logger.warning(
"FireANTs CUDA registration is incompatible with fork-based MultiProc workers; "
"falling back to Linear execution for workflow {}.",
workflow.name,
)
safe_plugin_args = dict(plugin_args)
safe_plugin_args.pop("n_gpu_procs", None)
safe_plugin_args.pop("n_gpus", None)
return "Linear", safe_plugin_args
return plugin, plugin_args
def _apply_workflow_config(self, workflow: "Workflow") -> None:
"""Apply Nipype execution config settings to workflow."""
try:
apply_nipype_execution_config(
workflow, self.config, crash_dir=self._resolve_crash_dir()
)
except Exception as exc:
self.logger.warning(f"Could not apply Nipype config: {exc}")
[docs]
def run_workflow(
workflow: "Workflow",
context: ProcessingContext,
config: Optional[Union[Dict[str, Any], NipypeConfig]] = None,
*,
event_bus: Optional[EventBus] = None,
progress: Optional[NodeProgressProtocol] = None,
) -> None:
"""
Convenience helper to execute a Nipype workflow with the thesis context.
Args:
workflow: nipype.pipeline.engine.Workflow instance
context: ProcessingContext
config: Optional NipypeConfig or dict overrides
"""
executor = NipypeExecutor(context=context, config=config)
executor.execute(workflow, event_bus=event_bus, progress=progress)