Source code for thesis.cli

"""
Command-line interface for the thesis framework.

Provides CLI commands for running custom Nipype workflows.
"""

import collections
import contextlib
import fnmatch
import importlib
import inspect
import json
import pkgutil
import re
import sys
import time
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import click
import yaml

from thesis import __version__
from thesis.core.config import ConfigManager, PipelineConfig, merge_configs
from thesis.core.context import create_context
from thesis.core.exceptions import ConfigurationError, ProcessingError
from thesis.core.logging import (
    get_logger,
    set_console_level,
    setup_logging,
    suppress_nipype_native_logging,
    suspend_console_logging,
)
from thesis.core.nipype import run_workflow
from thesis.core.nipype.executor import (
    apply_nipype_execution_config,
    build_nipype_status_callback,
    count_workflow_nodes,
)
from thesis.core.output import (
    BatchSummary,
    EventLevel,
    OutputConfig,
    OutputMode,
    OutputRenderer,
    RunResult,
    RunStatus,
    RunSummary,
    SummaryDetail,
    get_event_bus,
    reset_event_bus,
)
from thesis.core.registry import WORKFLOW_REGISTRY, WorkflowEntry

logger = get_logger(__name__)


# ---------------------------------------------------------------------------
# Click-based progress wrapper for Nipype node callbacks
# ---------------------------------------------------------------------------


[docs] class ClickNodeProgress: """Simple wrapper for node progress using Click's progressbar. This class provides the same interface as WorkflowProgress (node_started, node_finished, start, stop) so it can be used with Nipype's status callback mechanism without changing executor.py. """
[docs] def __init__(self, total: int, label: str, enabled: bool = True): self.total = total self.label = label self.enabled = enabled self._current_node: Optional[str] = None self._bar: Optional[Any] = None self._completed = 0 self._ctx: Optional[Any] = None
[docs] def start(self) -> None: if not self.enabled or self.total <= 0: return # Create the progressbar instance self._ctx = click.progressbar( length=self.total, label=self.label, item_show_func=lambda x: self._current_node or "", show_pos=True, show_eta=True, file=sys.stderr, color=True, width=40, ) self._log_cm = suspend_console_logging() self._log_cm.__enter__() self._bar = self._ctx.__enter__()
[docs] def stop(self) -> None: if self._ctx is not None: self._ctx.__exit__(None, None, None) self._ctx = None self._bar = None if hasattr(self, "_log_cm"): self._log_cm.__exit__(None, None, None)
[docs] def node_started(self, node_name: str) -> None: self._current_node = node_name if self._bar is not None: self._bar.render_progress() # Force redraw with new label
[docs] def node_finished(self, node_name: str, status: str = "") -> None: self._completed += 1 if self._bar is not None: self._bar.update(1, self._current_node)
# --------------------------------------------------------------------------- # GPU startup check # --------------------------------------------------------------------------- # These module-level variables ensure the GPU check (and its interactive # prompt) runs at most once per process, regardless of how many patients # are processed in sequence. _gpu_checked: bool = False _gpu_available: bool = False def _resolve_gpu(cfg: PipelineConfig) -> None: """ Validate GPU availability and update ``cfg.hardware.gpu_enabled`` in place. Called once per process immediately after the config is loaded. On the first call it runs :func:`thesis.core.gpu.check_gpu`, logs the outcome, and — if the GPU was requested but is not available — prompts the user: * On interactive terminals: asks whether to fall back to CPU or abort. * On non-interactive terminals (e.g. SLURM batch scripts): automatically falls back to CPU and logs a warning without blocking. Subsequent calls (e.g. when processing multiple patients in one run) use the cached result from the first call. Args: cfg: Loaded ``PipelineConfig``; ``cfg.hardware.gpu_enabled`` is set to ``False`` in place when falling back to CPU. """ global _gpu_checked, _gpu_available if not cfg.hardware.gpu_enabled: return # GPU not requested — nothing to check if not _gpu_checked: import sys from thesis.core.gpu import check_gpu status = check_gpu() _gpu_available = status.available _gpu_checked = True if status.available: logger.info("GPU acceleration available: {}", status.reason) else: logger.warning("GPU requested but not available: {}", status.reason) if sys.stdin.isatty(): click.echo( f"\n [WARN] GPU acceleration requested but not available:\n" f" {status.reason}\n", err=True, ) proceed = click.confirm( " Continue with CPU processing instead?", default=True, ) click.echo("", err=True) # blank line for readability if not proceed: raise click.Abort() else: logger.warning( "Non-interactive terminal detected; " "falling back to CPU processing automatically." ) if not _gpu_available: cfg.hardware.gpu_enabled = False def _discover_patient_ids( data_dir: Path, config_dir: Path, config_name: str = "default", protocol: Optional[str] = None, raw_data_dir_override: Optional[Path] = None, name_pattern: Optional[str] = None, ) -> List[str]: """Scan *inputs_dir* for sub-folders whose names consist only of digits. ``inputs_dir`` is an independent project-level root (it is *not* resolved under *data_dir*), matching how :class:`ProcessingContext` resolves the per-patient input directory. *data_dir* (the shared assets base) is accepted for signature compatibility but does not affect input discovery. Returns a sorted list of patient-ID strings. """ config_manager = ConfigManager(config_dir=config_dir) overrides = None if raw_data_dir_override: overrides = {"paths": {"inputs_dir": str(raw_data_dir_override)}} cfg = config_manager.load_config( config_name=config_name, protocol=protocol, overrides=overrides, protocol_required=protocol is not None, ) raw = Path(cfg.paths.inputs_dir) search_dir = raw if raw.is_absolute() else Path.cwd() / raw if not search_dir.is_dir(): resolved = search_dir.absolute() raise click.ClickException( f"inputs_dir does not exist or is not a directory: {search_dir} " f"(resolved to {resolved})" ) def _accept(name: str) -> bool: if name_pattern: return fnmatch.fnmatch(name, name_pattern) return name.isdigit() ids = sorted( entry.name for entry in search_dir.iterdir() if entry.is_dir() and _accept(entry.name) ) if not ids: if name_pattern: raise click.ClickException( f"No patient folders matching pattern '{name_pattern}' found in {search_dir}" ) else: raise click.ClickException(f"No numeric patient folders found in {search_dir}") return ids _VALID_PATIENT_ID_RE = re.compile(r"^[A-Za-z0-9._-]+$") def _validate_explicit_patient_ids( ids: List[str], config_dir: Path, config_name: str, protocol: Optional[str], raw_data_dir_override: Optional[Path], data_dir: Path, ) -> None: """Validate explicitly-supplied ``-p`` patient IDs before any setup work. Two checks are performed: * Format: each ID must contain only ``[A-Za-z0-9._-]``; an offending ID raises :class:`click.UsageError` naming it. * Existence: each ID's input directory must exist under the resolved ``inputs_dir``. A missing ID raises :class:`click.ClickException` listing the available IDs. Existence is best-effort: if ``inputs_dir`` cannot be resolved, the existence check is skipped rather than crashing. Args: ids: Explicit patient identifiers from ``-p/--patient-id``. config_dir: Configuration directory. config_name: Base config name (``-c``). protocol: Effective protocol name (or ``None``). raw_data_dir_override: Optional ``--raw-data-dir`` override (sets inputs_dir). data_dir: Resolved shared assets base (unused for input discovery). Raises: click.UsageError: If any ID has an invalid format. click.ClickException: If any (well-formed) ID is missing under ``inputs_dir``. """ bad = [pid for pid in ids if not _VALID_PATIENT_ID_RE.match(pid)] if bad: raise click.UsageError( "Invalid patient ID(s) " + ", ".join(repr(b) for b in bad) + "; only characters [A-Za-z0-9._-] are allowed." ) try: config_manager = ConfigManager(config_dir=config_dir) overrides = None if raw_data_dir_override: overrides = {"paths": {"inputs_dir": str(raw_data_dir_override)}} cfg = config_manager.load_config( config_name=config_name, protocol=protocol, overrides=overrides, protocol_required=protocol is not None, ) raw = Path(cfg.paths.inputs_dir) search_dir: Optional[Path] = raw if raw.is_absolute() else Path.cwd() / raw except Exception as exc: # noqa: BLE001 — best-effort existence check logger.debug("Skipping patient-ID existence check (inputs_dir unresolved): {}", exc) return if search_dir is None or not search_dir.is_dir(): logger.debug("Skipping patient-ID existence check: inputs_dir not a directory") return missing = [pid for pid in ids if not (search_dir / pid).is_dir()] if missing: available = sorted(entry.name for entry in search_dir.iterdir() if entry.is_dir()) raise click.ClickException( f"Patient '{missing[0]}' not found under inputs_dir; available: {available}" ) def _ensure_workflow_imported(name: str, is_package: bool = True) -> None: """Import workflow modules so their self-registration side effects run.""" try: importlib.import_module(f"thesis.workflows.{name}") except ImportError as exc: logger.debug("Workflow package import failed for {}: {}", name, exc) if is_package: try: importlib.import_module(f"thesis.workflows.{name}.workflow") except (ImportError, ModuleNotFoundError) as exc: logger.debug("Workflow module import failed for {}: {}", name, exc) def _resolve_workflow(name: str) -> WorkflowEntry: """Resolve a workflow name to its :class:`WorkflowEntry`. Attempts to import ``thesis.workflows.<name>`` (and for packages also ``thesis.workflows.<name>.workflow``) so that self-registration is triggered even if the module has not been imported yet. Args: name: Short workflow name (e.g. ``"hcp"``). Returns: The matching :class:`WorkflowEntry`. Raises: click.ClickException: If the name is not found after auto-import. """ _ensure_workflow_imported(name) try: return WORKFLOW_REGISTRY.get(name) except KeyError: known = WORKFLOW_REGISTRY.list() known_str = ", ".join(known) if known else "<none>" raise click.ClickException( f"Unknown workflow: '{name}'. Known workflows: {known_str}.\n" f"Run 'thesis list-workflows' to see all available workflows." ) def _build_workflow( factory: Callable[..., Any], config: PipelineConfig, context: Any, ) -> Any: """Call workflow factory with a flexible signature.""" signature = inspect.signature(factory) positional = [ param for param in signature.parameters.values() if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD) ] if len(positional) >= 2: return factory(config, context) if len(positional) == 1: return factory(config) return factory() def _resolve_protocol_override( config_manager: ConfigManager, patient_id: str, protocol: Optional[str], entry: WorkflowEntry, ) -> str: """Resolve the effective protocol for a patient run.""" temp_cfg_dict = config_manager.load_config_dict(patient_id, subdir="patients") protocol_override = ( protocol or (temp_cfg_dict.get("protocol") if temp_cfg_dict else None) or entry.default_protocol ) if protocol_override: return str(protocol_override) raise click.ClickException( f"No protocol for workflow '{entry.name}' and patient '{patient_id}'. " f"Tried: --protocol ({'set' if protocol else 'not set'}), " f"config/patients/{patient_id}.yaml (no 'protocol' key), " f"workflow default ({entry.default_protocol or 'not configured'}). " f"Provide --protocol or add a 'protocol' key to a config file." ) def _resolve_data_dir( data_dir: Optional[Path], config_dir: Path, config_name: str, protocol: Optional[str], ) -> Path: """Resolve the effective shared assets directory for a run. When ``--data-dir`` is passed explicitly it wins. Otherwise we fall back to ``config.paths.assets_dir`` (the shared templates/atlases/ROIs base), and finally to the legacy ``"data"`` default if neither is set or the config cannot be loaded here. Args: data_dir: The ``--data-dir`` CLI value, or ``None`` when not passed. config_dir: Configuration directory. config_name: Base config name (``-c``). protocol: Effective protocol name (``--protocol`` or the workflow's ``default_protocol``). Returns: The resolved base data directory. """ if data_dir is not None: return data_dir try: cfg = ConfigManager(config_dir=config_dir).load_config( config_name=config_name, protocol=protocol ) configured = getattr(getattr(cfg, "paths", None), "assets_dir", None) except Exception as exc: # noqa: BLE001 — best-effort fallback to legacy default logger.debug( "Could not read paths.assets_dir for --data-dir default ({}); using 'data'", exc ) configured = None return Path(str(configured)) if configured else Path("data") def _build_patient_workflow( entry: WorkflowEntry, patient_id: str, config_name: str, config_dir: Path, data_dir: Path, output_dir: Optional[Path], protocol: Optional[str], raw_data_dir_override: Optional[Path], config_overrides: Optional[Dict[str, Any]] = None, ) -> Any: """Build a single patient's Nipype workflow for inclusion in a batch. This creates the patient's config, context, runs preflight verification, and calls the workflow factory. The returned Nipype Workflow exposes its full node graph to the parent meta-workflow scheduler, allowing proper GPU serialisation and memory accounting. Args: entry: Resolved WorkflowEntry. patient_id: Patient identifier string. config_name: Configuration name to load. config_dir: Path to config directory. data_dir: Base data directory. output_dir: Optional output directory override. protocol: Protocol name override. raw_data_dir_override: Optional raw data dir override. config_overrides: Optional CLI-level config overrides (e.g. hemisphere). Returns: A Nipype Workflow object (or raises on failure). Raises: click.ClickException: On verification failure or missing protocol. RuntimeError: If the factory does not return a valid workflow. """ config_manager = ConfigManager(config_dir=config_dir) protocol_override = _resolve_protocol_override(config_manager, patient_id, protocol, entry) overrides: Dict[str, Any] = {} if raw_data_dir_override: overrides["paths"] = {"inputs_dir": str(raw_data_dir_override)} if config_overrides: overrides = merge_configs(overrides, config_overrides) cfg = config_manager.load_config( config_name=config_name, patient_id=patient_id, protocol=protocol_override, overrides=overrides or None, protocol_required=protocol is not None, ) _resolve_gpu(cfg) ctx_obj = create_context( patient_id=patient_id, config=cfg, data_dir=data_dir, output_dir=output_dir, ) # Preflight verification if callable(entry.verifier): errors = entry.verifier(cfg, ctx_obj) if errors: msg = f"Preflight verification failed for {patient_id}:\n" msg += "\n".join(f" - {e}" for e in errors) raise click.ClickException(msg) wf = _build_workflow(entry.factory, cfg, ctx_obj) if not hasattr(wf, "run"): raise RuntimeError(f"Workflow factory did not return a Nipype Workflow for {patient_id}.") return wf def _split_patient_ids(raw: Iterable[str]) -> Tuple[str, ...]: """Expand comma-separated ``--patient-id`` values into a flat tuple. Accepts repeated flags (``-p P001 -p P002``), comma-separated values (``-p P001,P002``), or any mix (``-p P001,P002 -p P003``). Whitespace around each id is stripped and empty entries are dropped. Order is preserved and duplicates are removed (first occurrence wins). Args: raw: The raw ``patient_ids`` tuple as collected by Click. Returns: A de-duplicated tuple of individual patient identifiers. """ seen: set[str] = set() out: List[str] = [] for value in raw: for pid in value.split(","): pid = pid.strip() if pid and pid not in seen: seen.add(pid) out.append(pid) return tuple(out) def _attribute_nodes_to_patients( node_names: Iterable[str], patient_ids: Iterable[str], ) -> Dict[str, List[str]]: """Map Nipype node fullnames back to the patient IDs they belong to. In the batched meta-workflow each patient is wrapped in a sub-workflow whose name embeds the patient ID (e.g. ``tract_synthseg_818455``), so a node fullname like ``batch_tractography.tract_synthseg_818455.hcp_818455. probtrackx2.a0`` is attributable to patient ``818455``. The match is anchored at alphanumeric word boundaries so a shorter ID does not match inside a longer one (``8184`` vs ``818455``); when both could match, the longer ID wins. Args: node_names: Fullnames of nodes whose attribution is wanted. patient_ids: Patient identifiers that were added to the batch. Returns: Mapping from patient ID to the list of node fullnames attributed to that patient. Patients with no matching nodes are omitted. """ sorted_ids = sorted({pid for pid in patient_ids if pid}, key=len, reverse=True) patterns = [ (pid, re.compile(rf"(?<![0-9A-Za-z]){re.escape(pid)}(?![0-9A-Za-z])")) for pid in sorted_ids ] attributed: Dict[str, List[str]] = {} for node_name in node_names: matched = False for pid, pattern in patterns: if pattern.search(node_name): attributed.setdefault(pid, []).append(node_name) matched = True break if not matched: logger.warning("Could not attribute node {!r} to any patient", node_name) return attributed def _classify_batch_failure( *, built: Iterable[str], workflow: str, failed_by_pid: Dict[str, List[str]], ok_by_pid: Dict[str, List[str]], last_error: Optional[BaseException], ) -> List[RunResult]: """Build per-patient RunResults from a failed batch's per-node observations. Pure (no side effects) so the CLI's failure path can be exercised without standing up a real Nipype workflow. The returned ``RunResult`` for a SUCCESS patient is *not* tract-similarity-enriched -- the caller does that separately because enrichment needs config/context the helper shouldn't own. Args: built: Patient IDs whose workflows were successfully added to the batch. workflow: Workflow name to stamp onto each result. failed_by_pid: Mapping from patient ID to fullnames of nodes that failed for that patient (from ``_attribute_nodes_to_patients``). ok_by_pid: Same shape, for nodes that finished successfully. last_error: Exception raised by the underlying ``meta_wf.run()``. Returns: One ``RunResult`` per patient in ``built`` (preserving order): - **FAILED** if the patient has any entry in ``failed_by_pid``. ``error_message`` lists the failed node names plus the batch error. - **SUCCESS** if the patient has at least one entry in ``ok_by_pid`` and none in ``failed_by_pid``. - **BLOCKED** if the callback observed neither -- the scheduler likely aborted before this patient ran. If ``failed_by_pid`` is empty, every patient is marked **FAILED** with ``str(last_error)``: the callback recorded no per-node failures, so the safe fallback is to attribute the batch error to everyone rather than silently call them SUCCESS. """ results: List[RunResult] = [] if not failed_by_pid: for pid in built: results.append( RunResult( patient_id=pid, workflow=workflow, status=RunStatus.FAILED, error_message=str(last_error), ) ) return results for pid in built: if pid in failed_by_pid: failed_nodes_str = ", ".join(sorted(failed_by_pid[pid])) results.append( RunResult( patient_id=pid, workflow=workflow, status=RunStatus.FAILED, error_message=(f"Failed nodes: {failed_nodes_str} (batch error: {last_error})"), ) ) elif pid in ok_by_pid: results.append(RunResult(patient_id=pid, workflow=workflow, status=RunStatus.SUCCESS)) else: results.append( RunResult( patient_id=pid, workflow=workflow, status=RunStatus.BLOCKED, error_message="Workflow aborted before this patient ran.", ) ) return results def _build_batch_workflow( entry: WorkflowEntry, ids: list[str], cfg: PipelineConfig, config_name: str, data_dir: Path, output_dir: Optional[Path], config_dir: Path, protocol: Optional[str], raw_data_dir: Optional[Path], event_bus: Optional[Any] = None, config_overrides: Optional[Dict[str, Any]] = None, ) -> Any: """Build a flat Nipype meta-workflow containing all patient sub-workflows. Instead of wrapping each patient in an opaque ``Function`` node, this function constructs each patient's real Nipype workflow and adds it as a direct sub-workflow of the meta-workflow. This exposes every node (including GPU-bound ones like ``synthseg`` and ``probtrackx2``) to the global ``MultiProc`` scheduler, which can then: * Parallelise CPU-bound nodes across patients. * Serialise GPU-bound nodes via ``n_gpu_procs`` (preventing OOM). * Account for ``mem_gb`` per node to avoid memory oversubscription. Args: entry: Resolved WorkflowEntry. ids: List of patient ID strings. cfg: Base PipelineConfig (used for working_dir resolution). config_name: Configuration name to load per patient. data_dir: Base data directory. output_dir: Optional output directory override. config_dir: Path to config directory. protocol: Protocol name override. raw_data_dir: Optional raw data dir override. event_bus: Optional event bus for emitting build status events. config_overrides: Optional CLI-level config overrides (e.g. hemisphere). Returns: Nipype Workflow containing all patient sub-workflows. """ from nipype import Workflow bus = event_bus or get_event_bus() meta_wf = Workflow(name="batch_tractography") work_dir = Path(cfg.nipype.working_dir) / "batch" work_dir.mkdir(parents=True, exist_ok=True) meta_wf.base_dir = str(work_dir) built: List[str] = [] failed: List[str] = [] failed_reasons: Dict[str, str] = {} for pid in ids: try: patient_wf = _build_patient_workflow( entry=entry, patient_id=pid, config_name=config_name, config_dir=config_dir, data_dir=data_dir, output_dir=output_dir, protocol=protocol, raw_data_dir_override=raw_data_dir, config_overrides=config_overrides, ) meta_wf.add_nodes([patient_wf]) built.append(pid) bus.emit( f"Workflow built for {pid}", level=EventLevel.INFO, category="build", patient_id=pid, ) except Exception as exc: logger.error("Failed to build workflow for {}: {}", pid, exc, exc_info=True) bus.emit( f"Skipping {pid}: {type(exc).__name__}: {exc}", level=EventLevel.ERROR, category="build", patient_id=pid, ) failed.append(pid) failed_reasons[pid] = f"{type(exc).__name__}: {exc}" if not built: base_msg = ( "No patient workflows could be built. " f"All {len(ids)} patient(s) failed to build." ) reasons = list(failed_reasons.values()) if reasons and len(set(reasons)) == 1: common_reason = reasons[0] raise click.ClickException( f"{base_msg} Common error: {common_reason}. " "Check logs or rerun with -v for per-patient detail." ) if reasons: counter = collections.Counter(reasons) common_reason, count = counter.most_common(1)[0] if count > 1: raise click.ClickException( f"{base_msg} Common error: {common_reason}. " "Check logs or rerun with -v for per-patient detail." ) raise click.ClickException( f"{base_msg} Reasons vary across patients; rerun with -v to see " "per-patient errors (e.g. thesis run -v ...)." ) if failed: bus.emit( f"Built {len(built)}/{len(ids)} patient workflows " f"({len(failed)} skipped: {', '.join(failed)})", level=EventLevel.WARNING, category="build", ) return meta_wf, built, failed def _run_slurm_batch( meta_wf: Any, built: List[str], cfg: PipelineConfig, *, workflow_name: str, config_name: str, dry_run: bool, event_bus: Optional[Any] = None, ) -> Dict[str, Any]: """Partition each built patient workflow and submit grouped SLURM stage jobs. Each top-level patient sub-workflow inside ``meta_wf`` is partitioned into a linear chain of resource-homogeneous stage groups and submitted via ``sbatch`` with an ``afterok`` dependency chain per patient. All SLURM modules are imported locally here so the disabled path never touches them. Args: meta_wf: The batch meta-workflow whose direct children are the per-patient built workflows. built: Patient ids successfully built (mapped onto ``meta_wf`` children by name suffix). cfg: Loaded pipeline config (``cfg.slurm`` must be enabled). workflow_name: Registered workflow name (``-w``). config_name: Configuration name (``-c``). dry_run: When ``True``, scripts + planned manifest are written but no ``sbatch`` is invoked. event_bus: Optional event bus for status events. Returns: The submission manifest dict from :func:`thesis.core.slurm.submitter.submit_jobs`. """ from thesis.core.slurm.partitioner import partition from thesis.core.slurm.submitter import submit_jobs bus = event_bus or get_event_bus() # Map each patient id to its built sub-workflow inside meta_wf by name suffix. children = [n for n in meta_wf._graph.nodes() if hasattr(n, "_graph")] per_patient_groups: Dict[str, Any] = {} for pid in built: match = next( (c for c in children if str(getattr(c, "name", "")).endswith(pid)), None, ) if match is None: logger.warning("Could not locate built sub-workflow for patient {}", pid) continue groups = partition(match) per_patient_groups[pid] = groups bus.emit( f"Partitioned {pid} into {len(groups)} stage group(s)", level=EventLevel.INFO, category="slurm", patient_id=pid, ) return submit_jobs( per_patient_groups, cfg.slurm, workflow_name=workflow_name, config_name=config_name, dry_run=dry_run, ) def _render_batch_tractography_stats( run_results: List["RunResult"], config_name: str, config_dir: Path, data_dir: Path, output_dir: Optional[Path], protocol: Optional[str], ) -> None: """Print aggregated tractography statistics for a batch run. Only includes patients that completed successfully. Silently returns if no successful patients exist or if statistics collection fails. Args: run_results: Per-patient run results from the batch. config_name: Configuration name (for resolving output paths). config_dir: Configuration directory. data_dir: Base data directory. output_dir: Override output directory (may be ``None``). protocol: Protocol name override. """ succeeded_ids = [r.patient_id for r in run_results if r.status == RunStatus.SUCCESS] if len(succeeded_ids) < 1: logger.debug("Batch stats rendering skipped: no successful patients") return try: from thesis.workflows.qc.statistics import collect_batch_stats, format_stats_table # Resolve the output base directory the same way contexts do. config_manager = ConfigManager(config_dir=config_dir) cfg = config_manager.load_config( config_name=config_name, patient_id=succeeded_ids[0], protocol=protocol, ) output_base: Optional[Path] = output_dir if output_base is None: raw_out = getattr(cfg.paths, "output_dir", None) if raw_out: output_base = Path(raw_out) if output_base is None or not output_base.is_dir(): logger.debug( "Batch stats rendering skipped: output_dir not resolved or does " "not exist (output_base={})", output_base, ) return tractography_relpath = getattr( getattr(cfg, "atlas", None), "tractography_relpath", "tractography/probtrackx2" ) stats = collect_batch_stats( output_base, patient_ids=succeeded_ids, tractography_relpath=tractography_relpath, ) if not stats: return include_rois = len(succeeded_ids) <= 10 # avoid huge tables table = format_stats_table(stats, include_roi_counts=include_rois) click.echo(f"\n{'─' * 70}", err=True) click.echo("Tractography Statistics", err=True) click.echo(f"{'─' * 70}", err=True) click.echo(table, err=True) outliers: List[Dict[str, Any]] = [] sd_thresh = getattr(getattr(cfg, "qc", None), "outlier_sd_threshold", 2.0) if len(stats) >= 3: from thesis.workflows.qc.checks import detect_batch_outliers outliers = detect_batch_outliers(stats, sd_threshold=sd_thresh) if outliers: click.echo(f"\n{'─' * 70}", err=True) click.echo(f"Outliers (>{sd_thresh} SD from mean)", err=True) click.echo(f"{'─' * 70}", err=True) for o in outliers: click.echo( f" {o['patient_id']:<12} {o['metric']:<20} " f"value={o['value']:<10} z={o['z_score']:.1f}", err=True, ) _persist_batch_stats(output_base, stats, outliers) except Exception as _exc: # Best-effort — never fail the run for stats. logger.debug("Batch statistics rendering failed: {}", _exc) def _persist_batch_stats( output_base: Path, stats: List[Dict[str, Any]], outliers: List[Dict[str, Any]], ) -> None: """Write batch stats and outliers to JSON under ``output_base/batch_stats``. Emits timestamped files plus a convenience ``latest.json`` that is overwritten each run. Best-effort — any IO failure is logged and swallowed so stats persistence never breaks a pipeline run. """ try: stats_dir = output_base / "batch_stats" stats_dir.mkdir(exist_ok=True) ts = time.strftime("%Y%m%d_%H%M%S") payload = {"subjects": stats, "outliers": outliers} stats_path = stats_dir / f"stats_{ts}.json" outliers_path = stats_dir / f"outliers_{ts}.json" latest_path = stats_dir / "latest.json" stats_path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8") outliers_path.write_text(json.dumps(outliers, indent=2, default=str), encoding="utf-8") latest_path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8") logger.info("Wrote batch stats to {}", stats_path) except Exception as exc: logger.debug("Persisting batch stats failed: {}", exc) def _enrich_with_tract_similarity_metrics( result: RunResult, cfg: PipelineConfig, ctx_obj: Any, ) -> None: """Populate ``result.metadata['tract_similarity']`` from metrics.json. Called after a successful ``full_pipeline`` run so that the terminal summary can show headline similarity numbers without re-reading the full metric file. A best-effort operation: any failure leaves the metadata unchanged. """ if result.workflow != "full_pipeline": return if ctx_obj.output_dir is None: return subdir = getattr(getattr(cfg, "tract_similarity", None), "output_subdir", None) if not subdir: return metrics_file = Path(ctx_obj.output_dir) / subdir / "metrics.json" if not metrics_file.is_file(): return try: data = json.loads(metrics_file.read_text(encoding="utf-8")) except (json.JSONDecodeError, OSError) as exc: logger.debug("Could not parse tract_similarity metrics.json: {}", exc) return overlap = data.get("overlap") or {} correlation = data.get("correlation") or {} distance = data.get("distance_mm") or {} distribution = data.get("distribution") or {} result.metadata["tract_similarity"] = { "dice": overlap.get("dice"), "pearson": correlation.get("pearson"), "hausdorff95": distance.get("hausdorff95"), "nmi": distribution.get("nmi"), } def _run_single_patient_with_retries( entry: WorkflowEntry, config: str, patient_id: str, data_dir: Path, output_dir: Optional[Path], config_dir: Path, graph: bool, dry_run: bool, protocol: Optional[str], raw_data_dir_override: Optional[Path], retries: int, show_workflow_progress: bool = False, config_overrides: Optional[Dict[str, Any]] = None, stage: Optional[str] = None, ) -> RunResult: """Run a patient workflow with retry support for sequential execution. Returns a :class:`RunResult` populated with status, timing, and error info. """ bus = get_event_bus() result = RunResult(patient_id=patient_id, workflow=entry.name) start_time = time.monotonic() error_msg = "" error_type = "" for attempt in range(retries + 1): if attempt > 0: result.retry_count = attempt bus.emit( f"Retry {attempt}/{retries} for {patient_id} (waiting 60s)", level=EventLevel.WARNING, category="retry", patient_id=patient_id, ) time.sleep(60) ok, error_msg, error_type = _run_single_patient( entry=entry, config=config, patient_id=patient_id, data_dir=data_dir, output_dir=output_dir, config_dir=config_dir, graph=graph, dry_run=dry_run, protocol=protocol, raw_data_dir_override=raw_data_dir_override, show_workflow_progress=show_workflow_progress, config_overrides=config_overrides, result=result, stage=stage, ) if ok: result.status = RunStatus.SKIPPED if dry_run else RunStatus.SUCCESS result.elapsed_seconds = time.monotonic() - start_time return result result.error_history.append(f"Attempt {attempt + 1}: {error_type or 'Error'}: {error_msg}") # All attempts failed result.status = RunStatus.FAILED result.error_message = error_msg result.error_type = error_type result.elapsed_seconds = time.monotonic() - start_time return result def _make_output_config(ctx: click.Context) -> OutputConfig: """Build an :class:`OutputConfig` from CLI context flags. Args: ctx: Click context whose ``obj`` dict may contain ``output_mode``, ``summary_detail``, and ``show_progress`` keys. Returns: Configured :class:`OutputConfig`. """ obj = ctx.obj or {} mode = obj.get("output_mode", OutputMode.NORMAL) summary = obj.get("summary_detail", SummaryDetail.COMPACT) progress = obj.get("show_progress", None) return OutputConfig(mode=mode, summary=summary, progress=progress) def _should_show_workflow_progress(entry: WorkflowEntry, show_progress: bool) -> bool: """Determine whether Click node progress should be created for a workflow run. Args: entry: Resolved workflow entry. show_progress: Whether progress UI is otherwise enabled. Returns: ``True`` when node-level Click progress should be enabled. """ if not show_progress: return False return not bool(getattr(entry, "is_cohort_level", False)) @click.group() @click.version_option(version=__version__) @click.option( "--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False), default="INFO", show_default=True, help="Logging level", ) @click.option( "--log-dir", type=click.Path(path_type=Path), # type: ignore[type-var] default="logs", show_default=True, help="Directory for log files", ) @click.option( "-v", "--verbose", is_flag=True, default=False, help="Verbose output: show DEBUG-level events, full logs, timing " "(independent of --summary detail).", ) @click.option( "-q", "--quiet", is_flag=True, default=False, help="Quiet output (errors and final result only)", ) @click.option( "--summary", "summary_opt", type=click.Choice(["off", "compact", "full"], case_sensitive=False), default=None, help="Summary detail level: off|compact|full (controls the final summary " "structure; independent of -v). Default: compact.", ) @click.option( "--no-progress", is_flag=True, default=False, help="Disable animated progress bars and spinners", ) @click.pass_context def main( ctx: click.Context, log_level: str, log_dir: Path, verbose: bool, quiet: bool, summary_opt: Optional[str], no_progress: bool, ) -> None: """ Thesis Medical Imaging Pipeline Framework. A modular Python framework for preprocessing and tractography analysis of diffusion MRI data. Examples: \b thesis run -w hcp -p P001 -c default thesis list-workflows thesis stats collect -o outputs/ """ # Ensure context object ctx.ensure_object(dict) # Resolve output mode from flags if verbose and quiet: raise click.UsageError("Cannot use both --verbose/-v and --quiet/-q.") if verbose: output_mode = OutputMode.VERBOSE # Also bump log level to DEBUG in verbose mode log_level = "DEBUG" elif quiet: output_mode = OutputMode.QUIET else: output_mode = OutputMode.NORMAL # Resolve summary detail if summary_opt is not None: summary_detail = SummaryDetail(summary_opt.lower()) else: summary_detail = SummaryDetail.COMPACT # Setup logging setup_logging(log_dir=log_dir, log_level=log_level.upper()) if output_mode == OutputMode.VERBOSE: set_console_level("DEBUG") elif output_mode == OutputMode.QUIET: set_console_level("ERROR") else: set_console_level("WARNING") # Store in context ctx.obj["log_level"] = log_level ctx.obj["log_dir"] = log_dir ctx.obj["output_mode"] = output_mode ctx.obj["summary_detail"] = summary_detail ctx.obj["show_progress"] = False if no_progress else None # None = auto-detect logger.info("Thesis framework v{}", __version__) @main.command() @click.option( "--config-dir", type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var] default="config", show_default=True, help="Configuration directory", ) @click.option( "--subdir", type=str, default=None, help="Subdirectory to list (e.g., 'patients', 'protocols')", ) def list_configs(config_dir: Path, subdir: Optional[str]) -> None: """ List available configuration files. Example: thesis list-configs thesis list-configs --subdir patients """ config_manager = ConfigManager(config_dir=config_dir) configs = config_manager.list_configs(subdir=subdir) if configs: location = config_dir / subdir if subdir else config_dir click.echo(f"Available configs in {location}:") for cfg in configs: click.echo(f" - {cfg}") else: click.echo("No configurations found") @main.command() @click.argument("config_name") @click.option( "--config-dir", type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var] default="config", show_default=True, help="Configuration directory", ) def show_config(config_name: str, config_dir: Path) -> None: """ Display the contents of a configuration file. Example: thesis show-config default """ try: config_manager = ConfigManager(config_dir=config_dir) cfg = config_manager.load_config(config_name=config_name) # Convert to YAML string for display config_dict = cfg.to_dict() yaml_str = yaml.dump(config_dict, default_flow_style=False, sort_keys=False) click.echo(f"Configuration: {config_name}") click.echo("=" * 60) click.echo(yaml_str) except FileNotFoundError: click.echo(f"[ERROR] Configuration not found: {config_name}", err=True) raise click.Abort() except ConfigurationError as e: click.echo(f"[ERROR] {e}", err=True) raise click.Abort() except Exception as e: click.echo(f"[ERROR] Error loading configuration: {e}", err=True) raise click.Abort() def _run_single_patient( entry: WorkflowEntry, config: str, patient_id: str, data_dir: Path, output_dir: Optional[Path], config_dir: Path, graph: bool, dry_run: bool, protocol: Optional[str] = None, raw_data_dir_override: Optional[Path] = None, config_overrides: Optional[Dict[str, Any]] = None, show_workflow_progress: bool = False, result: Optional[RunResult] = None, stage: Optional[str] = None, ) -> Tuple[bool, str, str]: """Run a Nipype workflow for a single patient. When ``stage`` is provided the built workflow is pruned to that SLURM stage group's node set before execution (the grouped-SLURM ``--stage`` path), leaning on Nipype caching for upstream inputs. Returns a tuple of ``(success, error_message, error_type)``. """ bus = get_event_bus() logger.info("Starting Nipype workflow for patient: {}", patient_id) bus.emit( f"Starting workflow for {patient_id}", level=EventLevel.IMPORTANT, category="workflow", patient_id=patient_id, ) try: # Load configuration config_manager = ConfigManager(config_dir=config_dir) # Determine protocol: CLI flag > patient YAML > WorkflowEntry.default_protocol protocol_override = _resolve_protocol_override(config_manager, patient_id, protocol, entry) overrides: Dict[str, Any] = {} if raw_data_dir_override: overrides["paths"] = {"inputs_dir": str(raw_data_dir_override)} if config_overrides: overrides = merge_configs(overrides, config_overrides) try: cfg = config_manager.load_config( config_name=config, patient_id=patient_id, protocol=protocol_override, overrides=overrides or None, protocol_required=protocol is not None, ) except ConfigurationError as cfg_exc: logger.error( "Configuration validation failed for {}: {}", patient_id, cfg_exc, exc_info=True ) augmented = ( f"{cfg_exc} (check config/{protocol_override}.yaml, " f"config/patients/{patient_id}.yaml, or {config}.yaml)" ) bus.emit( f"Configuration error for {patient_id}: {augmented}", level=EventLevel.ERROR, category="workflow", patient_id=patient_id, ) return (False, augmented, "ConfigurationError") # Validate GPU availability once per process; mutates cfg if falling back. # Skip in the grouped-SLURM --stage path: the partition was computed # up-front on the submit node (where the GPU is visible), so re-probing # here would flip hardware.gpu_enabled on a CPU-only stage node, # reclassify the GPU node as CPU, and collapse the partition to a single # stage whose name no longer matches the submitted --stage token (the # stage then either runs the whole pipeline or fails with "Unknown # stage"). Trust the submit-time config; the GPU node is isolated in its # own stage and never executes on a CPU stage node. if stage is None: _resolve_gpu(cfg) bus.emit( f"GPU resolution: gpu_enabled={cfg.hardware.gpu_enabled}", level=EventLevel.INFO, category="hardware", ) # Create processing context ctx_obj = create_context( patient_id=patient_id, config=cfg, data_dir=data_dir, output_dir=output_dir, ) logger.info("Configuration loaded: {}", config) logger.info("Input directory: {}", ctx_obj.input_dir) logger.info("Output directory: {}", ctx_obj.output_dir) # Input data directory preflight: a missing patient directory gives a # clear "directory not found" message instead of a per-file @requires # failure deeper in the workflow. if ctx_obj.input_dir is not None and not Path(ctx_obj.input_dir).exists(): msg = ( f"Input data directory not found: {ctx_obj.input_dir}. " f"Check that patient ID {patient_id} exists under inputs_dir." ) bus.emit( msg, level=EventLevel.ERROR, category="preflight", patient_id=patient_id, ) return (False, msg, "FileIOError") # Preflight verification if callable(entry.verifier): try: errors = entry.verifier(cfg, ctx_obj) if errors: error_msg = "; ".join(errors) bus.emit( f"Preflight verification failed for {patient_id}: {error_msg}", level=EventLevel.ERROR, category="preflight", patient_id=patient_id, ) return (False, f"Preflight failed: {error_msg}", "PreflightError") except Exception as ve: logger.error("Verification step failed: {}", ve, exc_info=True) bus.emit( f"Verification error for {patient_id}: {type(ve).__name__}: {ve} " "(see logs in logs or rerun with -v)", level=EventLevel.ERROR, category="preflight", patient_id=patient_id, ) return (False, str(ve), type(ve).__name__) # Build workflow bus.emit( f"Building workflow for {patient_id}", level=EventLevel.INFO, category="workflow", patient_id=patient_id, ) wf = _build_workflow(entry.factory, cfg, ctx_obj) if not hasattr(wf, "run"): msg = ( f"Workflow factory returned {type(wf).__name__} instead of a " f"nipype.pipeline.engine.Workflow for {patient_id}. " "Check the factory's return value." ) bus.emit(msg, level=EventLevel.ERROR, category="workflow", patient_id=patient_id) return (False, msg, "RuntimeError") # Grouped-SLURM stage pruning: restrict the built graph to one stage's # nodes before normal execution. Lazy-imported so the non-stage path # never touches the slurm package. if stage is not None: from thesis.core.slurm.pruning import find_stage_group, prune_workflow_to_stage # Grouped-SLURM stages hand results between separate sbatch processes # through Nipype's on-disk work/ cache. Nipype mixes ``needed_outputs`` # into a node's hash when ``remove_unnecessary_outputs`` is on, and a # node's needed-outputs set differs per stage (its consumers are often # pruned into a *later* stage) — so the producing stage stores one hash # and the consuming stage computes another, the cache always misses, # and every retained upstream ancestor re-runs (SynthSeg recomputed on # the CPU stage). Disabling it drops needed_outputs from the hash so the # cross-stage cache hits; keeping all outputs is also what staging needs # (later stages read them). cfg.nipype.remove_unnecessary_outputs = False stage_group = find_stage_group(wf, stage) prune_workflow_to_stage(wf, stage_group) bus.emit( f"Restricted workflow to stage '{stage}' for {patient_id}", level=EventLevel.IMPORTANT, category="workflow", patient_id=patient_id, ) if graph: if ctx_obj.output_dir is None: raise click.ClickException("output_dir must be set in context when --graph is used") graph_dir = ctx_obj.output_dir / "workflow_graphs" graph_dir.mkdir(parents=True, exist_ok=True) wf.write_graph( graph2use="flat", format="png", simple_form=True, dotfilename=str(graph_dir / "workflow"), ) bus.emit( f"Workflow graph written to: {graph_dir}", level=EventLevel.IMPORTANT, category="output", patient_id=patient_id, ) if dry_run: bus.emit( f"[DRY RUN] Workflow built for {patient_id} but not executed", level=EventLevel.IMPORTANT, category="workflow", patient_id=patient_id, ) return (True, "", "") # Execute workflow bus.emit( f"Executing workflow for {patient_id}", level=EventLevel.IMPORTANT, category="workflow", patient_id=patient_id, ) workflow_progress: Optional[ClickNodeProgress] = None if show_workflow_progress: total_nodes = count_workflow_nodes(wf) if total_nodes > 0: workflow_progress = ClickNodeProgress( total=total_nodes, label=f"Nodes for {patient_id}", enabled=True, ) workflow_progress.start() try: run_workflow( wf, ctx_obj, cfg.nipype, event_bus=bus, progress=workflow_progress, ) finally: if workflow_progress is not None: workflow_progress.stop() bus.emit( f"Workflow completed successfully for {patient_id}", level=EventLevel.IMPORTANT, category="workflow", patient_id=patient_id, ) logger.info("Nipype workflow completed for {}", patient_id) # Post-workflow QC overlays (best-effort) if ctx_obj.output_dir is not None: from thesis.workflows.qc.operations import run_post_workflow_qc run_post_workflow_qc(cfg, ctx_obj) if result is not None: _enrich_with_tract_similarity_metrics(result, cfg, ctx_obj) return (True, "", "") except click.ClickException as exc: logger.error("Nipype workflow failed for {}: {}", patient_id, exc, exc_info=True) log_dir = "logs" bus.emit( f"Error for {patient_id}: ClickException: {exc.format_message()} " f"(see logs in {log_dir} or rerun with -v)", level=EventLevel.ERROR, category="workflow", patient_id=patient_id, ) return (False, exc.format_message(), "ClickException") except Exception as e: logger.error("Nipype workflow failed for {}: {}", patient_id, e, exc_info=True) log_dir = "logs" category_hint = "processing error -- " if isinstance(e, ProcessingError) else "" bus.emit( f"Error for {patient_id}: {category_hint}{type(e).__name__}: {e} " f"(see logs in {log_dir} or rerun with -v)", level=EventLevel.ERROR, category="workflow", patient_id=patient_id, ) return (False, str(e), type(e).__name__) @main.command() @click.option( "--workflow", "-w", required=False, default=None, help="Registered workflow name (e.g. hcp, tract_synthseg, synthseg, minimal). " "Required unless --script is given." " Some workflows (tract_synthseg, full_pipeline) support multiple backends " "via config.tractography.method: probtrackx2 (default) or mrtrix3.", ) @click.option( "--script", "-s", "script", type=click.Path(path_type=Path, dir_okay=False), # type: ignore[type-var] default=None, help="Path to a user-supplied workflow script (.py) that registers a workflow " "via the @workflow decorator. Overrides --workflow.", ) @click.option( "--config", "-c", type=str, default="default", show_default=True, help="Configuration name to load", ) @click.option( "--protocol", "-t", type=str, default=None, help="Protocol name override. Resolved as: CLI --protocol > " "config/patients/<patient-id>.yaml protocol field > workflow default > " "error if none set.", ) @click.option( "--patient-id", "-p", "patient_ids", multiple=True, help="Patient identifier(s). Repeat (-p P001 -p P002) and/or pass a " "comma-separated list (-p P001,P002,P003). Mutually exclusive with --all.", ) @click.option( "--all", "all_patients", is_flag=True, default=False, help="Discover and process all numeric patient folders in inputs_dir.", ) @click.option( "--pattern", "-n", type=str, default=None, show_default=True, help="Glob pattern to match patient folder names in inputs_dir " "(requires --all). If omitted, only numeric folder names are discovered.", ) @click.option( "--parallel", is_flag=True, default=False, help="Enable parallel execution via a Nipype meta-workflow. Without -j, " "uses nipype.plugin_args.n_procs from config; set -j/--max-workers for an " "explicit worker count.", ) @click.option( "--max-workers", "-j", type=click.IntRange(min=1), default=None, help="Worker count (implies --parallel). Overrides nipype.plugin_args.n_procs " "(default: uses config nipype.plugin_args.n_procs or system default).", ) @click.option( "--retries", "-r", type=click.IntRange(min=0), default=0, show_default=True, help="Number of retry attempts per patient on failure (default: 0).", ) @click.option( "--data-dir", type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var] default=None, help="Shared assets directory (templates, atlases, ROIs); overrides config " "paths.assets_dir (default: 'data'). Independent of the inputs/outputs roots.", ) @click.option( "--raw-data-dir", type=click.Path(path_type=Path), # type: ignore[type-var] default=None, help="Override paths.inputs_dir, the per-patient source-data root (default: " "from config, typically data/raw). An independent project-level root, " "resolved relative to the current directory when not absolute.", ) @click.option( "--output-dir", type=click.Path(path_type=Path), # type: ignore[type-var] default=None, help="Output directory (default: outputs/<patient-id>, from config " "paths.output_dir).", ) @click.option( "--config-dir", type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var] default="config", show_default=True, help="Configuration directory", ) @click.option( "--graph", is_flag=True, default=False, help="Write workflow graph to output directory", ) @click.option( "--dry-run", is_flag=True, default=False, help="Build workflow but do not execute", ) @click.option( "--hemisphere", type=click.Choice(["left", "right", "both", "both-separately"], case_sensitive=False), default=None, help=( "Hemisphere selection for tractography (supported by hcp, mrtrix3, " "tract_synthseg, full_pipeline): left, right, both (merged ROIs, default), " "or both-separately (two independent runs with separate output " "directories). Ignored/rejected for cohort-level and non-tractography " "workflows." ), ) @click.option( "--stage", "stage", type=str, default=None, help=( "Restrict execution to a single SLURM stage group (e.g. 'stage_0_cpu'). " "Prunes the built workflow to that stage's nodes before running; intended " "for grouped SLURM submission and not normally used directly." ), ) @click.pass_context def run( ctx: click.Context, workflow: Optional[str], script: Optional[Path], config: str, protocol: Optional[str], patient_ids: Tuple[str, ...], all_patients: bool, pattern: Optional[str], parallel: bool, max_workers: Optional[int], retries: int, data_dir: Optional[Path], raw_data_dir: Optional[Path], output_dir: Optional[Path], config_dir: Path, graph: bool, dry_run: bool, hemisphere: Optional[str], stage: Optional[str], ) -> None: """ Run a registered workflow for one or more patients. Workflow is identified either by its short registered name (use ``thesis list-workflows`` to see what's available) or by an out-of-tree script via ``--script PATH.py``. Execution is sequential by default; pass ``--parallel`` or ``-j N`` to fan out across patients via a Nipype meta-workflow. Examples: \b thesis run -w hcp -p DTI_001 --dry-run thesis run -w hcp --all -c default --dry-run thesis run -w hcp --all -j 4 thesis run -w hcp -p P001 -p P002 --parallel thesis run -w hcp -p P001,P002,P003 -c default thesis run -w hcp -p DTI_001 --hemisphere left thesis run -w hcp -p DTI_001 --hemisphere both-separately thesis run -w full_pipeline -p 114823 -c full_pipeline_mrtrix3 # MRtrix3 backend thesis run --script ./my_workflow.py -p DTI_001 -c default """ # Normalise -p values: support both repeated flags and comma-separated lists. patient_ids = _split_patient_ids(patient_ids) if all_patients and patient_ids: raise click.UsageError("Cannot use both --patient-id/-p and --all.") if pattern and not all_patients: raise click.UsageError("--pattern can only be used together with --all.") if workflow is None and script is None: raise click.UsageError("Either --workflow / -w or --script / -s is required.") script_workflow_mismatch: Optional[Tuple[str, str]] = None if script is not None: from thesis.core.loader import load_workflow_script loaded_entry = load_workflow_script(script) if workflow is not None and workflow != loaded_entry.name: # Defer the warning until the event bus exists so it surfaces through # the structured output system (a plain logger.warning here would be # suppressed in -q mode). script_workflow_mismatch = (workflow, loaded_entry.name) workflow = loaded_entry.name assert workflow is not None # guarded above # Build config overrides from CLI flags cli_overrides: Dict[str, Any] = {} if hemisphere is not None: cli_overrides.setdefault("tractography", {})["hemisphere"] = hemisphere # Setup output system output_config = _make_output_config(ctx) reset_event_bus() bus = get_event_bus() renderer = OutputRenderer(config=output_config, event_bus=bus) renderer.attach(bus) if script_workflow_mismatch is not None: requested_wf, script_wf = script_workflow_mismatch bus.emit( f"--workflow {requested_wf!r} ignored; script registered " f"workflow {script_wf!r}.", level=EventLevel.WARNING, category="workflow", ) # Resolve workflow entry (triggers auto-import/registration) entry = _resolve_workflow(workflow) # Resolve the base data dir: explicit --data-dir wins, else fall back to # config.paths.assets_dir (the shared assets base), else "data". data_dir = _resolve_data_dir(data_dir, config_dir, config, protocol or entry.default_protocol) # ----------------------------------------------------------------------- # Cohort-level workflow execution bypass # ----------------------------------------------------------------------- if getattr(entry, "is_cohort_level", False): if hemisphere is not None: raise click.UsageError( "--hemisphere cannot be used with cohort-level workflows (e.g. " "atlas, tract_similarity_cohort). Hemisphere is a patient-level " "tractography setting." ) if patient_ids or all_patients: raise click.UsageError( f"Cohort-level workflow '{workflow}' does not accept -p/--patient-id " "or --all. These workflows operate over the configured output " "directory; omit these flags." ) if retries > 0: logger.warning( "--retries applies to the single cohort run; there is one logical " "run per invocation." ) bus.emit( f"Starting cohort-level workflow: {workflow}", level=EventLevel.IMPORTANT, category="workflow", ) batch_start = time.monotonic() result = _run_single_patient_with_retries( entry=entry, config=config, patient_id="cohort", data_dir=data_dir, output_dir=output_dir, config_dir=config_dir, graph=graph, dry_run=dry_run, protocol=protocol or entry.default_protocol, raw_data_dir_override=raw_data_dir, retries=retries, show_workflow_progress=_should_show_workflow_progress( entry, output_config.show_progress, ), config_overrides=cli_overrides, stage=stage, ) renderer.detach() summary = RunSummary.from_result(result) renderer.render_run_summary(summary) if result.status in (RunStatus.FAILED, RunStatus.BLOCKED): ctx.exit(1) return # ----------------------------------------------------------------------- # Standard Patient-level execution # ----------------------------------------------------------------------- if not all_patients and not patient_ids: raise click.UsageError( "Either --patient-id / -p or --all is required for patient-level workflows." ) # Fast-fail config/protocol preflight: a typo in -c/--protocol should fail # here, before any patient discovery or per-patient setup work (M4). _effective_protocol = protocol or entry.default_protocol try: ConfigManager(config_dir=config_dir).load_config( config_name=config, protocol=_effective_protocol, protocol_required=protocol is not None, ) except ConfigurationError as cfg_exc: raise click.ClickException( f"{cfg_exc} (failed to load config '{config}'" + (f" with protocol '{_effective_protocol}'" if _effective_protocol else "") + ")" ) # Collect patient IDs if all_patients: ids = _discover_patient_ids( data_dir=data_dir, config_dir=config_dir, config_name=config, protocol=protocol or entry.default_protocol, raw_data_dir_override=raw_data_dir, name_pattern=pattern, ) bus.emit( f"Discovered {len(ids)} patient(s): {', '.join(ids)}", level=EventLevel.IMPORTANT, category="discovery", ) else: ids = list(patient_ids) _validate_explicit_patient_ids( ids, config_dir=config_dir, config_name=config, protocol=_effective_protocol, raw_data_dir_override=raw_data_dir, data_dir=data_dir, ) total = len(ids) is_batch = total > 1 is_parallel = parallel or max_workers is not None batch_start = time.monotonic() # ----------------------------------------------------------------------- # Sequential execution # ----------------------------------------------------------------------- if not is_parallel: run_results: List[RunResult] = [] # Helper for click.progressbar item display def _patient_show(pid: Optional[str]) -> str: return pid if pid else "" # Use patient-level progress for sequential batch runs. show_patient_progress = is_batch and output_config.show_progress cm = suspend_console_logging() if show_patient_progress else contextlib.nullcontext() with cm: with ( click.progressbar( ids, label=f"Running {workflow}", item_show_func=_patient_show, show_pos=True, show_eta=True, file=sys.stderr, color=True, width=40, ) if show_patient_progress else contextlib.nullcontext(ids) ) as patient_iter: for pid in patient_iter: result = _run_single_patient_with_retries( entry=entry, config=config, patient_id=pid, data_dir=data_dir, output_dir=output_dir, config_dir=config_dir, graph=graph, dry_run=dry_run, protocol=protocol, raw_data_dir_override=raw_data_dir, retries=retries, show_workflow_progress=output_config.show_progress and not is_batch, config_overrides=cli_overrides or None, stage=stage, ) result.workflow = workflow run_results.append(result) # Detach renderer before printing summary renderer.detach() # Render summaries batch_elapsed = time.monotonic() - batch_start if is_batch: batch_summary = BatchSummary.from_results( run_results, workflow=workflow, batch_elapsed=batch_elapsed, ) renderer.render_batch_summary(batch_summary) _render_batch_tractography_stats( run_results, config_name=config, config_dir=config_dir, data_dir=data_dir, output_dir=output_dir, protocol=protocol or entry.default_protocol, ) if batch_summary.failed > 0 or batch_summary.blocked > 0: ctx.exit(1) else: # Single patient result = run_results[0] summary = RunSummary.from_result(result) renderer.render_run_summary(summary) if result.status in (RunStatus.FAILED, RunStatus.BLOCKED): ctx.exit(1) return # ----------------------------------------------------------------------- # Parallel execution via Nipype meta-workflow # # Each patient's Nipype workflow is added as a direct sub-workflow of # a single meta-workflow. This exposes *every* node (including GPU- # bound ones like ``synthseg`` and ``probtrackx2``) to the global # ``MultiProc`` scheduler, which: # # - Parallelises CPU-bound nodes across patients. # - Serialises GPU-bound nodes via ``n_gpu_procs`` (prevents OOM). # - Accounts for ``mem_gb`` per node to avoid memory oversubscription. # # Retries rely on Nipype's native result caching: re-running the # meta-workflow skips all previously successful nodes and only # re-executes failed ones. # ----------------------------------------------------------------------- bus.emit( f"Batch run: {total} patient(s) -- {', '.join(ids)}", level=EventLevel.IMPORTANT, category="batch", ) logger.info("Batch run started for {} patient(s): {}", total, ids) # Load config to get plugin settings cfg = ConfigManager(config_dir=config_dir).load_config( config_name=config, protocol=protocol or entry.default_protocol, overrides=cli_overrides or None, ) # Validate GPU once before building any patient workflows _resolve_gpu(cfg) bus.emit( f"GPU resolution: gpu_enabled={cfg.hardware.gpu_enabled}", level=EventLevel.INFO, category="hardware", ) plugin = cfg.nipype.plugin plugin_args: Dict[str, Any] = dict(cfg.nipype.plugin_args or {}) if max_workers is not None: plugin_args["n_procs"] = max_workers # Build-phase status: emit at IMPORTANT level so it surfaces even under # --no-progress (typical in SLURM/CI), where the multi-minute build phase # would otherwise produce no output. bus.emit( f"Building {total} patient workflow(s)...", level=EventLevel.IMPORTANT, category="build", ) meta_wf, built, build_failed = _build_batch_workflow( entry=entry, ids=ids, cfg=cfg, config_name=config, data_dir=data_dir, output_dir=output_dir, config_dir=config_dir, protocol=protocol, raw_data_dir=raw_data_dir, event_bus=bus, config_overrides=cli_overrides or None, ) bus.emit( f"Build complete: {len(built)}/{total} workflows ready", level=EventLevel.IMPORTANT, category="build", ) # ------------------------------------------------------------------------- # SLURM grouped-submission branch (patient batch path). # # When slurm.enabled, each built patient workflow is partitioned into a # linear chain of resource-homogeneous stage jobs and submitted via sbatch # instead of being run in-process. All slurm imports are lazy/local so the # disabled path is byte-identical to today. # ------------------------------------------------------------------------- if getattr(cfg, "slurm", None) is not None and cfg.slurm.enabled: renderer.detach() manifest = _run_slurm_batch( meta_wf=meta_wf, built=built, cfg=cfg, workflow_name=workflow, config_name=config, dry_run=dry_run, event_bus=bus, ) click.echo( f"SLURM submission {'planned' if dry_run else 'complete'}: " f"{len(manifest)} patient chain(s).", err=True, ) return if cfg.hardware.gpu_enabled: plugin_args.setdefault("n_gpu_procs", cfg.hardware.n_gpu_procs) plugin_args.setdefault("n_gpus", cfg.hardware.n_gpus) bus.emit( f"Workflows built: {len(built)}/{total} | " f"Plugin: {plugin} | n_procs: {plugin_args.get('n_procs', '?')} | " f"memory_gb: {plugin_args.get('memory_gb', '?')} | " f"n_gpu_procs: {plugin_args.get('n_gpu_procs', 'N/A')}", level=EventLevel.IMPORTANT, category="batch", ) if graph: graph_dir = Path(cfg.nipype.working_dir) / "batch" / "workflow_graphs" graph_dir.mkdir(parents=True, exist_ok=True) meta_wf.write_graph( graph2use="flat", format="png", simple_form=True, dotfilename=str(graph_dir / "batch_workflow"), ) bus.emit( f"Batch workflow graph written to: {graph_dir}", level=EventLevel.IMPORTANT, category="output", ) if dry_run: bus.emit( f"[DRY RUN] Meta-workflow built for {len(built)} patient(s) but not executed", level=EventLevel.IMPORTANT, category="workflow", ) # Build summary for dry run run_results_parallel: List[RunResult] = [] for pid in built: run_results_parallel.append( RunResult(patient_id=pid, workflow=workflow, status=RunStatus.SKIPPED) ) for pid in build_failed: run_results_parallel.append( RunResult( patient_id=pid, workflow=workflow, status=RunStatus.BLOCKED, error_message="Failed to build workflow", ) ) renderer.detach() batch_summary = BatchSummary.from_results( run_results_parallel, workflow=workflow, batch_elapsed=time.monotonic() - batch_start, ) renderer.render_batch_summary(batch_summary) if build_failed: ctx.exit(1) return # Execute with retry support. Nipype caches node results, so # re-running the same meta-workflow automatically skips successful # nodes and only re-executes failures. last_error: Optional[Exception] = None workflow_progress: Optional[ClickNodeProgress] = None total_nodes = count_workflow_nodes(meta_wf) if output_config.show_progress and total_nodes > 0: workflow_progress = ClickNodeProgress( total=total_nodes, label=f"Nodes for {workflow}", enabled=output_config.show_progress, ) workflow_progress.start() try: suppress_nipype_native_logging() # Apply cfg.nipype onto meta_wf AND the global nipype.config — this # path runs meta_wf.run() directly, bypassing NipypeExecutor, so # without this call hash_method="content" (and every other exec-level # setting) silently falls back to Nipype's package defaults. See # thesis.core.nipype.executor.apply_nipype_execution_config for why # the global-config update is load-bearing for file-input hashing. try: apply_nipype_execution_config(meta_wf, cfg.nipype) except Exception as cfg_exc: logger.warning(f"Could not apply Nipype config to meta-workflow: {cfg_exc}") status_callback = build_nipype_status_callback( progress=workflow_progress, event_bus=bus, ) plugin_args["status_callback"] = status_callback for attempt in range(retries + 1): if attempt > 0: logger.info( "Retry {}/{} -- re-running meta-workflow (cached nodes will be skipped)", attempt, retries, ) bus.emit( f"Retry {attempt}/{retries} -- " f"re-running (completed nodes cached, only failures re-execute)", level=EventLevel.WARNING, category="retry", ) time.sleep(5) # brief pause between retries try: meta_wf.run(plugin=plugin, plugin_args=plugin_args) last_error = None break # success except Exception as exc: last_error = exc logger.error("Batch attempt {}/{} failed: {}", attempt + 1, retries + 1, exc) # Attribute failed nodes to patients immediately so users see # which patients were affected (and that this is an execution- # phase failure, distinct from the build phase). failed_by_pid = _attribute_nodes_to_patients(status_callback.failed_nodes, built) attributed = ", ".join(sorted(failed_by_pid)) or "unattributed" bus.emit( f"Batch execution failed (attempt {attempt + 1}/{retries + 1}). " f"Failed nodes attributed to: {attributed}", level=EventLevel.ERROR, category="batch", ) finally: if workflow_progress: workflow_progress.stop() # Build batch summary renderer.detach() batch_elapsed = time.monotonic() - batch_start run_results_parallel = [] if last_error is None: for pid in built: res = RunResult(patient_id=pid, workflow=workflow, status=RunStatus.SUCCESS) try: patient_ctx = create_context( patient_id=pid, config=cfg, data_dir=data_dir, output_dir=output_dir ) _enrich_with_tract_similarity_metrics(res, cfg, patient_ctx) except Exception as _exc: logger.debug("tract_similarity enrichment skipped for {}: {}", pid, _exc) run_results_parallel.append(res) else: # Attribute failed/finished Nipype nodes back to specific patients so the # summary reflects which patients actually failed -- not "all of them". failed_by_pid = _attribute_nodes_to_patients(status_callback.failed_nodes, built) ok_by_pid = _attribute_nodes_to_patients(status_callback.finished_ok_nodes, built) classified = _classify_batch_failure( built=built, workflow=workflow, failed_by_pid=failed_by_pid, ok_by_pid=ok_by_pid, last_error=last_error, ) for res in classified: if res.status == RunStatus.SUCCESS: try: patient_ctx = create_context( patient_id=res.patient_id, config=cfg, data_dir=data_dir, output_dir=output_dir, ) _enrich_with_tract_similarity_metrics(res, cfg, patient_ctx) except Exception as _exc: logger.debug( "tract_similarity enrichment skipped for {}: {}", res.patient_id, _exc ) run_results_parallel.append(res) for pid in build_failed: run_results_parallel.append( RunResult( patient_id=pid, workflow=workflow, status=RunStatus.BLOCKED, error_message="Failed to build workflow", ) ) batch_summary = BatchSummary.from_results( run_results_parallel, workflow=workflow, batch_elapsed=batch_elapsed, ) batch_summary.retries = retries renderer.render_batch_summary(batch_summary) _render_batch_tractography_stats( run_results_parallel, config_name=config, config_dir=config_dir, data_dir=data_dir, output_dir=output_dir, protocol=protocol or entry.default_protocol, ) if last_error is not None or build_failed: ctx.exit(1) @main.command("list-workflows") @click.option( "--config", "-c", type=str, default="default", show_default=True, help="Configuration name to read paths.scripts_dir from.", ) @click.option( "--config-dir", type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var] default="config", show_default=True, help="Configuration directory", ) def list_workflows(config: str, config_dir: Path) -> None: """ List all available registered workflows. Auto-discovers and imports all ``thesis.workflows.*`` submodules to trigger self-registration. If ``paths.scripts_dir`` is set in the selected config, also scans that directory for user-supplied scripts (``*.py``, top-level, names not starting with ``_``). Example: thesis list-workflows thesis list-workflows -c my_config """ import thesis.workflows as _wf_pkg # Auto-import every submodule/subpackage to trigger registrations for _finder, name, ispkg in pkgutil.iter_modules(_wf_pkg.__path__): _ensure_workflow_imported(name, is_package=ispkg) builtin_names = set(WORKFLOW_REGISTRY.list()) user_entries: List[Tuple[str, Path]] = [] user_failures: List[Tuple[Path, str]] = [] scripts_dir: Optional[Path] = None try: manager = ConfigManager(config_dir=config_dir) cfg = manager.load_config(config_name=config) scripts_dir = cfg.paths.scripts_dir except Exception as exc: # pragma: no cover - config read is best-effort here logger.debug("list-workflows could not read scripts_dir from config: {}", exc) if scripts_dir is not None: from thesis.core.loader import discover_workflows_dir, load_workflow_script if not scripts_dir.is_absolute(): scripts_dir = Path.cwd() / scripts_dir for script_path in discover_workflows_dir(scripts_dir): try: loaded = load_workflow_script(script_path) if loaded.name not in builtin_names: user_entries.append((loaded.name, script_path)) except click.ClickException as exc: user_failures.append((script_path, exc.message)) except Exception as exc: # pragma: no cover - defensive user_failures.append((script_path, f"{type(exc).__name__}: {exc}")) entries = WORKFLOW_REGISTRY.all_entries() if not entries: click.echo("No workflows registered.") else: click.echo("Available workflows:") user_paths = dict(user_entries) for e in entries: proto_str = f" (protocol: {e.default_protocol})" if e.default_protocol else "" suffix = "" if e.name in user_paths: suffix = f" (user script: {user_paths[e.name]})" click.echo(f" {e.name:<20} {e.description}{proto_str}{suffix}") if user_failures: click.echo("") click.echo("User scripts that failed to load:") for path, msg in user_failures: click.echo(f" {path}: {msg}") @main.command() def info() -> None: """ Display system and framework information. """ import platform import sys from importlib.metadata import PackageNotFoundError from importlib.metadata import version as package_version click.echo("=" * 60) click.echo("Thesis Framework Information") click.echo("=" * 60) click.echo(f"Version: {__version__}") click.echo(f"Python: {sys.version.split()[0]}") click.echo(f"Platform: {platform.system()} {platform.release()}") click.echo(f"Architecture: {platform.machine()}") click.echo() # Check for dependencies deps = { "NumPy": ("numpy", "numpy"), "SciPy": ("scipy", "scipy"), "NiBabel": ("nibabel", "nibabel"), "PyYAML": ("yaml", "PyYAML"), "Pydantic": ("pydantic", "pydantic"), "Click": ("click", "click"), "Loguru": ("loguru", "loguru"), } click.echo("Dependencies:") for name, (module, distribution) in deps.items(): try: __import__(module) try: version = package_version(distribution) except PackageNotFoundError: version = "unknown" click.echo(f" [OK] {name}: {version}") except ImportError: click.echo(f" [--] {name}: not installed") click.echo() # Check for external tools import shutil external_tools = { "FSL": ["fsl", "probtrackx2"], "ANTs": ["antsRegistration"], } click.echo("External Tools:") for tool_name, commands in external_tools.items(): found = any(shutil.which(cmd) for cmd in commands) status = "[OK]" if found else "[--]" click.echo(f" {status} {tool_name}: {'available' if found else 'not found'}") @main.group() def stats() -> None: """Inspect and export batch tractography statistics. Run ``thesis stats collect -o <dir>`` to extract stats from an output directory. """ @stats.command("collect") @click.option( "-o", "--output-dir", required=True, type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var] help="Root directory containing per-subject output subdirectories.", ) @click.option( "-p", "--patient-id", "patient_ids", multiple=True, help="Limit to these patient IDs (repeat for multiple). " "Default: auto-discover all subjects.", ) @click.option( "--out", "output_file", type=click.Path(dir_okay=False, path_type=Path), # type: ignore[type-var] default=None, help="Destination JSON file. " "Default: <output-dir>/batch_stats/stats_<timestamp>.json.", ) @click.option( "--sd-threshold", type=float, default=2.0, show_default=True, help="Outlier detection threshold in standard deviations.", ) @click.option( "--tractography-relpath", "tractography_relpath", type=str, default="tractography/probtrackx2", show_default=True, help=( "Relative path under each patient directory where the tractography " "run lives. Use 'tractography/mrtrix3' for the MRtrix3 backend." ), ) @click.option( "--tract-similarity-subdir", "tract_similarity_subdir", type=str, default="tract_similarity", show_default=True, help=( "Relative path under each patient directory where " "tract_similarity / tract_similarity_hcp_loo wrote metrics.json. " "Attached under subject['tract_similarity'] when the file exists." ), ) def stats_collect( output_dir: Path, patient_ids: Tuple[str, ...], output_file: Optional[Path], sd_threshold: float, tractography_relpath: str, tract_similarity_subdir: str, ) -> None: """Collect tractography stats from an existing output dir and write JSON. Reuses the same collectors as the post-run summary (``collect_batch_stats`` + ``detect_batch_outliers``) so the emitted schema matches what future runs write automatically. When a subject has a ``<tract_similarity_subdir>/metrics.json`` file (produced by the ``tract_similarity`` or ``tract_similarity_hcp_loo`` workflows), its contents are merged into that subject's record under a ``tract_similarity`` key — no separate file or post-processing needed. Example: thesis stats collect -o outputs/ --out /tmp/atlas_stats.json thesis stats collect -o outputs/ --tractography-relpath tractography/mrtrix3 """ from thesis.workflows.qc.checks import detect_batch_outliers from thesis.workflows.qc.statistics import collect_batch_stats pids = list(patient_ids) if patient_ids else None subjects = collect_batch_stats( output_dir, patient_ids=pids, tractography_relpath=tractography_relpath, tract_similarity_subdir=tract_similarity_subdir, ) if not subjects: raise click.ClickException(f"No subjects with tractography output found under {output_dir}") outliers: List[Dict[str, Any]] = [] if len(subjects) >= 3: outliers = detect_batch_outliers(subjects, sd_threshold=sd_threshold) if output_file is None: stats_dir = output_dir / "batch_stats" stats_dir.mkdir(exist_ok=True) output_file = stats_dir / f"stats_{time.strftime('%Y%m%d_%H%M%S')}.json" else: output_file.parent.mkdir(parents=True, exist_ok=True) payload = {"subjects": subjects, "outliers": outliers} output_file.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8") click.echo(f"Wrote {output_file}") click.echo(f" subjects: {len(subjects)}") click.echo(f" outliers: {len(outliers)}") @main.group() def slurm() -> None: """Inspect and retry grouped SLURM stage submissions.""" def _collect_job_ids(manifest: Dict[str, List[Dict[str, Any]]]) -> List[str]: """Return every non-null ``job_id`` recorded in a manifest.""" return [ str(entry["job_id"]) for entries in manifest.values() for entry in entries if entry.get("job_id") is not None ] @slurm.command("status") @click.option( "--manifest", "manifest_path", required=True, type=click.Path(path_type=Path), # type: ignore[type-var] help="Path to manifest.json (or a submit_dir containing one).", ) def slurm_status(manifest_path: Path) -> None: """Show per-patient / per-stage SLURM states from a submission manifest. Reads the durable manifest written at submit time, queries ``sacct`` for the current state of every recorded job, and prints a per-``(patient, stage)`` status table. Stages whose ``afterok`` predecessor failed are surfaced as ``BLOCKED``. Example: thesis slurm status --manifest $SCRATCH/thesis_slurm/manifest.json """ from thesis.core.slurm.monitor import query_states, read_manifest, summarize try: manifest = read_manifest(manifest_path) except (FileNotFoundError, ValueError) as exc: raise click.ClickException(str(exc)) states = query_states(_collect_job_ids(manifest)) summary = summarize(manifest, states) for pid, rows in summary.items(): click.echo(f"{pid}:") for row in rows: job = row["job_id"] if row["job_id"] is not None else "-" click.echo(f" {row['stage']:<24} {row['compute']:<3} " f"job={job:<10} {row['state']}") @slurm.command("resubmit") @click.option( "--manifest", "manifest_path", required=True, type=click.Path(path_type=Path), # type: ignore[type-var] help="Path to manifest.json (or a submit_dir containing one).", ) @click.option( "-c", "--config", "config_name", required=True, help="Configuration name providing slurm.* settings.", ) @click.option( "--config-dir", type=click.Path(path_type=Path), # type: ignore[type-var] default="config", help="Configuration directory.", ) @click.option( "--retries", type=int, default=1, show_default=True, help="Maximum resubmission passes for still-failing stages.", ) @click.option( "--dry-run", is_flag=True, default=False, help="Select failed stages and write the plan, but do not call sbatch.", ) def slurm_resubmit( manifest_path: Path, config_name: str, config_dir: Path, retries: int, dry_run: bool, ) -> None: """Re-submit only failed / timed-out / OOM stages, re-chaining dependents. Reconciles the manifest against ``sacct``, selects stages in ``FAILED`` / ``TIMEOUT`` / ``OUT_OF_MEMORY``, and re-submits them (and their downstream dependents) via the Phase D submitter with refreshed ``afterok`` chains. Nipype on-disk caching skips already-completed nodes. Example: thesis slurm resubmit --manifest $SCRATCH/thesis_slurm/manifest.json -c full_pipeline """ from thesis.core.slurm.monitor import query_states, read_manifest, select_resubmit from thesis.core.slurm.submitter import resubmit_jobs try: manifest = read_manifest(manifest_path) except (FileNotFoundError, ValueError) as exc: raise click.ClickException(str(exc)) config_manager = ConfigManager(config_dir=config_dir) cfg = config_manager.load_config(config_name=config_name) slurm_cfg = getattr(cfg, "slurm", None) if slurm_cfg is None or not slurm_cfg.enabled: raise click.ClickException( f"Config '{config_name}' does not enable slurm (slurm.enabled is false)." ) passes = max(1, retries) any_selected = False for attempt in range(1, passes + 1): states = query_states(_collect_job_ids(manifest)) to_resubmit = select_resubmit(manifest, states) if not to_resubmit: if not any_selected: click.echo("No failed/timeout/oom stages to resubmit.") break any_selected = True total = sum(len(v) for v in to_resubmit.values()) click.echo( f"Pass {attempt}/{passes}: resubmitting {total} stage(s) " f"across {len(to_resubmit)} patient(s)." ) manifest = resubmit_jobs(manifest, to_resubmit, slurm_cfg, dry_run=dry_run) if dry_run: # No new states to observe in a dry run; one planning pass is enough. break if __name__ == "__main__": main()