Source code for thesis.core.nipype.executor

"""
Nipype workflow executor for thesis framework.

This module provides the execution engine for running Nipype workflows
within your framework, managing execution context, caching, and result collection.
"""

import getpass
import sys
import types
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from thesis.core.config import NipypeConfig
from thesis.core.context import ProcessingContext
from thesis.core.logging import get_logger
from thesis.core.output import EventBus, EventLevel, NodeProgressProtocol

if TYPE_CHECKING:
    from nipype.pipeline.engine import Workflow

logger = get_logger(__name__)

__all__ = [
    "NipypeExecutor",
    "NipypeConfig",
    "NipypeStatusCallback",
    "apply_nipype_execution_config",
    "build_nipype_status_callback",
    "count_workflow_nodes",
    "run_workflow",
]


[docs] def apply_nipype_execution_config( workflow: "Workflow", config: NipypeConfig, crash_dir: Optional[Path] = None, ) -> None: """Apply NipypeConfig onto a workflow and Nipype's global config. Must be called on every entry path before ``workflow.run()`` — including the CLI's batch parallel path that bypasses :class:`NipypeExecutor`. Two distinct updates happen here and BOTH are load-bearing: 1. ``workflow.config['execution']`` — consulted by Nipype's scheduler (stop_on_first_crash, remove_unnecessary_outputs, keep_inputs, etc.). 2. ``nipype.config`` global singleton — consulted by ``BaseTraitedSpec._get_sorteddict`` (``interfaces/base/specs.py:307``) when computing file-input hashes. Without step 2, ``hash_method`` silently falls back to Nipype's package default (``timestamp``) no matter what the user configured, which is why content-mode trial YAMLs were producing timestamp hashfiles and cache-missing on every pipeline run. Raises: Exception: any error reading or mutating ``workflow.config`` is re-raised so the caller can log it against its own logger (executor vs CLI). """ from nipype import config as nipype_global_config from thesis.core.logging import suppress_nipype_native_logging suppress_nipype_native_logging() exec_cfg = workflow.config.get("execution", {}) exec_cfg["stop_on_first_crash"] = config.stop_on_first_crash exec_cfg["remove_unnecessary_outputs"] = config.remove_unnecessary_outputs exec_cfg["keep_inputs"] = config.keep_inputs exec_cfg["hash_method"] = config.hash_method exec_cfg["use_profiler"] = config.use_profiler if crash_dir is not None: crash_dir.mkdir(parents=True, exist_ok=True) exec_cfg["crashdump_dir"] = str(crash_dir) workflow.config["execution"] = exec_cfg nipype_global_config.set("execution", "hash_method", config.hash_method)
_NODE_FINISH_STATUSES = {"end", "done", "finish", "finished", "cached"} _NODE_FAIL_STATUSES = {"exception", "failed", "crash", "error"}
[docs] def count_workflow_nodes(workflow: "Workflow") -> int: """Count executable nodes in a workflow graph, including nested sub-workflows. For meta-workflows (e.g. ``tract_synthseg``) whose top-level graph contains child ``Workflow`` objects rather than ``Node`` objects, this recurses into each nested workflow so that every leaf node with an ``interface`` attribute is counted. Args: workflow: Nipype workflow whose executable leaf nodes should be counted. Returns: Number of executable nodes, or ``0`` if graph introspection fails. """ try: count = 0 for node in workflow._graph.nodes(): if hasattr(node, "interface"): count += 1 elif hasattr(node, "_graph"): count += count_workflow_nodes(node) return count except (AttributeError, TypeError) as exc: # ``_graph`` missing/not-yet-built (AttributeError) or not iterable # (TypeError). Returning 0 sizes progress bars conservatively; log so a # 0-on-error is distinguishable from a genuinely empty workflow. logger.warning("Could not introspect workflow graph for node count: {}", exc) return 0
def _node_label(node: Any) -> str: """Build a stable, human-readable node label.""" fullname = getattr(node, "fullname", None) if fullname: return str(fullname) full_name = getattr(node, "full_name", None) if full_name: return str(full_name) name = getattr(node, "name", None) if name: return str(name) return str(node)
[docs] class NipypeStatusCallback: """Picklable callable for Nipype's plugin ``status_callback`` hook. Must be a top-level class (not a closure) so that ``multiprocessing`` can pickle anything that transitively references it when ``MultiProcPlugin`` submits work to workers. A nested-function closure here causes ``AttributeError: Can't get local object ...`` during ``_CallItem`` pickling, which wedges the scheduler. """
[docs] def __init__( self, progress: Optional[NodeProgressProtocol] = None, event_bus: Optional[EventBus] = None, patient_id: str = "", ) -> None: self._progress = progress self._event_bus = event_bus self._patient_id = patient_id self._completed_nodes: set[str] = set() self._finished_ok_nodes: set[str] = set() self._failed_nodes: set[str] = set()
@property def failed_nodes(self) -> set[str]: """Fullnames of nodes whose terminal status was exception/failed/crash/error.""" return set(self._failed_nodes) @property def finished_ok_nodes(self) -> set[str]: """Fullnames of nodes whose terminal status was end/done/finish/finished/cached.""" return set(self._finished_ok_nodes) def __call__(self, node: Any, status: Any, *_args: Any, **_kwargs: Any) -> None: status_text = str(status).lower() node_name = _node_label(node) if status_text == "start": if self._progress is not None: self._progress.node_started(node_name) if self._event_bus is not None: self._event_bus.emit( f"Node started: {node_name}", level=EventLevel.INFO, category="nipype.node", patient_id=self._patient_id, metadata={"node": node_name, "status": status_text}, ) return if status_text in _NODE_FINISH_STATUSES | _NODE_FAIL_STATUSES: if node_name not in self._completed_nodes: self._completed_nodes.add(node_name) if status_text in _NODE_FAIL_STATUSES: self._failed_nodes.add(node_name) else: self._finished_ok_nodes.add(node_name) if self._progress is not None: icon = "[FAIL]" if status_text in _NODE_FAIL_STATUSES else "[OK]" self._progress.node_finished(node_name, status=icon) if self._event_bus is not None: level = EventLevel.ERROR if status_text in _NODE_FAIL_STATUSES else EventLevel.INFO self._event_bus.emit( f"Node {status_text}: {node_name}", level=level, category="nipype.node", patient_id=self._patient_id, metadata={"node": node_name, "status": status_text}, ) def __getstate__(self) -> Dict[str, Any]: # Workers never invoke this callback; ship only inert, picklable state # so that any object transitively referencing it can still be sent # through multiprocessing. return { "_progress": None, "_event_bus": None, "_patient_id": self._patient_id, "_completed_nodes": set(), "_finished_ok_nodes": set(), "_failed_nodes": set(), } def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state)
[docs] def build_nipype_status_callback( *, progress: Optional[NodeProgressProtocol] = None, event_bus: Optional[EventBus] = None, patient_id: str = "", ) -> NipypeStatusCallback: """Create a Nipype status callback for node-level progress updates. Args: progress: Optional progress bar updated on node start and finish. event_bus: Optional event bus for structured node lifecycle events. patient_id: Patient identifier attached to emitted events. Returns: Callback compatible with Nipype plugin ``status_callback`` hooks. Callers may also read ``failed_nodes`` / ``finished_ok_nodes`` after the workflow runs to recover per-node final status. """ return NipypeStatusCallback(progress=progress, event_bus=event_bus, patient_id=patient_id)
def _uses_fireants_cuda_registration(context: ProcessingContext, workflow: "Workflow") -> bool: """Return whether this workflow should avoid fork-based MultiProc execution. Detection is config-driven: any workflow run with ``registration.method == "fireants"`` and ``registration.fireants.device`` starting with ``"cuda"`` is unsafe under fork-based MultiProc workers, regardless of whether the registration runs as a top-level workflow or as an embedded sub-workflow inside a meta-workflow such as ``full_pipeline``. """ config = getattr(context, "config", None) registration_cfg = getattr(config, "registration", None) if registration_cfg is None: return False if getattr(registration_cfg, "method", "") != "fireants": return False fireants_cfg = getattr(registration_cfg, "fireants", None) device = str(getattr(fireants_cfg, "device", "cpu")) return device.startswith("cuda")
[docs] class NipypeExecutor: """ Executes Nipype workflows within the thesis framework. This executor bridges your context system with Nipype's workflow engine, managing: - Workflow base directory and working directories - Plugin selection and configuration - Logging and error handling Workflow outputs are written by individual nodes to their configured ``output_dir`` paths under ``workflow.base_dir``; downstream code reads them back from the filesystem rather than from a returned dict. Example: >>> from thesis.core import ConfigManager, create_context >>> from nipype import Workflow >>> >>> config = ConfigManager().load_config("default", patient_id="patient_001") >>> context = create_context("patient_001", config, Path("./data")) >>> workflow = Workflow(name="example") >>> >>> executor = NipypeExecutor(context, config.nipype) >>> executor.execute(workflow) """
[docs] def __init__( self, context: ProcessingContext, config: Optional[Union[Dict[str, Any], NipypeConfig]] = None, ): """ Initialize the executor. Args: context: ProcessingContext from your framework config: Nipype configuration (dict or NipypeConfig instance) """ self.context = context # Convert dict config to NipypeConfig if needed if isinstance(config, dict): self.config = NipypeConfig.model_validate(config) elif isinstance(config, NipypeConfig): self.config = config else: # Try to extract nipype config from context if hasattr(context, "config") and hasattr(context.config, "nipype"): nipype_cfg = context.config.nipype if isinstance(nipype_cfg, dict): self.config = NipypeConfig.model_validate(nipype_cfg) else: self.config = nipype_cfg else: self.config = NipypeConfig() self.logger = get_logger(self.__class__.__name__) self.workflow = None
[docs] def execute( self, workflow: "Workflow", *, event_bus: Optional[EventBus] = None, progress: Optional[NodeProgressProtocol] = None, ) -> None: """ Execute a Nipype workflow. Workflow outputs are written to disk under ``workflow.base_dir`` (and whatever ``output_dir`` paths individual nodes were configured with); downstream code reads them from the filesystem rather than from a returned dict. Args: workflow: nipype.pipeline.engine.Workflow instance event_bus: Optional event bus for emitting execution events. progress: Optional workflow progress tracker. """ self.workflow = workflow # Set workflow base directory workflow_base_dir = self._resolve_working_dir() workflow_base_dir.mkdir(parents=True, exist_ok=True) workflow.base_dir = str(workflow_base_dir) self.logger.info(f"Executing workflow: {workflow.name}") self.logger.info(f"Working directory: {workflow.base_dir}") # Apply Nipype execution config options self._apply_workflow_config(workflow) try: # WORKAROUND: Nipype unconditionally imports the Unix-only ``pwd`` # module when loading SGE/PBS/Torque plugin backends. On Windows # ``pwd`` does not exist, so we inject a minimal stub into # ``sys.modules`` to prevent ImportError. This is safe because the # SGE/PBS plugins are never actually *used* on Windows. if sys.platform.startswith("win"): try: import pwd as _pwd # noqa: F401 except ImportError: pwd_stub = types.ModuleType("pwd") # Create a stub class for pwd struct_passwd class _PwdStub: pw_name: str def __init__(self, name: str): self.pw_name = name setattr(pwd_stub, "getpwuid", lambda _uid: _PwdStub(getpass.getuser())) setattr(pwd_stub, "getpwnam", lambda name: _PwdStub(name)) sys.modules["pwd"] = pwd_stub # Get plugin arguments plugin = self.config.plugin plugin_args = dict(self.config.plugin_args or {}) plugin_args = self._inject_gpu_plugin_args(plugin_args) plugin, plugin_args = self._resolve_plugin_for_workflow(workflow, plugin, plugin_args) self.logger.info(f"Plugin: {plugin}") self.logger.info(f"Plugin args: {plugin_args}") if progress is not None or event_bus is not None: plugin_args["status_callback"] = build_nipype_status_callback( progress=progress, event_bus=event_bus, patient_id=str(getattr(self.context, "patient_id", "")), ) # Run workflow (avoid verbose kwarg for Nipype compatibility) workflow.run( plugin=plugin, plugin_args=plugin_args, ) self.logger.info("Workflow execution completed successfully") except Exception as e: self.logger.error(f"Workflow execution failed: {e}", exc_info=True) raise
[docs] def get_working_directory(self) -> Path: """ Get the workflow working directory. Returns: Path to the workflow working directory """ if self.workflow is None: return self._resolve_working_dir() return Path(str(self.workflow.base_dir))
[docs] def cleanup(self, remove_intermediate: Optional[bool] = None) -> None: """ Clean up workflow intermediate files. Args: remove_intermediate: Whether to remove intermediate outputs. If None, uses config setting. """ if remove_intermediate is None: remove_intermediate = self.config.remove_unnecessary_outputs if not remove_intermediate or self.workflow is None: self.logger.info("Skipping cleanup of intermediate files") return import shutil working_dir = Path(str(self.workflow.base_dir)) # Remove intermediate directories (keep final outputs) try: for item in working_dir.iterdir(): if item.is_dir() and item.name.startswith("_"): self.logger.debug(f"Removing intermediate directory: {item}") shutil.rmtree(item) except Exception as e: self.logger.warning(f"Error during cleanup: {e}")
def _resolve_working_dir(self) -> Path: """Resolve working directory, adding patient_id if not already templated.""" base_str = str(self.config.working_dir) if "{patient_id}" in base_str: return Path(base_str.format(patient_id=self.context.patient_id)) return Path(base_str) / str(self.context.patient_id) def _resolve_crash_dir(self) -> Optional[Path]: """Resolve crash directory, if set.""" if self.config.crash_dir is None: return None crash_str = str(self.config.crash_dir) if "{patient_id}" in crash_str: return Path(crash_str.format(patient_id=self.context.patient_id)) return Path(crash_str) def _inject_gpu_plugin_args(self, plugin_args: Dict[str, Any]) -> Dict[str, Any]: """ Inject ``n_gpu_procs`` and ``n_gpus`` into plugin_args when GPU is enabled. Reads ``context.config.hardware.gpu_enabled``. If ``True`` and ``n_gpu_procs`` is not already present, sets ``n_gpu_procs=1`` and ``n_gpus=1`` so the Nipype scheduler serializes GPU nodes at launch time. Existing user-provided values are never overridden. Args: plugin_args: Current plugin arguments dict. Returns: Updated plugin arguments dict. """ try: hw_cfg = getattr(self.context.config, "hardware", None) if hw_cfg is None: return plugin_args gpu_enabled = bool(getattr(hw_cfg, "gpu_enabled", False)) except AttributeError: return plugin_args if gpu_enabled and "n_gpu_procs" not in plugin_args: plugin_args["n_gpu_procs"] = int(getattr(hw_cfg, "n_gpu_procs", 1)) plugin_args["n_gpus"] = int(getattr(hw_cfg, "n_gpus", 1)) return plugin_args def _resolve_plugin_for_workflow( self, workflow: "Workflow", plugin: str, plugin_args: Dict[str, Any], ) -> tuple[str, Dict[str, Any]]: """Adjust plugin selection for workflows with backend-specific constraints.""" if plugin == "MultiProc" and _uses_fireants_cuda_registration(self.context, workflow): self.logger.warning( "FireANTs CUDA registration is incompatible with fork-based MultiProc workers; " "falling back to Linear execution for workflow {}.", workflow.name, ) safe_plugin_args = dict(plugin_args) safe_plugin_args.pop("n_gpu_procs", None) safe_plugin_args.pop("n_gpus", None) return "Linear", safe_plugin_args return plugin, plugin_args def _apply_workflow_config(self, workflow: "Workflow") -> None: """Apply Nipype execution config settings to workflow.""" try: apply_nipype_execution_config( workflow, self.config, crash_dir=self._resolve_crash_dir() ) except Exception as exc: self.logger.warning(f"Could not apply Nipype config: {exc}")
[docs] def run_workflow( workflow: "Workflow", context: ProcessingContext, config: Optional[Union[Dict[str, Any], NipypeConfig]] = None, *, event_bus: Optional[EventBus] = None, progress: Optional[NodeProgressProtocol] = None, ) -> None: """ Convenience helper to execute a Nipype workflow with the thesis context. Args: workflow: nipype.pipeline.engine.Workflow instance context: ProcessingContext config: Optional NipypeConfig or dict overrides """ executor = NipypeExecutor(context=context, config=config) executor.execute(workflow, event_bus=event_bus, progress=progress)