Source code for thesis.workflows.tract_similarity.hcp_loo

"""Cohort-scope workflow: per-HCP-subject leave-one-out similarity vs the cohort atlas.

Produces, for each numeric (HCP) subject under the cohort output directory, a
``metrics.json`` (and optionally four NIfTI volumes) under
``<subject>/<tract_similarity.output_subdir>/`` mirroring the per-patient
``tract_similarity`` workflow. The per-subject reference is a leave-one-out
cohort mean atlas, built in memory by subtracting that subject's volume from
the precomputed stack sum.

Registered as ``thesis run -w tract_similarity_hcp_loo -c <profile>``.
"""

from __future__ import annotations

from pathlib import Path
from typing import List, Tuple

import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function

from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger

logger = get_logger(__name__)


def _discover_hcp_subjects(
    input_dir: Path | str,
    tractography_relpath: str,
) -> List[Tuple[str, List[Tuple[Path, float]]]]:
    """Discover numeric-pid HCP subject directories with valid tractography input.

    Mirrors :func:`thesis.workflows.atlas._io._collect_patient_inputs` but also
    preserves the patient ID and sorts deterministically. The returned list is
    sorted lexicographically by ``pid`` so that ``patient_inputs[i]`` and
    ``pids[i]`` stay aligned downstream.

    Args:
        input_dir: Cohort directory containing numeric patient subdirectories.
        tractography_relpath: Relative path under each subject directory where
            tractography outputs live (e.g. ``"tractography/probtrackx2"``).

    Returns:
        List of ``(pid, runs)`` pairs. ``runs`` is a list of
        ``(fdt_path, waytotal)`` tuples, identical in shape to one entry of
        ``_collect_patient_inputs``'s output.
    """
    base = Path(input_dir)
    pairs: List[Tuple[str, List[Tuple[Path, float]]]] = []

    for pdir in sorted(base.iterdir(), key=lambda d: d.name):
        if not pdir.is_dir() or not pdir.name.isdigit():
            continue
        probtrackx_dir = pdir / tractography_relpath
        hemi_dirs = [d for d in (probtrackx_dir / "left", probtrackx_dir / "right") if d.is_dir()]
        run_dirs = hemi_dirs or [probtrackx_dir]

        runs: List[Tuple[Path, float]] = []
        for run_dir in run_dirs:
            fdt_file = run_dir / "warped_streamlines" / "fdt_paths.nii.gz"
            waytotal_file = run_dir / "waytotal"
            if not fdt_file.is_file() or not waytotal_file.is_file():
                continue
            try:
                waytotal = float(waytotal_file.read_text().strip())
            except ValueError:
                continue
            if waytotal <= 0.0:
                continue
            runs.append((fdt_file, waytotal))

        if runs:
            pairs.append((pdir.name, runs))

    return pairs


def _hcp_loo_task(
    input_dir: str,
    tractography_relpath: str,
    normalization_method: str,
    output_subdir: str,
    subject_threshold_mode: str,
    subject_threshold_value: float,
    atlas_threshold_mode: str,
    atlas_threshold_value: float,
    n_bins: int,
    minimum_subjects: int,
    write_volumes: bool,
    learned_prediction_relpath: str = "",
    learned_support_threshold: float = 0.0,
    _ready_anchor: object = None,
) -> list:
    """Per-cohort task body: compute LOO similarity metrics for every HCP subject.

    Runs entirely in template space — the cohort atlas inputs are already
    aligned. For each subject ``i`` the leave-one-out atlas is
    ``(stack_sum - stack[i]) / (N - 1)``; metrics come from the existing
    ``tract_similarity._metrics`` family. Writes per-subject artefacts into
    ``<input_dir>/<pid>/<output_subdir>/``.

    All imports are local because this body is invoked as a Nipype
    ``Function`` and must pickle cleanly to the subprocess.

    Returns:
        List of patient IDs for which artefacts were successfully written.
    """
    import json
    import sys
    from pathlib import Path

    import nibabel as nib
    import numpy as np

    from thesis.core.exceptions import ProcessingError
    from thesis.workflows.atlas._io import _build_patient_stack
    from thesis.workflows.atlas._params import NormalizationMethod
    from thesis.workflows.learned_atlas.diagnostics import evaluate_prediction
    from thesis.workflows.tract_similarity._io import voxel_size_mm
    from thesis.workflows.tract_similarity._metrics import (
        correlation_metrics,
        distance_metrics,
        distribution_metrics,
        overlap_metrics,
    )
    from thesis.workflows.tract_similarity.hcp_loo import _discover_hcp_subjects
    from thesis.workflows.tract_similarity.workflow import apply_threshold

    base = Path(input_dir)
    subjects = _discover_hcp_subjects(base, tractography_relpath)
    if len(subjects) < minimum_subjects:
        raise ProcessingError(
            f"HCP LOO requires >= {minimum_subjects} subjects with valid "
            f"fdt+waytotal under {base}, found {len(subjects)}."
        )

    pids = [pid for pid, _ in subjects]
    patient_inputs = [runs for _, runs in subjects]

    # Load reference metadata from the first run for affine, header, shape.
    ref_file = patient_inputs[0][0][0]
    ref_img: nib.Nifti1Image = nib.load(str(ref_file))  # type: ignore[assignment]
    affine = np.asarray(ref_img.affine)
    header = ref_img.header
    _shape = ref_img.shape
    ref_shape: tuple[int, int, int] = (int(_shape[0]), int(_shape[1]), int(_shape[2]))

    try:
        stack = _build_patient_stack(
            patient_inputs, ref_shape, NormalizationMethod(normalization_method)
        )
    except ValueError as exc:
        raise ProcessingError(f"HCP LOO stack build failed under {base}: {exc}") from exc

    n = stack.shape[0]
    stack_sum = stack.sum(axis=0)
    spacing = voxel_size_mm(affine)
    successful: list = []

    for i, pid in enumerate(pids):
        try:
            vol_subj = stack[i]
            loo_mean = (stack_sum - vol_subj) / float(n - 1)

            thr_subj = apply_threshold(vol_subj, subject_threshold_mode, subject_threshold_value)
            thr_atlas = apply_threshold(loo_mean, atlas_threshold_mode, atlas_threshold_value)
            mask_subj = vol_subj > thr_subj
            mask_atlas = loo_mean > thr_atlas

            overlap: dict = overlap_metrics(mask_subj, mask_atlas)
            correlation: dict = correlation_metrics(vol_subj, loo_mean)
            distance_mm: dict = distance_metrics(mask_subj, mask_atlas, voxel_size_mm=spacing)
            distribution: dict = distribution_metrics(vol_subj, loo_mean, n_bins=int(n_bins))
            voxel_counts: dict = {
                "probtrackx": int(np.count_nonzero(mask_subj)),
                "atlas": int(np.count_nonzero(mask_atlas)),
                "intersection": int(np.count_nonzero(mask_subj & mask_atlas)),
                "union": int(np.count_nonzero(mask_subj | mask_atlas)),
            }
            metrics: dict = {
                "patient_id": pid,
                "thresholds": {
                    "subject": {
                        "mode": subject_threshold_mode,
                        "value": float(subject_threshold_value),
                        "effective": float(thr_subj),
                    },
                    "atlas": {
                        "mode": atlas_threshold_mode,
                        "value": float(atlas_threshold_value),
                        "effective": float(thr_atlas),
                    },
                },
                "voxel_counts": voxel_counts,
                "overlap": overlap,
                "correlation": correlation,
                "distance_mm": distance_mm,
                "distribution": distribution,
            }

            if learned_prediction_relpath:
                pred_path = base / pid / learned_prediction_relpath
                if pred_path.is_file():
                    _pred_obj = nib.load(str(pred_path)).dataobj  # type: ignore[attr-defined]
                    vol_pred = np.asarray(_pred_obj, dtype=np.float32)
                    if vol_pred.shape[:3] == ref_shape:
                        thr_pred = apply_threshold(
                            vol_pred, subject_threshold_mode, subject_threshold_value
                        )
                        mask_pred = vol_pred > thr_pred
                        metrics["learned_atlas"] = {
                            "overlap": overlap_metrics(mask_subj, mask_pred),
                            "correlation": correlation_metrics(vol_subj, vol_pred),
                            "distance_mm": distance_metrics(
                                mask_subj, mask_pred, voxel_size_mm=spacing
                            ),
                            "distribution": distribution_metrics(
                                vol_subj, vol_pred, n_bins=int(n_bins)
                            ),
                            "non_circular": evaluate_prediction(
                                vol_pred,
                                vol_subj,
                                loo_mean,
                                support_threshold=float(learned_support_threshold),
                            ),
                        }
                    else:
                        print(
                            f"[tract_similarity_hcp_loo] WARNING: {pid} learned "
                            f"prediction shape {vol_pred.shape[:3]} != {ref_shape}; "
                            "skipping learned arm",
                            file=sys.stderr,
                        )
                else:
                    print(
                        f"[tract_similarity_hcp_loo] {pid}: no learned prediction "
                        f"at {pred_path}; skipping learned arm",
                        file=sys.stderr,
                    )

            out_dir = base / pid / output_subdir
            out_dir.mkdir(parents=True, exist_ok=True)
            (out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")

            if write_volumes:
                for name, vol, dtype in (
                    ("subject_normalized", vol_subj, np.float32),
                    ("atlas_normalized", loo_mean, np.float32),
                    ("subject_mask", mask_subj, np.uint8),
                    ("atlas_mask", mask_atlas, np.uint8),
                ):
                    nib.save(
                        nib.Nifti1Image(np.asarray(vol).astype(dtype), affine, header),
                        str(out_dir / f"{name}.nii.gz"),
                    )

            print(
                f"[tract_similarity_hcp_loo] {pid}: "
                f"dice={overlap['dice']:.4f}, "
                f"voxels_subj={voxel_counts['probtrackx']}, "
                f"voxels_atlas={voxel_counts['atlas']}",
                file=sys.stderr,
            )
            successful.append(pid)
        except Exception as exc:  # noqa: BLE001 — failure isolation per subject
            print(
                f"[tract_similarity_hcp_loo] WARNING: subject {pid} failed: {exc}",
                file=sys.stderr,
            )
            continue

    return successful


def _resolve_input_dir(output_dir: Path) -> Path:
    """Return the directory containing numeric patient subdirectories.

    Mirror of atlas/_resolve_input_dir.

    When the CLI runs a cohort-level workflow, ``output_dir`` may be
    ``<base>/cohort`` and the numeric patient data lives one level up.
    """
    patient_dirs = [d for d in output_dir.iterdir() if d.is_dir() and d.name.isdigit()]
    if not patient_dirs and output_dir.name == "cohort" and output_dir.parent.is_dir():
        parent_dirs = [d for d in output_dir.parent.iterdir() if d.is_dir() and d.name.isdigit()]
        if parent_dirs:
            return output_dir.parent
    return output_dir


[docs] def verify_requirements(config, context) -> list: """Pre-flight checks for tract_similarity_hcp_loo. Returns a list of human-readable error strings; an empty list means the workflow is ready to build. The CLI surfaces these to the user verbatim. """ if context.output_dir is None: return ["output_dir is not set in the processing context"] out = Path(context.output_dir) if not out.is_dir(): return [f"output_dir does not exist: {out}"] input_dir = _resolve_input_dir(out) tractography_relpath = config.atlas.tractography_relpath minimum = int(config.tract_similarity.hcp_loo.minimum_subjects) subjects = _discover_hcp_subjects(input_dir, tractography_relpath) if not subjects: return [ f"No numeric (HCP) subject directories with valid " f"'{tractography_relpath}/<run>/warped_streamlines/fdt_paths.nii.gz' " f"+ 'waytotal' found under {input_dir}. Run the upstream tractography " f"workflow first." ] if len(subjects) < minimum: return [ f"HCP LOO requires at least N={minimum} subjects with valid fdt+waytotal " f"under {input_dir}, found {len(subjects)}. Run the upstream tractography " f"workflow first or lower tract_similarity.hcp_loo.minimum_subjects." ] return []
_HCP_LOO_INPUTS = [ "input_dir", "tractography_relpath", "normalization_method", "output_subdir", "subject_threshold_mode", "subject_threshold_value", "atlas_threshold_mode", "atlas_threshold_value", "n_bins", "minimum_subjects", "write_volumes", "learned_prediction_relpath", "learned_support_threshold", "_ready_anchor", ]
[docs] @workflow( name="tract_similarity_hcp_loo", description=( "Per-HCP-subject leave-one-out tract similarity vs the cohort atlas. " "Emits metrics.json + optional NIfTI volumes per subject in the same " "format as the per-patient tract_similarity workflow." ), protocol="hcp", scope="cohort", ) @verify(verify_requirements) def build_workflow(*, config, context) -> "pe.Workflow": """Build the cohort-scope HCP-LOO tract similarity workflow.""" if context.output_dir is None: raise ValueError("output_dir must be set before building the workflow") wf = pe.Workflow(name="tract_similarity_hcp_loo") if context.working_dir: wf.base_dir = str(context.working_dir) input_dir = _resolve_input_dir(Path(context.output_dir)) ts = config.tract_similarity atlas_cfg = config.atlas node = pe.Node( Function( input_names=_HCP_LOO_INPUTS, output_names=["successful_subjects"], function=_hcp_loo_task, ), name="compute_hcp_loo_metrics", ) node.inputs.input_dir = str(input_dir) node.inputs.tractography_relpath = atlas_cfg.tractography_relpath node.inputs.normalization_method = atlas_cfg.normalization_method.value node.inputs.output_subdir = ts.output_subdir node.inputs.subject_threshold_mode = ts.subject_threshold.mode node.inputs.subject_threshold_value = float(ts.subject_threshold.value) node.inputs.atlas_threshold_mode = ts.atlas_threshold.mode node.inputs.atlas_threshold_value = float(ts.atlas_threshold.value) node.inputs.n_bins = int(ts.n_bins) node.inputs.minimum_subjects = int(ts.hcp_loo.minimum_subjects) node.inputs.write_volumes = bool(ts.hcp_loo.write_volumes) node.inputs.learned_prediction_relpath = ts.hcp_loo.learned_prediction_relpath node.inputs.learned_support_threshold = float(ts.hcp_loo.learned_support_threshold) wf.add_nodes([node]) logger.info( "Built tract_similarity_hcp_loo workflow rooted at {} (min_subjects={}, write_volumes={})", input_dir, ts.hcp_loo.minimum_subjects, ts.hcp_loo.write_volumes, ) return wf