"""
Command-line interface for the thesis framework.
Provides CLI commands for running custom Nipype workflows.
"""
import collections
import contextlib
import fnmatch
import importlib
import inspect
import json
import pkgutil
import re
import sys
import time
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import click
import yaml
from thesis import __version__
from thesis.core.config import ConfigManager, PipelineConfig, merge_configs
from thesis.core.context import create_context
from thesis.core.exceptions import ConfigurationError, ProcessingError
from thesis.core.logging import (
get_logger,
set_console_level,
setup_logging,
suppress_nipype_native_logging,
suspend_console_logging,
)
from thesis.core.nipype import run_workflow
from thesis.core.nipype.executor import (
apply_nipype_execution_config,
build_nipype_status_callback,
count_workflow_nodes,
)
from thesis.core.output import (
BatchSummary,
EventLevel,
OutputConfig,
OutputMode,
OutputRenderer,
RunResult,
RunStatus,
RunSummary,
SummaryDetail,
get_event_bus,
reset_event_bus,
)
from thesis.core.registry import WORKFLOW_REGISTRY, WorkflowEntry
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Click-based progress wrapper for Nipype node callbacks
# ---------------------------------------------------------------------------
[docs]
class ClickNodeProgress:
"""Simple wrapper for node progress using Click's progressbar.
This class provides the same interface as WorkflowProgress (node_started,
node_finished, start, stop) so it can be used with Nipype's status callback
mechanism without changing executor.py.
"""
[docs]
def __init__(self, total: int, label: str, enabled: bool = True):
self.total = total
self.label = label
self.enabled = enabled
self._current_node: Optional[str] = None
self._bar: Optional[Any] = None
self._completed = 0
self._ctx: Optional[Any] = None
[docs]
def start(self) -> None:
if not self.enabled or self.total <= 0:
return
# Create the progressbar instance
self._ctx = click.progressbar(
length=self.total,
label=self.label,
item_show_func=lambda x: self._current_node or "",
show_pos=True,
show_eta=True,
file=sys.stderr,
color=True,
width=40,
)
self._log_cm = suspend_console_logging()
self._log_cm.__enter__()
self._bar = self._ctx.__enter__()
[docs]
def stop(self) -> None:
if self._ctx is not None:
self._ctx.__exit__(None, None, None)
self._ctx = None
self._bar = None
if hasattr(self, "_log_cm"):
self._log_cm.__exit__(None, None, None)
[docs]
def node_started(self, node_name: str) -> None:
self._current_node = node_name
if self._bar is not None:
self._bar.render_progress() # Force redraw with new label
[docs]
def node_finished(self, node_name: str, status: str = "") -> None:
self._completed += 1
if self._bar is not None:
self._bar.update(1, self._current_node)
# ---------------------------------------------------------------------------
# GPU startup check
# ---------------------------------------------------------------------------
# These module-level variables ensure the GPU check (and its interactive
# prompt) runs at most once per process, regardless of how many patients
# are processed in sequence.
_gpu_checked: bool = False
_gpu_available: bool = False
def _resolve_gpu(cfg: PipelineConfig) -> None:
"""
Validate GPU availability and update ``cfg.hardware.gpu_enabled`` in place.
Called once per process immediately after the config is loaded. On the
first call it runs :func:`thesis.core.gpu.check_gpu`, logs the outcome,
and — if the GPU was requested but is not available — prompts the user:
* On interactive terminals: asks whether to fall back to CPU or abort.
* On non-interactive terminals (e.g. SLURM batch scripts): automatically
falls back to CPU and logs a warning without blocking.
Subsequent calls (e.g. when processing multiple patients in one run)
use the cached result from the first call.
Args:
cfg: Loaded ``PipelineConfig``; ``cfg.hardware.gpu_enabled`` is set to
``False`` in place when falling back to CPU.
"""
global _gpu_checked, _gpu_available
if not cfg.hardware.gpu_enabled:
return # GPU not requested — nothing to check
if not _gpu_checked:
import sys
from thesis.core.gpu import check_gpu
status = check_gpu()
_gpu_available = status.available
_gpu_checked = True
if status.available:
logger.info("GPU acceleration available: {}", status.reason)
else:
logger.warning("GPU requested but not available: {}", status.reason)
if sys.stdin.isatty():
click.echo(
f"\n [WARN] GPU acceleration requested but not available:\n"
f" {status.reason}\n",
err=True,
)
proceed = click.confirm(
" Continue with CPU processing instead?",
default=True,
)
click.echo("", err=True) # blank line for readability
if not proceed:
raise click.Abort()
else:
logger.warning(
"Non-interactive terminal detected; "
"falling back to CPU processing automatically."
)
if not _gpu_available:
cfg.hardware.gpu_enabled = False
def _discover_patient_ids(
data_dir: Path,
config_dir: Path,
config_name: str = "default",
protocol: Optional[str] = None,
raw_data_dir_override: Optional[Path] = None,
name_pattern: Optional[str] = None,
) -> List[str]:
"""Scan *inputs_dir* for sub-folders whose names consist only of digits.
``inputs_dir`` is an independent project-level root (it is *not* resolved
under *data_dir*), matching how :class:`ProcessingContext` resolves the
per-patient input directory. *data_dir* (the shared assets base) is
accepted for signature compatibility but does not affect input discovery.
Returns a sorted list of patient-ID strings.
"""
config_manager = ConfigManager(config_dir=config_dir)
overrides = None
if raw_data_dir_override:
overrides = {"paths": {"inputs_dir": str(raw_data_dir_override)}}
cfg = config_manager.load_config(
config_name=config_name,
protocol=protocol,
overrides=overrides,
protocol_required=protocol is not None,
)
raw = Path(cfg.paths.inputs_dir)
search_dir = raw if raw.is_absolute() else Path.cwd() / raw
if not search_dir.is_dir():
resolved = search_dir.absolute()
raise click.ClickException(
f"inputs_dir does not exist or is not a directory: {search_dir} "
f"(resolved to {resolved})"
)
def _accept(name: str) -> bool:
if name_pattern:
return fnmatch.fnmatch(name, name_pattern)
return name.isdigit()
ids = sorted(
entry.name for entry in search_dir.iterdir() if entry.is_dir() and _accept(entry.name)
)
if not ids:
if name_pattern:
raise click.ClickException(
f"No patient folders matching pattern '{name_pattern}' found in {search_dir}"
)
else:
raise click.ClickException(f"No numeric patient folders found in {search_dir}")
return ids
_VALID_PATIENT_ID_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _validate_explicit_patient_ids(
ids: List[str],
config_dir: Path,
config_name: str,
protocol: Optional[str],
raw_data_dir_override: Optional[Path],
data_dir: Path,
) -> None:
"""Validate explicitly-supplied ``-p`` patient IDs before any setup work.
Two checks are performed:
* Format: each ID must contain only ``[A-Za-z0-9._-]``; an offending ID
raises :class:`click.UsageError` naming it.
* Existence: each ID's input directory must exist under the resolved
``inputs_dir``. A missing ID raises :class:`click.ClickException`
listing the available IDs. Existence is best-effort: if ``inputs_dir``
cannot be resolved, the existence check is skipped rather than crashing.
Args:
ids: Explicit patient identifiers from ``-p/--patient-id``.
config_dir: Configuration directory.
config_name: Base config name (``-c``).
protocol: Effective protocol name (or ``None``).
raw_data_dir_override: Optional ``--raw-data-dir`` override (sets inputs_dir).
data_dir: Resolved shared assets base (unused for input discovery).
Raises:
click.UsageError: If any ID has an invalid format.
click.ClickException: If any (well-formed) ID is missing under
``inputs_dir``.
"""
bad = [pid for pid in ids if not _VALID_PATIENT_ID_RE.match(pid)]
if bad:
raise click.UsageError(
"Invalid patient ID(s) "
+ ", ".join(repr(b) for b in bad)
+ "; only characters [A-Za-z0-9._-] are allowed."
)
try:
config_manager = ConfigManager(config_dir=config_dir)
overrides = None
if raw_data_dir_override:
overrides = {"paths": {"inputs_dir": str(raw_data_dir_override)}}
cfg = config_manager.load_config(
config_name=config_name,
protocol=protocol,
overrides=overrides,
protocol_required=protocol is not None,
)
raw = Path(cfg.paths.inputs_dir)
search_dir: Optional[Path] = raw if raw.is_absolute() else Path.cwd() / raw
except Exception as exc: # noqa: BLE001 — best-effort existence check
logger.debug("Skipping patient-ID existence check (inputs_dir unresolved): {}", exc)
return
if search_dir is None or not search_dir.is_dir():
logger.debug("Skipping patient-ID existence check: inputs_dir not a directory")
return
missing = [pid for pid in ids if not (search_dir / pid).is_dir()]
if missing:
available = sorted(entry.name for entry in search_dir.iterdir() if entry.is_dir())
raise click.ClickException(
f"Patient '{missing[0]}' not found under inputs_dir; available: {available}"
)
def _ensure_workflow_imported(name: str, is_package: bool = True) -> None:
"""Import workflow modules so their self-registration side effects run."""
try:
importlib.import_module(f"thesis.workflows.{name}")
except ImportError as exc:
logger.debug("Workflow package import failed for {}: {}", name, exc)
if is_package:
try:
importlib.import_module(f"thesis.workflows.{name}.workflow")
except (ImportError, ModuleNotFoundError) as exc:
logger.debug("Workflow module import failed for {}: {}", name, exc)
def _resolve_workflow(name: str) -> WorkflowEntry:
"""Resolve a workflow name to its :class:`WorkflowEntry`.
Attempts to import ``thesis.workflows.<name>`` (and for packages also
``thesis.workflows.<name>.workflow``) so that self-registration is
triggered even if the module has not been imported yet.
Args:
name: Short workflow name (e.g. ``"hcp"``).
Returns:
The matching :class:`WorkflowEntry`.
Raises:
click.ClickException: If the name is not found after auto-import.
"""
_ensure_workflow_imported(name)
try:
return WORKFLOW_REGISTRY.get(name)
except KeyError:
known = WORKFLOW_REGISTRY.list()
known_str = ", ".join(known) if known else "<none>"
raise click.ClickException(
f"Unknown workflow: '{name}'. Known workflows: {known_str}.\n"
f"Run 'thesis list-workflows' to see all available workflows."
)
def _build_workflow(
factory: Callable[..., Any],
config: PipelineConfig,
context: Any,
) -> Any:
"""Call workflow factory with a flexible signature."""
signature = inspect.signature(factory)
positional = [
param
for param in signature.parameters.values()
if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD)
]
if len(positional) >= 2:
return factory(config, context)
if len(positional) == 1:
return factory(config)
return factory()
def _resolve_protocol_override(
config_manager: ConfigManager,
patient_id: str,
protocol: Optional[str],
entry: WorkflowEntry,
) -> str:
"""Resolve the effective protocol for a patient run."""
temp_cfg_dict = config_manager.load_config_dict(patient_id, subdir="patients")
protocol_override = (
protocol
or (temp_cfg_dict.get("protocol") if temp_cfg_dict else None)
or entry.default_protocol
)
if protocol_override:
return str(protocol_override)
raise click.ClickException(
f"No protocol for workflow '{entry.name}' and patient '{patient_id}'. "
f"Tried: --protocol ({'set' if protocol else 'not set'}), "
f"config/patients/{patient_id}.yaml (no 'protocol' key), "
f"workflow default ({entry.default_protocol or 'not configured'}). "
f"Provide --protocol or add a 'protocol' key to a config file."
)
def _resolve_data_dir(
data_dir: Optional[Path],
config_dir: Path,
config_name: str,
protocol: Optional[str],
) -> Path:
"""Resolve the effective shared assets directory for a run.
When ``--data-dir`` is passed explicitly it wins. Otherwise we fall back to
``config.paths.assets_dir`` (the shared templates/atlases/ROIs base), and
finally to the legacy ``"data"`` default if neither is set or the config
cannot be loaded here.
Args:
data_dir: The ``--data-dir`` CLI value, or ``None`` when not passed.
config_dir: Configuration directory.
config_name: Base config name (``-c``).
protocol: Effective protocol name (``--protocol`` or the workflow's
``default_protocol``).
Returns:
The resolved base data directory.
"""
if data_dir is not None:
return data_dir
try:
cfg = ConfigManager(config_dir=config_dir).load_config(
config_name=config_name, protocol=protocol
)
configured = getattr(getattr(cfg, "paths", None), "assets_dir", None)
except Exception as exc: # noqa: BLE001 — best-effort fallback to legacy default
logger.debug(
"Could not read paths.assets_dir for --data-dir default ({}); using 'data'", exc
)
configured = None
return Path(str(configured)) if configured else Path("data")
def _build_patient_workflow(
entry: WorkflowEntry,
patient_id: str,
config_name: str,
config_dir: Path,
data_dir: Path,
output_dir: Optional[Path],
protocol: Optional[str],
raw_data_dir_override: Optional[Path],
config_overrides: Optional[Dict[str, Any]] = None,
) -> Any:
"""Build a single patient's Nipype workflow for inclusion in a batch.
This creates the patient's config, context, runs preflight verification,
and calls the workflow factory. The returned Nipype Workflow exposes its
full node graph to the parent meta-workflow scheduler, allowing proper
GPU serialisation and memory accounting.
Args:
entry: Resolved WorkflowEntry.
patient_id: Patient identifier string.
config_name: Configuration name to load.
config_dir: Path to config directory.
data_dir: Base data directory.
output_dir: Optional output directory override.
protocol: Protocol name override.
raw_data_dir_override: Optional raw data dir override.
config_overrides: Optional CLI-level config overrides (e.g. hemisphere).
Returns:
A Nipype Workflow object (or raises on failure).
Raises:
click.ClickException: On verification failure or missing protocol.
RuntimeError: If the factory does not return a valid workflow.
"""
config_manager = ConfigManager(config_dir=config_dir)
protocol_override = _resolve_protocol_override(config_manager, patient_id, protocol, entry)
overrides: Dict[str, Any] = {}
if raw_data_dir_override:
overrides["paths"] = {"inputs_dir": str(raw_data_dir_override)}
if config_overrides:
overrides = merge_configs(overrides, config_overrides)
cfg = config_manager.load_config(
config_name=config_name,
patient_id=patient_id,
protocol=protocol_override,
overrides=overrides or None,
protocol_required=protocol is not None,
)
_resolve_gpu(cfg)
ctx_obj = create_context(
patient_id=patient_id,
config=cfg,
data_dir=data_dir,
output_dir=output_dir,
)
# Preflight verification
if callable(entry.verifier):
errors = entry.verifier(cfg, ctx_obj)
if errors:
msg = f"Preflight verification failed for {patient_id}:\n"
msg += "\n".join(f" - {e}" for e in errors)
raise click.ClickException(msg)
wf = _build_workflow(entry.factory, cfg, ctx_obj)
if not hasattr(wf, "run"):
raise RuntimeError(f"Workflow factory did not return a Nipype Workflow for {patient_id}.")
return wf
def _split_patient_ids(raw: Iterable[str]) -> Tuple[str, ...]:
"""Expand comma-separated ``--patient-id`` values into a flat tuple.
Accepts repeated flags (``-p P001 -p P002``), comma-separated values
(``-p P001,P002``), or any mix (``-p P001,P002 -p P003``). Whitespace around
each id is stripped and empty entries are dropped. Order is preserved and
duplicates are removed (first occurrence wins).
Args:
raw: The raw ``patient_ids`` tuple as collected by Click.
Returns:
A de-duplicated tuple of individual patient identifiers.
"""
seen: set[str] = set()
out: List[str] = []
for value in raw:
for pid in value.split(","):
pid = pid.strip()
if pid and pid not in seen:
seen.add(pid)
out.append(pid)
return tuple(out)
def _attribute_nodes_to_patients(
node_names: Iterable[str],
patient_ids: Iterable[str],
) -> Dict[str, List[str]]:
"""Map Nipype node fullnames back to the patient IDs they belong to.
In the batched meta-workflow each patient is wrapped in a sub-workflow
whose name embeds the patient ID (e.g. ``tract_synthseg_818455``), so a
node fullname like ``batch_tractography.tract_synthseg_818455.hcp_818455.
probtrackx2.a0`` is attributable to patient ``818455``. The match is
anchored at alphanumeric word boundaries so a shorter ID does not match
inside a longer one (``8184`` vs ``818455``); when both could match,
the longer ID wins.
Args:
node_names: Fullnames of nodes whose attribution is wanted.
patient_ids: Patient identifiers that were added to the batch.
Returns:
Mapping from patient ID to the list of node fullnames attributed
to that patient. Patients with no matching nodes are omitted.
"""
sorted_ids = sorted({pid for pid in patient_ids if pid}, key=len, reverse=True)
patterns = [
(pid, re.compile(rf"(?<![0-9A-Za-z]){re.escape(pid)}(?![0-9A-Za-z])")) for pid in sorted_ids
]
attributed: Dict[str, List[str]] = {}
for node_name in node_names:
matched = False
for pid, pattern in patterns:
if pattern.search(node_name):
attributed.setdefault(pid, []).append(node_name)
matched = True
break
if not matched:
logger.warning("Could not attribute node {!r} to any patient", node_name)
return attributed
def _classify_batch_failure(
*,
built: Iterable[str],
workflow: str,
failed_by_pid: Dict[str, List[str]],
ok_by_pid: Dict[str, List[str]],
last_error: Optional[BaseException],
) -> List[RunResult]:
"""Build per-patient RunResults from a failed batch's per-node observations.
Pure (no side effects) so the CLI's failure path can be exercised without
standing up a real Nipype workflow. The returned ``RunResult`` for a
SUCCESS patient is *not* tract-similarity-enriched -- the caller does that
separately because enrichment needs config/context the helper shouldn't
own.
Args:
built: Patient IDs whose workflows were successfully added to the batch.
workflow: Workflow name to stamp onto each result.
failed_by_pid: Mapping from patient ID to fullnames of nodes that
failed for that patient (from ``_attribute_nodes_to_patients``).
ok_by_pid: Same shape, for nodes that finished successfully.
last_error: Exception raised by the underlying ``meta_wf.run()``.
Returns:
One ``RunResult`` per patient in ``built`` (preserving order):
- **FAILED** if the patient has any entry in ``failed_by_pid``.
``error_message`` lists the failed node names plus the batch error.
- **SUCCESS** if the patient has at least one entry in ``ok_by_pid``
and none in ``failed_by_pid``.
- **BLOCKED** if the callback observed neither -- the scheduler likely
aborted before this patient ran.
If ``failed_by_pid`` is empty, every patient is marked **FAILED** with
``str(last_error)``: the callback recorded no per-node failures, so
the safe fallback is to attribute the batch error to everyone rather
than silently call them SUCCESS.
"""
results: List[RunResult] = []
if not failed_by_pid:
for pid in built:
results.append(
RunResult(
patient_id=pid,
workflow=workflow,
status=RunStatus.FAILED,
error_message=str(last_error),
)
)
return results
for pid in built:
if pid in failed_by_pid:
failed_nodes_str = ", ".join(sorted(failed_by_pid[pid]))
results.append(
RunResult(
patient_id=pid,
workflow=workflow,
status=RunStatus.FAILED,
error_message=(f"Failed nodes: {failed_nodes_str} (batch error: {last_error})"),
)
)
elif pid in ok_by_pid:
results.append(RunResult(patient_id=pid, workflow=workflow, status=RunStatus.SUCCESS))
else:
results.append(
RunResult(
patient_id=pid,
workflow=workflow,
status=RunStatus.BLOCKED,
error_message="Workflow aborted before this patient ran.",
)
)
return results
def _build_batch_workflow(
entry: WorkflowEntry,
ids: list[str],
cfg: PipelineConfig,
config_name: str,
data_dir: Path,
output_dir: Optional[Path],
config_dir: Path,
protocol: Optional[str],
raw_data_dir: Optional[Path],
event_bus: Optional[Any] = None,
config_overrides: Optional[Dict[str, Any]] = None,
) -> Any:
"""Build a flat Nipype meta-workflow containing all patient sub-workflows.
Instead of wrapping each patient in an opaque ``Function`` node, this
function constructs each patient's real Nipype workflow and adds it as a
direct sub-workflow of the meta-workflow. This exposes every node
(including GPU-bound ones like ``synthseg`` and ``probtrackx2``) to
the global ``MultiProc`` scheduler, which can then:
* Parallelise CPU-bound nodes across patients.
* Serialise GPU-bound nodes via ``n_gpu_procs`` (preventing OOM).
* Account for ``mem_gb`` per node to avoid memory oversubscription.
Args:
entry: Resolved WorkflowEntry.
ids: List of patient ID strings.
cfg: Base PipelineConfig (used for working_dir resolution).
config_name: Configuration name to load per patient.
data_dir: Base data directory.
output_dir: Optional output directory override.
config_dir: Path to config directory.
protocol: Protocol name override.
raw_data_dir: Optional raw data dir override.
event_bus: Optional event bus for emitting build status events.
config_overrides: Optional CLI-level config overrides (e.g. hemisphere).
Returns:
Nipype Workflow containing all patient sub-workflows.
"""
from nipype import Workflow
bus = event_bus or get_event_bus()
meta_wf = Workflow(name="batch_tractography")
work_dir = Path(cfg.nipype.working_dir) / "batch"
work_dir.mkdir(parents=True, exist_ok=True)
meta_wf.base_dir = str(work_dir)
built: List[str] = []
failed: List[str] = []
failed_reasons: Dict[str, str] = {}
for pid in ids:
try:
patient_wf = _build_patient_workflow(
entry=entry,
patient_id=pid,
config_name=config_name,
config_dir=config_dir,
data_dir=data_dir,
output_dir=output_dir,
protocol=protocol,
raw_data_dir_override=raw_data_dir,
config_overrides=config_overrides,
)
meta_wf.add_nodes([patient_wf])
built.append(pid)
bus.emit(
f"Workflow built for {pid}",
level=EventLevel.INFO,
category="build",
patient_id=pid,
)
except Exception as exc:
logger.error("Failed to build workflow for {}: {}", pid, exc, exc_info=True)
bus.emit(
f"Skipping {pid}: {type(exc).__name__}: {exc}",
level=EventLevel.ERROR,
category="build",
patient_id=pid,
)
failed.append(pid)
failed_reasons[pid] = f"{type(exc).__name__}: {exc}"
if not built:
base_msg = (
"No patient workflows could be built. " f"All {len(ids)} patient(s) failed to build."
)
reasons = list(failed_reasons.values())
if reasons and len(set(reasons)) == 1:
common_reason = reasons[0]
raise click.ClickException(
f"{base_msg} Common error: {common_reason}. "
"Check logs or rerun with -v for per-patient detail."
)
if reasons:
counter = collections.Counter(reasons)
common_reason, count = counter.most_common(1)[0]
if count > 1:
raise click.ClickException(
f"{base_msg} Common error: {common_reason}. "
"Check logs or rerun with -v for per-patient detail."
)
raise click.ClickException(
f"{base_msg} Reasons vary across patients; rerun with -v to see "
"per-patient errors (e.g. thesis run -v ...)."
)
if failed:
bus.emit(
f"Built {len(built)}/{len(ids)} patient workflows "
f"({len(failed)} skipped: {', '.join(failed)})",
level=EventLevel.WARNING,
category="build",
)
return meta_wf, built, failed
def _run_slurm_batch(
meta_wf: Any,
built: List[str],
cfg: PipelineConfig,
*,
workflow_name: str,
config_name: str,
dry_run: bool,
event_bus: Optional[Any] = None,
) -> Dict[str, Any]:
"""Partition each built patient workflow and submit grouped SLURM stage jobs.
Each top-level patient sub-workflow inside ``meta_wf`` is partitioned into a
linear chain of resource-homogeneous stage groups and submitted via
``sbatch`` with an ``afterok`` dependency chain per patient. All SLURM
modules are imported locally here so the disabled path never touches them.
Args:
meta_wf: The batch meta-workflow whose direct children are the per-patient
built workflows.
built: Patient ids successfully built (mapped onto ``meta_wf`` children
by name suffix).
cfg: Loaded pipeline config (``cfg.slurm`` must be enabled).
workflow_name: Registered workflow name (``-w``).
config_name: Configuration name (``-c``).
dry_run: When ``True``, scripts + planned manifest are written but no
``sbatch`` is invoked.
event_bus: Optional event bus for status events.
Returns:
The submission manifest dict from
:func:`thesis.core.slurm.submitter.submit_jobs`.
"""
from thesis.core.slurm.partitioner import partition
from thesis.core.slurm.submitter import submit_jobs
bus = event_bus or get_event_bus()
# Map each patient id to its built sub-workflow inside meta_wf by name suffix.
children = [n for n in meta_wf._graph.nodes() if hasattr(n, "_graph")]
per_patient_groups: Dict[str, Any] = {}
for pid in built:
match = next(
(c for c in children if str(getattr(c, "name", "")).endswith(pid)),
None,
)
if match is None:
logger.warning("Could not locate built sub-workflow for patient {}", pid)
continue
groups = partition(match)
per_patient_groups[pid] = groups
bus.emit(
f"Partitioned {pid} into {len(groups)} stage group(s)",
level=EventLevel.INFO,
category="slurm",
patient_id=pid,
)
return submit_jobs(
per_patient_groups,
cfg.slurm,
workflow_name=workflow_name,
config_name=config_name,
dry_run=dry_run,
)
def _render_batch_tractography_stats(
run_results: List["RunResult"],
config_name: str,
config_dir: Path,
data_dir: Path,
output_dir: Optional[Path],
protocol: Optional[str],
) -> None:
"""Print aggregated tractography statistics for a batch run.
Only includes patients that completed successfully. Silently
returns if no successful patients exist or if statistics
collection fails.
Args:
run_results: Per-patient run results from the batch.
config_name: Configuration name (for resolving output paths).
config_dir: Configuration directory.
data_dir: Base data directory.
output_dir: Override output directory (may be ``None``).
protocol: Protocol name override.
"""
succeeded_ids = [r.patient_id for r in run_results if r.status == RunStatus.SUCCESS]
if len(succeeded_ids) < 1:
logger.debug("Batch stats rendering skipped: no successful patients")
return
try:
from thesis.workflows.qc.statistics import collect_batch_stats, format_stats_table
# Resolve the output base directory the same way contexts do.
config_manager = ConfigManager(config_dir=config_dir)
cfg = config_manager.load_config(
config_name=config_name,
patient_id=succeeded_ids[0],
protocol=protocol,
)
output_base: Optional[Path] = output_dir
if output_base is None:
raw_out = getattr(cfg.paths, "output_dir", None)
if raw_out:
output_base = Path(raw_out)
if output_base is None or not output_base.is_dir():
logger.debug(
"Batch stats rendering skipped: output_dir not resolved or does "
"not exist (output_base={})",
output_base,
)
return
tractography_relpath = getattr(
getattr(cfg, "atlas", None), "tractography_relpath", "tractography/probtrackx2"
)
stats = collect_batch_stats(
output_base,
patient_ids=succeeded_ids,
tractography_relpath=tractography_relpath,
)
if not stats:
return
include_rois = len(succeeded_ids) <= 10 # avoid huge tables
table = format_stats_table(stats, include_roi_counts=include_rois)
click.echo(f"\n{'─' * 70}", err=True)
click.echo("Tractography Statistics", err=True)
click.echo(f"{'─' * 70}", err=True)
click.echo(table, err=True)
outliers: List[Dict[str, Any]] = []
sd_thresh = getattr(getattr(cfg, "qc", None), "outlier_sd_threshold", 2.0)
if len(stats) >= 3:
from thesis.workflows.qc.checks import detect_batch_outliers
outliers = detect_batch_outliers(stats, sd_threshold=sd_thresh)
if outliers:
click.echo(f"\n{'─' * 70}", err=True)
click.echo(f"Outliers (>{sd_thresh} SD from mean)", err=True)
click.echo(f"{'─' * 70}", err=True)
for o in outliers:
click.echo(
f" {o['patient_id']:<12} {o['metric']:<20} "
f"value={o['value']:<10} z={o['z_score']:.1f}",
err=True,
)
_persist_batch_stats(output_base, stats, outliers)
except Exception as _exc:
# Best-effort — never fail the run for stats.
logger.debug("Batch statistics rendering failed: {}", _exc)
def _persist_batch_stats(
output_base: Path,
stats: List[Dict[str, Any]],
outliers: List[Dict[str, Any]],
) -> None:
"""Write batch stats and outliers to JSON under ``output_base/batch_stats``.
Emits timestamped files plus a convenience ``latest.json`` that is
overwritten each run. Best-effort — any IO failure is logged and
swallowed so stats persistence never breaks a pipeline run.
"""
try:
stats_dir = output_base / "batch_stats"
stats_dir.mkdir(exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")
payload = {"subjects": stats, "outliers": outliers}
stats_path = stats_dir / f"stats_{ts}.json"
outliers_path = stats_dir / f"outliers_{ts}.json"
latest_path = stats_dir / "latest.json"
stats_path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8")
outliers_path.write_text(json.dumps(outliers, indent=2, default=str), encoding="utf-8")
latest_path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8")
logger.info("Wrote batch stats to {}", stats_path)
except Exception as exc:
logger.debug("Persisting batch stats failed: {}", exc)
def _enrich_with_tract_similarity_metrics(
result: RunResult,
cfg: PipelineConfig,
ctx_obj: Any,
) -> None:
"""Populate ``result.metadata['tract_similarity']`` from metrics.json.
Called after a successful ``full_pipeline`` run so that the terminal
summary can show headline similarity numbers without re-reading the
full metric file. A best-effort operation: any failure leaves the
metadata unchanged.
"""
if result.workflow != "full_pipeline":
return
if ctx_obj.output_dir is None:
return
subdir = getattr(getattr(cfg, "tract_similarity", None), "output_subdir", None)
if not subdir:
return
metrics_file = Path(ctx_obj.output_dir) / subdir / "metrics.json"
if not metrics_file.is_file():
return
try:
data = json.loads(metrics_file.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError) as exc:
logger.debug("Could not parse tract_similarity metrics.json: {}", exc)
return
overlap = data.get("overlap") or {}
correlation = data.get("correlation") or {}
distance = data.get("distance_mm") or {}
distribution = data.get("distribution") or {}
result.metadata["tract_similarity"] = {
"dice": overlap.get("dice"),
"pearson": correlation.get("pearson"),
"hausdorff95": distance.get("hausdorff95"),
"nmi": distribution.get("nmi"),
}
def _run_single_patient_with_retries(
entry: WorkflowEntry,
config: str,
patient_id: str,
data_dir: Path,
output_dir: Optional[Path],
config_dir: Path,
graph: bool,
dry_run: bool,
protocol: Optional[str],
raw_data_dir_override: Optional[Path],
retries: int,
show_workflow_progress: bool = False,
config_overrides: Optional[Dict[str, Any]] = None,
stage: Optional[str] = None,
) -> RunResult:
"""Run a patient workflow with retry support for sequential execution.
Returns a :class:`RunResult` populated with status, timing, and error info.
"""
bus = get_event_bus()
result = RunResult(patient_id=patient_id, workflow=entry.name)
start_time = time.monotonic()
error_msg = ""
error_type = ""
for attempt in range(retries + 1):
if attempt > 0:
result.retry_count = attempt
bus.emit(
f"Retry {attempt}/{retries} for {patient_id} (waiting 60s)",
level=EventLevel.WARNING,
category="retry",
patient_id=patient_id,
)
time.sleep(60)
ok, error_msg, error_type = _run_single_patient(
entry=entry,
config=config,
patient_id=patient_id,
data_dir=data_dir,
output_dir=output_dir,
config_dir=config_dir,
graph=graph,
dry_run=dry_run,
protocol=protocol,
raw_data_dir_override=raw_data_dir_override,
show_workflow_progress=show_workflow_progress,
config_overrides=config_overrides,
result=result,
stage=stage,
)
if ok:
result.status = RunStatus.SKIPPED if dry_run else RunStatus.SUCCESS
result.elapsed_seconds = time.monotonic() - start_time
return result
result.error_history.append(f"Attempt {attempt + 1}: {error_type or 'Error'}: {error_msg}")
# All attempts failed
result.status = RunStatus.FAILED
result.error_message = error_msg
result.error_type = error_type
result.elapsed_seconds = time.monotonic() - start_time
return result
def _make_output_config(ctx: click.Context) -> OutputConfig:
"""Build an :class:`OutputConfig` from CLI context flags.
Args:
ctx: Click context whose ``obj`` dict may contain ``output_mode``,
``summary_detail``, and ``show_progress`` keys.
Returns:
Configured :class:`OutputConfig`.
"""
obj = ctx.obj or {}
mode = obj.get("output_mode", OutputMode.NORMAL)
summary = obj.get("summary_detail", SummaryDetail.COMPACT)
progress = obj.get("show_progress", None)
return OutputConfig(mode=mode, summary=summary, progress=progress)
def _should_show_workflow_progress(entry: WorkflowEntry, show_progress: bool) -> bool:
"""Determine whether Click node progress should be created for a workflow run.
Args:
entry: Resolved workflow entry.
show_progress: Whether progress UI is otherwise enabled.
Returns:
``True`` when node-level Click progress should be enabled.
"""
if not show_progress:
return False
return not bool(getattr(entry, "is_cohort_level", False))
@click.group()
@click.version_option(version=__version__)
@click.option(
"--log-level",
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
default="INFO",
show_default=True,
help="Logging level",
)
@click.option(
"--log-dir",
type=click.Path(path_type=Path), # type: ignore[type-var]
default="logs",
show_default=True,
help="Directory for log files",
)
@click.option(
"-v",
"--verbose",
is_flag=True,
default=False,
help="Verbose output: show DEBUG-level events, full logs, timing "
"(independent of --summary detail).",
)
@click.option(
"-q",
"--quiet",
is_flag=True,
default=False,
help="Quiet output (errors and final result only)",
)
@click.option(
"--summary",
"summary_opt",
type=click.Choice(["off", "compact", "full"], case_sensitive=False),
default=None,
help="Summary detail level: off|compact|full (controls the final summary "
"structure; independent of -v). Default: compact.",
)
@click.option(
"--no-progress",
is_flag=True,
default=False,
help="Disable animated progress bars and spinners",
)
@click.pass_context
def main(
ctx: click.Context,
log_level: str,
log_dir: Path,
verbose: bool,
quiet: bool,
summary_opt: Optional[str],
no_progress: bool,
) -> None:
"""
Thesis Medical Imaging Pipeline Framework.
A modular Python framework for preprocessing and tractography analysis
of diffusion MRI data.
Examples:
\b
thesis run -w hcp -p P001 -c default
thesis list-workflows
thesis stats collect -o outputs/
"""
# Ensure context object
ctx.ensure_object(dict)
# Resolve output mode from flags
if verbose and quiet:
raise click.UsageError("Cannot use both --verbose/-v and --quiet/-q.")
if verbose:
output_mode = OutputMode.VERBOSE
# Also bump log level to DEBUG in verbose mode
log_level = "DEBUG"
elif quiet:
output_mode = OutputMode.QUIET
else:
output_mode = OutputMode.NORMAL
# Resolve summary detail
if summary_opt is not None:
summary_detail = SummaryDetail(summary_opt.lower())
else:
summary_detail = SummaryDetail.COMPACT
# Setup logging
setup_logging(log_dir=log_dir, log_level=log_level.upper())
if output_mode == OutputMode.VERBOSE:
set_console_level("DEBUG")
elif output_mode == OutputMode.QUIET:
set_console_level("ERROR")
else:
set_console_level("WARNING")
# Store in context
ctx.obj["log_level"] = log_level
ctx.obj["log_dir"] = log_dir
ctx.obj["output_mode"] = output_mode
ctx.obj["summary_detail"] = summary_detail
ctx.obj["show_progress"] = False if no_progress else None # None = auto-detect
logger.info("Thesis framework v{}", __version__)
@main.command()
@click.option(
"--config-dir",
type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var]
default="config",
show_default=True,
help="Configuration directory",
)
@click.option(
"--subdir",
type=str,
default=None,
help="Subdirectory to list (e.g., 'patients', 'protocols')",
)
def list_configs(config_dir: Path, subdir: Optional[str]) -> None:
"""
List available configuration files.
Example:
thesis list-configs
thesis list-configs --subdir patients
"""
config_manager = ConfigManager(config_dir=config_dir)
configs = config_manager.list_configs(subdir=subdir)
if configs:
location = config_dir / subdir if subdir else config_dir
click.echo(f"Available configs in {location}:")
for cfg in configs:
click.echo(f" - {cfg}")
else:
click.echo("No configurations found")
@main.command()
@click.argument("config_name")
@click.option(
"--config-dir",
type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var]
default="config",
show_default=True,
help="Configuration directory",
)
def show_config(config_name: str, config_dir: Path) -> None:
"""
Display the contents of a configuration file.
Example:
thesis show-config default
"""
try:
config_manager = ConfigManager(config_dir=config_dir)
cfg = config_manager.load_config(config_name=config_name)
# Convert to YAML string for display
config_dict = cfg.to_dict()
yaml_str = yaml.dump(config_dict, default_flow_style=False, sort_keys=False)
click.echo(f"Configuration: {config_name}")
click.echo("=" * 60)
click.echo(yaml_str)
except FileNotFoundError:
click.echo(f"[ERROR] Configuration not found: {config_name}", err=True)
raise click.Abort()
except ConfigurationError as e:
click.echo(f"[ERROR] {e}", err=True)
raise click.Abort()
except Exception as e:
click.echo(f"[ERROR] Error loading configuration: {e}", err=True)
raise click.Abort()
def _run_single_patient(
entry: WorkflowEntry,
config: str,
patient_id: str,
data_dir: Path,
output_dir: Optional[Path],
config_dir: Path,
graph: bool,
dry_run: bool,
protocol: Optional[str] = None,
raw_data_dir_override: Optional[Path] = None,
config_overrides: Optional[Dict[str, Any]] = None,
show_workflow_progress: bool = False,
result: Optional[RunResult] = None,
stage: Optional[str] = None,
) -> Tuple[bool, str, str]:
"""Run a Nipype workflow for a single patient.
When ``stage`` is provided the built workflow is pruned to that SLURM stage
group's node set before execution (the grouped-SLURM ``--stage`` path),
leaning on Nipype caching for upstream inputs.
Returns a tuple of ``(success, error_message, error_type)``.
"""
bus = get_event_bus()
logger.info("Starting Nipype workflow for patient: {}", patient_id)
bus.emit(
f"Starting workflow for {patient_id}",
level=EventLevel.IMPORTANT,
category="workflow",
patient_id=patient_id,
)
try:
# Load configuration
config_manager = ConfigManager(config_dir=config_dir)
# Determine protocol: CLI flag > patient YAML > WorkflowEntry.default_protocol
protocol_override = _resolve_protocol_override(config_manager, patient_id, protocol, entry)
overrides: Dict[str, Any] = {}
if raw_data_dir_override:
overrides["paths"] = {"inputs_dir": str(raw_data_dir_override)}
if config_overrides:
overrides = merge_configs(overrides, config_overrides)
try:
cfg = config_manager.load_config(
config_name=config,
patient_id=patient_id,
protocol=protocol_override,
overrides=overrides or None,
protocol_required=protocol is not None,
)
except ConfigurationError as cfg_exc:
logger.error(
"Configuration validation failed for {}: {}", patient_id, cfg_exc, exc_info=True
)
augmented = (
f"{cfg_exc} (check config/{protocol_override}.yaml, "
f"config/patients/{patient_id}.yaml, or {config}.yaml)"
)
bus.emit(
f"Configuration error for {patient_id}: {augmented}",
level=EventLevel.ERROR,
category="workflow",
patient_id=patient_id,
)
return (False, augmented, "ConfigurationError")
# Validate GPU availability once per process; mutates cfg if falling back.
# Skip in the grouped-SLURM --stage path: the partition was computed
# up-front on the submit node (where the GPU is visible), so re-probing
# here would flip hardware.gpu_enabled on a CPU-only stage node,
# reclassify the GPU node as CPU, and collapse the partition to a single
# stage whose name no longer matches the submitted --stage token (the
# stage then either runs the whole pipeline or fails with "Unknown
# stage"). Trust the submit-time config; the GPU node is isolated in its
# own stage and never executes on a CPU stage node.
if stage is None:
_resolve_gpu(cfg)
bus.emit(
f"GPU resolution: gpu_enabled={cfg.hardware.gpu_enabled}",
level=EventLevel.INFO,
category="hardware",
)
# Create processing context
ctx_obj = create_context(
patient_id=patient_id,
config=cfg,
data_dir=data_dir,
output_dir=output_dir,
)
logger.info("Configuration loaded: {}", config)
logger.info("Input directory: {}", ctx_obj.input_dir)
logger.info("Output directory: {}", ctx_obj.output_dir)
# Input data directory preflight: a missing patient directory gives a
# clear "directory not found" message instead of a per-file @requires
# failure deeper in the workflow.
if ctx_obj.input_dir is not None and not Path(ctx_obj.input_dir).exists():
msg = (
f"Input data directory not found: {ctx_obj.input_dir}. "
f"Check that patient ID {patient_id} exists under inputs_dir."
)
bus.emit(
msg,
level=EventLevel.ERROR,
category="preflight",
patient_id=patient_id,
)
return (False, msg, "FileIOError")
# Preflight verification
if callable(entry.verifier):
try:
errors = entry.verifier(cfg, ctx_obj)
if errors:
error_msg = "; ".join(errors)
bus.emit(
f"Preflight verification failed for {patient_id}: {error_msg}",
level=EventLevel.ERROR,
category="preflight",
patient_id=patient_id,
)
return (False, f"Preflight failed: {error_msg}", "PreflightError")
except Exception as ve:
logger.error("Verification step failed: {}", ve, exc_info=True)
bus.emit(
f"Verification error for {patient_id}: {type(ve).__name__}: {ve} "
"(see logs in logs or rerun with -v)",
level=EventLevel.ERROR,
category="preflight",
patient_id=patient_id,
)
return (False, str(ve), type(ve).__name__)
# Build workflow
bus.emit(
f"Building workflow for {patient_id}",
level=EventLevel.INFO,
category="workflow",
patient_id=patient_id,
)
wf = _build_workflow(entry.factory, cfg, ctx_obj)
if not hasattr(wf, "run"):
msg = (
f"Workflow factory returned {type(wf).__name__} instead of a "
f"nipype.pipeline.engine.Workflow for {patient_id}. "
"Check the factory's return value."
)
bus.emit(msg, level=EventLevel.ERROR, category="workflow", patient_id=patient_id)
return (False, msg, "RuntimeError")
# Grouped-SLURM stage pruning: restrict the built graph to one stage's
# nodes before normal execution. Lazy-imported so the non-stage path
# never touches the slurm package.
if stage is not None:
from thesis.core.slurm.pruning import find_stage_group, prune_workflow_to_stage
# Grouped-SLURM stages hand results between separate sbatch processes
# through Nipype's on-disk work/ cache. Nipype mixes ``needed_outputs``
# into a node's hash when ``remove_unnecessary_outputs`` is on, and a
# node's needed-outputs set differs per stage (its consumers are often
# pruned into a *later* stage) — so the producing stage stores one hash
# and the consuming stage computes another, the cache always misses,
# and every retained upstream ancestor re-runs (SynthSeg recomputed on
# the CPU stage). Disabling it drops needed_outputs from the hash so the
# cross-stage cache hits; keeping all outputs is also what staging needs
# (later stages read them).
cfg.nipype.remove_unnecessary_outputs = False
stage_group = find_stage_group(wf, stage)
prune_workflow_to_stage(wf, stage_group)
bus.emit(
f"Restricted workflow to stage '{stage}' for {patient_id}",
level=EventLevel.IMPORTANT,
category="workflow",
patient_id=patient_id,
)
if graph:
if ctx_obj.output_dir is None:
raise click.ClickException("output_dir must be set in context when --graph is used")
graph_dir = ctx_obj.output_dir / "workflow_graphs"
graph_dir.mkdir(parents=True, exist_ok=True)
wf.write_graph(
graph2use="flat",
format="png",
simple_form=True,
dotfilename=str(graph_dir / "workflow"),
)
bus.emit(
f"Workflow graph written to: {graph_dir}",
level=EventLevel.IMPORTANT,
category="output",
patient_id=patient_id,
)
if dry_run:
bus.emit(
f"[DRY RUN] Workflow built for {patient_id} but not executed",
level=EventLevel.IMPORTANT,
category="workflow",
patient_id=patient_id,
)
return (True, "", "")
# Execute workflow
bus.emit(
f"Executing workflow for {patient_id}",
level=EventLevel.IMPORTANT,
category="workflow",
patient_id=patient_id,
)
workflow_progress: Optional[ClickNodeProgress] = None
if show_workflow_progress:
total_nodes = count_workflow_nodes(wf)
if total_nodes > 0:
workflow_progress = ClickNodeProgress(
total=total_nodes,
label=f"Nodes for {patient_id}",
enabled=True,
)
workflow_progress.start()
try:
run_workflow(
wf,
ctx_obj,
cfg.nipype,
event_bus=bus,
progress=workflow_progress,
)
finally:
if workflow_progress is not None:
workflow_progress.stop()
bus.emit(
f"Workflow completed successfully for {patient_id}",
level=EventLevel.IMPORTANT,
category="workflow",
patient_id=patient_id,
)
logger.info("Nipype workflow completed for {}", patient_id)
# Post-workflow QC overlays (best-effort)
if ctx_obj.output_dir is not None:
from thesis.workflows.qc.operations import run_post_workflow_qc
run_post_workflow_qc(cfg, ctx_obj)
if result is not None:
_enrich_with_tract_similarity_metrics(result, cfg, ctx_obj)
return (True, "", "")
except click.ClickException as exc:
logger.error("Nipype workflow failed for {}: {}", patient_id, exc, exc_info=True)
log_dir = "logs"
bus.emit(
f"Error for {patient_id}: ClickException: {exc.format_message()} "
f"(see logs in {log_dir} or rerun with -v)",
level=EventLevel.ERROR,
category="workflow",
patient_id=patient_id,
)
return (False, exc.format_message(), "ClickException")
except Exception as e:
logger.error("Nipype workflow failed for {}: {}", patient_id, e, exc_info=True)
log_dir = "logs"
category_hint = "processing error -- " if isinstance(e, ProcessingError) else ""
bus.emit(
f"Error for {patient_id}: {category_hint}{type(e).__name__}: {e} "
f"(see logs in {log_dir} or rerun with -v)",
level=EventLevel.ERROR,
category="workflow",
patient_id=patient_id,
)
return (False, str(e), type(e).__name__)
@main.command()
@click.option(
"--workflow",
"-w",
required=False,
default=None,
help="Registered workflow name (e.g. hcp, tract_synthseg, synthseg, minimal). "
"Required unless --script is given."
" Some workflows (tract_synthseg, full_pipeline) support multiple backends "
"via config.tractography.method: probtrackx2 (default) or mrtrix3.",
)
@click.option(
"--script",
"-s",
"script",
type=click.Path(path_type=Path, dir_okay=False), # type: ignore[type-var]
default=None,
help="Path to a user-supplied workflow script (.py) that registers a workflow "
"via the @workflow decorator. Overrides --workflow.",
)
@click.option(
"--config",
"-c",
type=str,
default="default",
show_default=True,
help="Configuration name to load",
)
@click.option(
"--protocol",
"-t",
type=str,
default=None,
help="Protocol name override. Resolved as: CLI --protocol > "
"config/patients/<patient-id>.yaml protocol field > workflow default > "
"error if none set.",
)
@click.option(
"--patient-id",
"-p",
"patient_ids",
multiple=True,
help="Patient identifier(s). Repeat (-p P001 -p P002) and/or pass a "
"comma-separated list (-p P001,P002,P003). Mutually exclusive with --all.",
)
@click.option(
"--all",
"all_patients",
is_flag=True,
default=False,
help="Discover and process all numeric patient folders in inputs_dir.",
)
@click.option(
"--pattern",
"-n",
type=str,
default=None,
show_default=True,
help="Glob pattern to match patient folder names in inputs_dir "
"(requires --all). If omitted, only numeric folder names are discovered.",
)
@click.option(
"--parallel",
is_flag=True,
default=False,
help="Enable parallel execution via a Nipype meta-workflow. Without -j, "
"uses nipype.plugin_args.n_procs from config; set -j/--max-workers for an "
"explicit worker count.",
)
@click.option(
"--max-workers",
"-j",
type=click.IntRange(min=1),
default=None,
help="Worker count (implies --parallel). Overrides nipype.plugin_args.n_procs "
"(default: uses config nipype.plugin_args.n_procs or system default).",
)
@click.option(
"--retries",
"-r",
type=click.IntRange(min=0),
default=0,
show_default=True,
help="Number of retry attempts per patient on failure (default: 0).",
)
@click.option(
"--data-dir",
type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var]
default=None,
help="Shared assets directory (templates, atlases, ROIs); overrides config "
"paths.assets_dir (default: 'data'). Independent of the inputs/outputs roots.",
)
@click.option(
"--raw-data-dir",
type=click.Path(path_type=Path), # type: ignore[type-var]
default=None,
help="Override paths.inputs_dir, the per-patient source-data root (default: "
"from config, typically data/raw). An independent project-level root, "
"resolved relative to the current directory when not absolute.",
)
@click.option(
"--output-dir",
type=click.Path(path_type=Path), # type: ignore[type-var]
default=None,
help="Output directory (default: outputs/<patient-id>, from config " "paths.output_dir).",
)
@click.option(
"--config-dir",
type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var]
default="config",
show_default=True,
help="Configuration directory",
)
@click.option(
"--graph",
is_flag=True,
default=False,
help="Write workflow graph to output directory",
)
@click.option(
"--dry-run",
is_flag=True,
default=False,
help="Build workflow but do not execute",
)
@click.option(
"--hemisphere",
type=click.Choice(["left", "right", "both", "both-separately"], case_sensitive=False),
default=None,
help=(
"Hemisphere selection for tractography (supported by hcp, mrtrix3, "
"tract_synthseg, full_pipeline): left, right, both (merged ROIs, default), "
"or both-separately (two independent runs with separate output "
"directories). Ignored/rejected for cohort-level and non-tractography "
"workflows."
),
)
@click.option(
"--stage",
"stage",
type=str,
default=None,
help=(
"Restrict execution to a single SLURM stage group (e.g. 'stage_0_cpu'). "
"Prunes the built workflow to that stage's nodes before running; intended "
"for grouped SLURM submission and not normally used directly."
),
)
@click.pass_context
def run(
ctx: click.Context,
workflow: Optional[str],
script: Optional[Path],
config: str,
protocol: Optional[str],
patient_ids: Tuple[str, ...],
all_patients: bool,
pattern: Optional[str],
parallel: bool,
max_workers: Optional[int],
retries: int,
data_dir: Optional[Path],
raw_data_dir: Optional[Path],
output_dir: Optional[Path],
config_dir: Path,
graph: bool,
dry_run: bool,
hemisphere: Optional[str],
stage: Optional[str],
) -> None:
"""
Run a registered workflow for one or more patients.
Workflow is identified either by its short registered name (use
``thesis list-workflows`` to see what's available) or by an
out-of-tree script via ``--script PATH.py``. Execution is sequential
by default; pass ``--parallel`` or ``-j N`` to fan out across
patients via a Nipype meta-workflow.
Examples:
\b
thesis run -w hcp -p DTI_001 --dry-run
thesis run -w hcp --all -c default --dry-run
thesis run -w hcp --all -j 4
thesis run -w hcp -p P001 -p P002 --parallel
thesis run -w hcp -p P001,P002,P003 -c default
thesis run -w hcp -p DTI_001 --hemisphere left
thesis run -w hcp -p DTI_001 --hemisphere both-separately
thesis run -w full_pipeline -p 114823 -c full_pipeline_mrtrix3 # MRtrix3 backend
thesis run --script ./my_workflow.py -p DTI_001 -c default
"""
# Normalise -p values: support both repeated flags and comma-separated lists.
patient_ids = _split_patient_ids(patient_ids)
if all_patients and patient_ids:
raise click.UsageError("Cannot use both --patient-id/-p and --all.")
if pattern and not all_patients:
raise click.UsageError("--pattern can only be used together with --all.")
if workflow is None and script is None:
raise click.UsageError("Either --workflow / -w or --script / -s is required.")
script_workflow_mismatch: Optional[Tuple[str, str]] = None
if script is not None:
from thesis.core.loader import load_workflow_script
loaded_entry = load_workflow_script(script)
if workflow is not None and workflow != loaded_entry.name:
# Defer the warning until the event bus exists so it surfaces through
# the structured output system (a plain logger.warning here would be
# suppressed in -q mode).
script_workflow_mismatch = (workflow, loaded_entry.name)
workflow = loaded_entry.name
assert workflow is not None # guarded above
# Build config overrides from CLI flags
cli_overrides: Dict[str, Any] = {}
if hemisphere is not None:
cli_overrides.setdefault("tractography", {})["hemisphere"] = hemisphere
# Setup output system
output_config = _make_output_config(ctx)
reset_event_bus()
bus = get_event_bus()
renderer = OutputRenderer(config=output_config, event_bus=bus)
renderer.attach(bus)
if script_workflow_mismatch is not None:
requested_wf, script_wf = script_workflow_mismatch
bus.emit(
f"--workflow {requested_wf!r} ignored; script registered " f"workflow {script_wf!r}.",
level=EventLevel.WARNING,
category="workflow",
)
# Resolve workflow entry (triggers auto-import/registration)
entry = _resolve_workflow(workflow)
# Resolve the base data dir: explicit --data-dir wins, else fall back to
# config.paths.assets_dir (the shared assets base), else "data".
data_dir = _resolve_data_dir(data_dir, config_dir, config, protocol or entry.default_protocol)
# -----------------------------------------------------------------------
# Cohort-level workflow execution bypass
# -----------------------------------------------------------------------
if getattr(entry, "is_cohort_level", False):
if hemisphere is not None:
raise click.UsageError(
"--hemisphere cannot be used with cohort-level workflows (e.g. "
"atlas, tract_similarity_cohort). Hemisphere is a patient-level "
"tractography setting."
)
if patient_ids or all_patients:
raise click.UsageError(
f"Cohort-level workflow '{workflow}' does not accept -p/--patient-id "
"or --all. These workflows operate over the configured output "
"directory; omit these flags."
)
if retries > 0:
logger.warning(
"--retries applies to the single cohort run; there is one logical "
"run per invocation."
)
bus.emit(
f"Starting cohort-level workflow: {workflow}",
level=EventLevel.IMPORTANT,
category="workflow",
)
batch_start = time.monotonic()
result = _run_single_patient_with_retries(
entry=entry,
config=config,
patient_id="cohort",
data_dir=data_dir,
output_dir=output_dir,
config_dir=config_dir,
graph=graph,
dry_run=dry_run,
protocol=protocol or entry.default_protocol,
raw_data_dir_override=raw_data_dir,
retries=retries,
show_workflow_progress=_should_show_workflow_progress(
entry,
output_config.show_progress,
),
config_overrides=cli_overrides,
stage=stage,
)
renderer.detach()
summary = RunSummary.from_result(result)
renderer.render_run_summary(summary)
if result.status in (RunStatus.FAILED, RunStatus.BLOCKED):
ctx.exit(1)
return
# -----------------------------------------------------------------------
# Standard Patient-level execution
# -----------------------------------------------------------------------
if not all_patients and not patient_ids:
raise click.UsageError(
"Either --patient-id / -p or --all is required for patient-level workflows."
)
# Fast-fail config/protocol preflight: a typo in -c/--protocol should fail
# here, before any patient discovery or per-patient setup work (M4).
_effective_protocol = protocol or entry.default_protocol
try:
ConfigManager(config_dir=config_dir).load_config(
config_name=config,
protocol=_effective_protocol,
protocol_required=protocol is not None,
)
except ConfigurationError as cfg_exc:
raise click.ClickException(
f"{cfg_exc} (failed to load config '{config}'"
+ (f" with protocol '{_effective_protocol}'" if _effective_protocol else "")
+ ")"
)
# Collect patient IDs
if all_patients:
ids = _discover_patient_ids(
data_dir=data_dir,
config_dir=config_dir,
config_name=config,
protocol=protocol or entry.default_protocol,
raw_data_dir_override=raw_data_dir,
name_pattern=pattern,
)
bus.emit(
f"Discovered {len(ids)} patient(s): {', '.join(ids)}",
level=EventLevel.IMPORTANT,
category="discovery",
)
else:
ids = list(patient_ids)
_validate_explicit_patient_ids(
ids,
config_dir=config_dir,
config_name=config,
protocol=_effective_protocol,
raw_data_dir_override=raw_data_dir,
data_dir=data_dir,
)
total = len(ids)
is_batch = total > 1
is_parallel = parallel or max_workers is not None
batch_start = time.monotonic()
# -----------------------------------------------------------------------
# Sequential execution
# -----------------------------------------------------------------------
if not is_parallel:
run_results: List[RunResult] = []
# Helper for click.progressbar item display
def _patient_show(pid: Optional[str]) -> str:
return pid if pid else ""
# Use patient-level progress for sequential batch runs.
show_patient_progress = is_batch and output_config.show_progress
cm = suspend_console_logging() if show_patient_progress else contextlib.nullcontext()
with cm:
with (
click.progressbar(
ids,
label=f"Running {workflow}",
item_show_func=_patient_show,
show_pos=True,
show_eta=True,
file=sys.stderr,
color=True,
width=40,
)
if show_patient_progress
else contextlib.nullcontext(ids)
) as patient_iter:
for pid in patient_iter:
result = _run_single_patient_with_retries(
entry=entry,
config=config,
patient_id=pid,
data_dir=data_dir,
output_dir=output_dir,
config_dir=config_dir,
graph=graph,
dry_run=dry_run,
protocol=protocol,
raw_data_dir_override=raw_data_dir,
retries=retries,
show_workflow_progress=output_config.show_progress and not is_batch,
config_overrides=cli_overrides or None,
stage=stage,
)
result.workflow = workflow
run_results.append(result)
# Detach renderer before printing summary
renderer.detach()
# Render summaries
batch_elapsed = time.monotonic() - batch_start
if is_batch:
batch_summary = BatchSummary.from_results(
run_results,
workflow=workflow,
batch_elapsed=batch_elapsed,
)
renderer.render_batch_summary(batch_summary)
_render_batch_tractography_stats(
run_results,
config_name=config,
config_dir=config_dir,
data_dir=data_dir,
output_dir=output_dir,
protocol=protocol or entry.default_protocol,
)
if batch_summary.failed > 0 or batch_summary.blocked > 0:
ctx.exit(1)
else:
# Single patient
result = run_results[0]
summary = RunSummary.from_result(result)
renderer.render_run_summary(summary)
if result.status in (RunStatus.FAILED, RunStatus.BLOCKED):
ctx.exit(1)
return
# -----------------------------------------------------------------------
# Parallel execution via Nipype meta-workflow
#
# Each patient's Nipype workflow is added as a direct sub-workflow of
# a single meta-workflow. This exposes *every* node (including GPU-
# bound ones like ``synthseg`` and ``probtrackx2``) to the global
# ``MultiProc`` scheduler, which:
#
# - Parallelises CPU-bound nodes across patients.
# - Serialises GPU-bound nodes via ``n_gpu_procs`` (prevents OOM).
# - Accounts for ``mem_gb`` per node to avoid memory oversubscription.
#
# Retries rely on Nipype's native result caching: re-running the
# meta-workflow skips all previously successful nodes and only
# re-executes failed ones.
# -----------------------------------------------------------------------
bus.emit(
f"Batch run: {total} patient(s) -- {', '.join(ids)}",
level=EventLevel.IMPORTANT,
category="batch",
)
logger.info("Batch run started for {} patient(s): {}", total, ids)
# Load config to get plugin settings
cfg = ConfigManager(config_dir=config_dir).load_config(
config_name=config,
protocol=protocol or entry.default_protocol,
overrides=cli_overrides or None,
)
# Validate GPU once before building any patient workflows
_resolve_gpu(cfg)
bus.emit(
f"GPU resolution: gpu_enabled={cfg.hardware.gpu_enabled}",
level=EventLevel.INFO,
category="hardware",
)
plugin = cfg.nipype.plugin
plugin_args: Dict[str, Any] = dict(cfg.nipype.plugin_args or {})
if max_workers is not None:
plugin_args["n_procs"] = max_workers
# Build-phase status: emit at IMPORTANT level so it surfaces even under
# --no-progress (typical in SLURM/CI), where the multi-minute build phase
# would otherwise produce no output.
bus.emit(
f"Building {total} patient workflow(s)...",
level=EventLevel.IMPORTANT,
category="build",
)
meta_wf, built, build_failed = _build_batch_workflow(
entry=entry,
ids=ids,
cfg=cfg,
config_name=config,
data_dir=data_dir,
output_dir=output_dir,
config_dir=config_dir,
protocol=protocol,
raw_data_dir=raw_data_dir,
event_bus=bus,
config_overrides=cli_overrides or None,
)
bus.emit(
f"Build complete: {len(built)}/{total} workflows ready",
level=EventLevel.IMPORTANT,
category="build",
)
# -------------------------------------------------------------------------
# SLURM grouped-submission branch (patient batch path).
#
# When slurm.enabled, each built patient workflow is partitioned into a
# linear chain of resource-homogeneous stage jobs and submitted via sbatch
# instead of being run in-process. All slurm imports are lazy/local so the
# disabled path is byte-identical to today.
# -------------------------------------------------------------------------
if getattr(cfg, "slurm", None) is not None and cfg.slurm.enabled:
renderer.detach()
manifest = _run_slurm_batch(
meta_wf=meta_wf,
built=built,
cfg=cfg,
workflow_name=workflow,
config_name=config,
dry_run=dry_run,
event_bus=bus,
)
click.echo(
f"SLURM submission {'planned' if dry_run else 'complete'}: "
f"{len(manifest)} patient chain(s).",
err=True,
)
return
if cfg.hardware.gpu_enabled:
plugin_args.setdefault("n_gpu_procs", cfg.hardware.n_gpu_procs)
plugin_args.setdefault("n_gpus", cfg.hardware.n_gpus)
bus.emit(
f"Workflows built: {len(built)}/{total} | "
f"Plugin: {plugin} | n_procs: {plugin_args.get('n_procs', '?')} | "
f"memory_gb: {plugin_args.get('memory_gb', '?')} | "
f"n_gpu_procs: {plugin_args.get('n_gpu_procs', 'N/A')}",
level=EventLevel.IMPORTANT,
category="batch",
)
if graph:
graph_dir = Path(cfg.nipype.working_dir) / "batch" / "workflow_graphs"
graph_dir.mkdir(parents=True, exist_ok=True)
meta_wf.write_graph(
graph2use="flat",
format="png",
simple_form=True,
dotfilename=str(graph_dir / "batch_workflow"),
)
bus.emit(
f"Batch workflow graph written to: {graph_dir}",
level=EventLevel.IMPORTANT,
category="output",
)
if dry_run:
bus.emit(
f"[DRY RUN] Meta-workflow built for {len(built)} patient(s) but not executed",
level=EventLevel.IMPORTANT,
category="workflow",
)
# Build summary for dry run
run_results_parallel: List[RunResult] = []
for pid in built:
run_results_parallel.append(
RunResult(patient_id=pid, workflow=workflow, status=RunStatus.SKIPPED)
)
for pid in build_failed:
run_results_parallel.append(
RunResult(
patient_id=pid,
workflow=workflow,
status=RunStatus.BLOCKED,
error_message="Failed to build workflow",
)
)
renderer.detach()
batch_summary = BatchSummary.from_results(
run_results_parallel,
workflow=workflow,
batch_elapsed=time.monotonic() - batch_start,
)
renderer.render_batch_summary(batch_summary)
if build_failed:
ctx.exit(1)
return
# Execute with retry support. Nipype caches node results, so
# re-running the same meta-workflow automatically skips successful
# nodes and only re-executes failures.
last_error: Optional[Exception] = None
workflow_progress: Optional[ClickNodeProgress] = None
total_nodes = count_workflow_nodes(meta_wf)
if output_config.show_progress and total_nodes > 0:
workflow_progress = ClickNodeProgress(
total=total_nodes,
label=f"Nodes for {workflow}",
enabled=output_config.show_progress,
)
workflow_progress.start()
try:
suppress_nipype_native_logging()
# Apply cfg.nipype onto meta_wf AND the global nipype.config — this
# path runs meta_wf.run() directly, bypassing NipypeExecutor, so
# without this call hash_method="content" (and every other exec-level
# setting) silently falls back to Nipype's package defaults. See
# thesis.core.nipype.executor.apply_nipype_execution_config for why
# the global-config update is load-bearing for file-input hashing.
try:
apply_nipype_execution_config(meta_wf, cfg.nipype)
except Exception as cfg_exc:
logger.warning(f"Could not apply Nipype config to meta-workflow: {cfg_exc}")
status_callback = build_nipype_status_callback(
progress=workflow_progress,
event_bus=bus,
)
plugin_args["status_callback"] = status_callback
for attempt in range(retries + 1):
if attempt > 0:
logger.info(
"Retry {}/{} -- re-running meta-workflow (cached nodes will be skipped)",
attempt,
retries,
)
bus.emit(
f"Retry {attempt}/{retries} -- "
f"re-running (completed nodes cached, only failures re-execute)",
level=EventLevel.WARNING,
category="retry",
)
time.sleep(5) # brief pause between retries
try:
meta_wf.run(plugin=plugin, plugin_args=plugin_args)
last_error = None
break # success
except Exception as exc:
last_error = exc
logger.error("Batch attempt {}/{} failed: {}", attempt + 1, retries + 1, exc)
# Attribute failed nodes to patients immediately so users see
# which patients were affected (and that this is an execution-
# phase failure, distinct from the build phase).
failed_by_pid = _attribute_nodes_to_patients(status_callback.failed_nodes, built)
attributed = ", ".join(sorted(failed_by_pid)) or "unattributed"
bus.emit(
f"Batch execution failed (attempt {attempt + 1}/{retries + 1}). "
f"Failed nodes attributed to: {attributed}",
level=EventLevel.ERROR,
category="batch",
)
finally:
if workflow_progress:
workflow_progress.stop()
# Build batch summary
renderer.detach()
batch_elapsed = time.monotonic() - batch_start
run_results_parallel = []
if last_error is None:
for pid in built:
res = RunResult(patient_id=pid, workflow=workflow, status=RunStatus.SUCCESS)
try:
patient_ctx = create_context(
patient_id=pid, config=cfg, data_dir=data_dir, output_dir=output_dir
)
_enrich_with_tract_similarity_metrics(res, cfg, patient_ctx)
except Exception as _exc:
logger.debug("tract_similarity enrichment skipped for {}: {}", pid, _exc)
run_results_parallel.append(res)
else:
# Attribute failed/finished Nipype nodes back to specific patients so the
# summary reflects which patients actually failed -- not "all of them".
failed_by_pid = _attribute_nodes_to_patients(status_callback.failed_nodes, built)
ok_by_pid = _attribute_nodes_to_patients(status_callback.finished_ok_nodes, built)
classified = _classify_batch_failure(
built=built,
workflow=workflow,
failed_by_pid=failed_by_pid,
ok_by_pid=ok_by_pid,
last_error=last_error,
)
for res in classified:
if res.status == RunStatus.SUCCESS:
try:
patient_ctx = create_context(
patient_id=res.patient_id,
config=cfg,
data_dir=data_dir,
output_dir=output_dir,
)
_enrich_with_tract_similarity_metrics(res, cfg, patient_ctx)
except Exception as _exc:
logger.debug(
"tract_similarity enrichment skipped for {}: {}", res.patient_id, _exc
)
run_results_parallel.append(res)
for pid in build_failed:
run_results_parallel.append(
RunResult(
patient_id=pid,
workflow=workflow,
status=RunStatus.BLOCKED,
error_message="Failed to build workflow",
)
)
batch_summary = BatchSummary.from_results(
run_results_parallel,
workflow=workflow,
batch_elapsed=batch_elapsed,
)
batch_summary.retries = retries
renderer.render_batch_summary(batch_summary)
_render_batch_tractography_stats(
run_results_parallel,
config_name=config,
config_dir=config_dir,
data_dir=data_dir,
output_dir=output_dir,
protocol=protocol or entry.default_protocol,
)
if last_error is not None or build_failed:
ctx.exit(1)
@main.command("list-workflows")
@click.option(
"--config",
"-c",
type=str,
default="default",
show_default=True,
help="Configuration name to read paths.scripts_dir from.",
)
@click.option(
"--config-dir",
type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var]
default="config",
show_default=True,
help="Configuration directory",
)
def list_workflows(config: str, config_dir: Path) -> None:
"""
List all available registered workflows.
Auto-discovers and imports all ``thesis.workflows.*`` submodules to
trigger self-registration. If ``paths.scripts_dir`` is set in the
selected config, also scans that directory for user-supplied scripts
(``*.py``, top-level, names not starting with ``_``).
Example:
thesis list-workflows
thesis list-workflows -c my_config
"""
import thesis.workflows as _wf_pkg
# Auto-import every submodule/subpackage to trigger registrations
for _finder, name, ispkg in pkgutil.iter_modules(_wf_pkg.__path__):
_ensure_workflow_imported(name, is_package=ispkg)
builtin_names = set(WORKFLOW_REGISTRY.list())
user_entries: List[Tuple[str, Path]] = []
user_failures: List[Tuple[Path, str]] = []
scripts_dir: Optional[Path] = None
try:
manager = ConfigManager(config_dir=config_dir)
cfg = manager.load_config(config_name=config)
scripts_dir = cfg.paths.scripts_dir
except Exception as exc: # pragma: no cover - config read is best-effort here
logger.debug("list-workflows could not read scripts_dir from config: {}", exc)
if scripts_dir is not None:
from thesis.core.loader import discover_workflows_dir, load_workflow_script
if not scripts_dir.is_absolute():
scripts_dir = Path.cwd() / scripts_dir
for script_path in discover_workflows_dir(scripts_dir):
try:
loaded = load_workflow_script(script_path)
if loaded.name not in builtin_names:
user_entries.append((loaded.name, script_path))
except click.ClickException as exc:
user_failures.append((script_path, exc.message))
except Exception as exc: # pragma: no cover - defensive
user_failures.append((script_path, f"{type(exc).__name__}: {exc}"))
entries = WORKFLOW_REGISTRY.all_entries()
if not entries:
click.echo("No workflows registered.")
else:
click.echo("Available workflows:")
user_paths = dict(user_entries)
for e in entries:
proto_str = f" (protocol: {e.default_protocol})" if e.default_protocol else ""
suffix = ""
if e.name in user_paths:
suffix = f" (user script: {user_paths[e.name]})"
click.echo(f" {e.name:<20} {e.description}{proto_str}{suffix}")
if user_failures:
click.echo("")
click.echo("User scripts that failed to load:")
for path, msg in user_failures:
click.echo(f" {path}: {msg}")
@main.command()
def info() -> None:
"""
Display system and framework information.
"""
import platform
import sys
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as package_version
click.echo("=" * 60)
click.echo("Thesis Framework Information")
click.echo("=" * 60)
click.echo(f"Version: {__version__}")
click.echo(f"Python: {sys.version.split()[0]}")
click.echo(f"Platform: {platform.system()} {platform.release()}")
click.echo(f"Architecture: {platform.machine()}")
click.echo()
# Check for dependencies
deps = {
"NumPy": ("numpy", "numpy"),
"SciPy": ("scipy", "scipy"),
"NiBabel": ("nibabel", "nibabel"),
"PyYAML": ("yaml", "PyYAML"),
"Pydantic": ("pydantic", "pydantic"),
"Click": ("click", "click"),
"Loguru": ("loguru", "loguru"),
}
click.echo("Dependencies:")
for name, (module, distribution) in deps.items():
try:
__import__(module)
try:
version = package_version(distribution)
except PackageNotFoundError:
version = "unknown"
click.echo(f" [OK] {name}: {version}")
except ImportError:
click.echo(f" [--] {name}: not installed")
click.echo()
# Check for external tools
import shutil
external_tools = {
"FSL": ["fsl", "probtrackx2"],
"ANTs": ["antsRegistration"],
}
click.echo("External Tools:")
for tool_name, commands in external_tools.items():
found = any(shutil.which(cmd) for cmd in commands)
status = "[OK]" if found else "[--]"
click.echo(f" {status} {tool_name}: {'available' if found else 'not found'}")
@main.group()
def stats() -> None:
"""Inspect and export batch tractography statistics.
Run ``thesis stats collect -o <dir>`` to extract stats from an output
directory.
"""
@stats.command("collect")
@click.option(
"-o",
"--output-dir",
required=True,
type=click.Path(exists=True, file_okay=False, path_type=Path), # type: ignore[type-var]
help="Root directory containing per-subject output subdirectories.",
)
@click.option(
"-p",
"--patient-id",
"patient_ids",
multiple=True,
help="Limit to these patient IDs (repeat for multiple). "
"Default: auto-discover all subjects.",
)
@click.option(
"--out",
"output_file",
type=click.Path(dir_okay=False, path_type=Path), # type: ignore[type-var]
default=None,
help="Destination JSON file. " "Default: <output-dir>/batch_stats/stats_<timestamp>.json.",
)
@click.option(
"--sd-threshold",
type=float,
default=2.0,
show_default=True,
help="Outlier detection threshold in standard deviations.",
)
@click.option(
"--tractography-relpath",
"tractography_relpath",
type=str,
default="tractography/probtrackx2",
show_default=True,
help=(
"Relative path under each patient directory where the tractography "
"run lives. Use 'tractography/mrtrix3' for the MRtrix3 backend."
),
)
@click.option(
"--tract-similarity-subdir",
"tract_similarity_subdir",
type=str,
default="tract_similarity",
show_default=True,
help=(
"Relative path under each patient directory where "
"tract_similarity / tract_similarity_hcp_loo wrote metrics.json. "
"Attached under subject['tract_similarity'] when the file exists."
),
)
def stats_collect(
output_dir: Path,
patient_ids: Tuple[str, ...],
output_file: Optional[Path],
sd_threshold: float,
tractography_relpath: str,
tract_similarity_subdir: str,
) -> None:
"""Collect tractography stats from an existing output dir and write JSON.
Reuses the same collectors as the post-run summary
(``collect_batch_stats`` + ``detect_batch_outliers``) so the emitted
schema matches what future runs write automatically. When a subject has
a ``<tract_similarity_subdir>/metrics.json`` file (produced by the
``tract_similarity`` or ``tract_similarity_hcp_loo`` workflows), its
contents are merged into that subject's record under a
``tract_similarity`` key — no separate file or post-processing needed.
Example:
thesis stats collect -o outputs/ --out /tmp/atlas_stats.json
thesis stats collect -o outputs/ --tractography-relpath tractography/mrtrix3
"""
from thesis.workflows.qc.checks import detect_batch_outliers
from thesis.workflows.qc.statistics import collect_batch_stats
pids = list(patient_ids) if patient_ids else None
subjects = collect_batch_stats(
output_dir,
patient_ids=pids,
tractography_relpath=tractography_relpath,
tract_similarity_subdir=tract_similarity_subdir,
)
if not subjects:
raise click.ClickException(f"No subjects with tractography output found under {output_dir}")
outliers: List[Dict[str, Any]] = []
if len(subjects) >= 3:
outliers = detect_batch_outliers(subjects, sd_threshold=sd_threshold)
if output_file is None:
stats_dir = output_dir / "batch_stats"
stats_dir.mkdir(exist_ok=True)
output_file = stats_dir / f"stats_{time.strftime('%Y%m%d_%H%M%S')}.json"
else:
output_file.parent.mkdir(parents=True, exist_ok=True)
payload = {"subjects": subjects, "outliers": outliers}
output_file.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8")
click.echo(f"Wrote {output_file}")
click.echo(f" subjects: {len(subjects)}")
click.echo(f" outliers: {len(outliers)}")
@main.group()
def slurm() -> None:
"""Inspect and retry grouped SLURM stage submissions."""
def _collect_job_ids(manifest: Dict[str, List[Dict[str, Any]]]) -> List[str]:
"""Return every non-null ``job_id`` recorded in a manifest."""
return [
str(entry["job_id"])
for entries in manifest.values()
for entry in entries
if entry.get("job_id") is not None
]
@slurm.command("status")
@click.option(
"--manifest",
"manifest_path",
required=True,
type=click.Path(path_type=Path), # type: ignore[type-var]
help="Path to manifest.json (or a submit_dir containing one).",
)
def slurm_status(manifest_path: Path) -> None:
"""Show per-patient / per-stage SLURM states from a submission manifest.
Reads the durable manifest written at submit time, queries ``sacct`` for the
current state of every recorded job, and prints a per-``(patient, stage)``
status table. Stages whose ``afterok`` predecessor failed are surfaced as
``BLOCKED``.
Example:
thesis slurm status --manifest $SCRATCH/thesis_slurm/manifest.json
"""
from thesis.core.slurm.monitor import query_states, read_manifest, summarize
try:
manifest = read_manifest(manifest_path)
except (FileNotFoundError, ValueError) as exc:
raise click.ClickException(str(exc))
states = query_states(_collect_job_ids(manifest))
summary = summarize(manifest, states)
for pid, rows in summary.items():
click.echo(f"{pid}:")
for row in rows:
job = row["job_id"] if row["job_id"] is not None else "-"
click.echo(f" {row['stage']:<24} {row['compute']:<3} " f"job={job:<10} {row['state']}")
@slurm.command("resubmit")
@click.option(
"--manifest",
"manifest_path",
required=True,
type=click.Path(path_type=Path), # type: ignore[type-var]
help="Path to manifest.json (or a submit_dir containing one).",
)
@click.option(
"-c",
"--config",
"config_name",
required=True,
help="Configuration name providing slurm.* settings.",
)
@click.option(
"--config-dir",
type=click.Path(path_type=Path), # type: ignore[type-var]
default="config",
help="Configuration directory.",
)
@click.option(
"--retries",
type=int,
default=1,
show_default=True,
help="Maximum resubmission passes for still-failing stages.",
)
@click.option(
"--dry-run",
is_flag=True,
default=False,
help="Select failed stages and write the plan, but do not call sbatch.",
)
def slurm_resubmit(
manifest_path: Path,
config_name: str,
config_dir: Path,
retries: int,
dry_run: bool,
) -> None:
"""Re-submit only failed / timed-out / OOM stages, re-chaining dependents.
Reconciles the manifest against ``sacct``, selects stages in
``FAILED`` / ``TIMEOUT`` / ``OUT_OF_MEMORY``, and re-submits them (and their
downstream dependents) via the Phase D submitter with refreshed ``afterok``
chains. Nipype on-disk caching skips already-completed nodes.
Example:
thesis slurm resubmit --manifest $SCRATCH/thesis_slurm/manifest.json -c full_pipeline
"""
from thesis.core.slurm.monitor import query_states, read_manifest, select_resubmit
from thesis.core.slurm.submitter import resubmit_jobs
try:
manifest = read_manifest(manifest_path)
except (FileNotFoundError, ValueError) as exc:
raise click.ClickException(str(exc))
config_manager = ConfigManager(config_dir=config_dir)
cfg = config_manager.load_config(config_name=config_name)
slurm_cfg = getattr(cfg, "slurm", None)
if slurm_cfg is None or not slurm_cfg.enabled:
raise click.ClickException(
f"Config '{config_name}' does not enable slurm (slurm.enabled is false)."
)
passes = max(1, retries)
any_selected = False
for attempt in range(1, passes + 1):
states = query_states(_collect_job_ids(manifest))
to_resubmit = select_resubmit(manifest, states)
if not to_resubmit:
if not any_selected:
click.echo("No failed/timeout/oom stages to resubmit.")
break
any_selected = True
total = sum(len(v) for v in to_resubmit.values())
click.echo(
f"Pass {attempt}/{passes}: resubmitting {total} stage(s) "
f"across {len(to_resubmit)} patient(s)."
)
manifest = resubmit_jobs(manifest, to_resubmit, slurm_cfg, dry_run=dry_run)
if dry_run:
# No new states to observe in a dry run; one planning pass is enough.
break
if __name__ == "__main__":
main()