Source code for thesis.workflows.tract_similarity.sweep

"""Cohort-wide grid search over tract_similarity binarisation thresholds.

Loads every patient's normalised probtrackx2 density and warped atlas once,
evaluates Dice across the configured 2-D grid of
``(subject_threshold.value, atlas_threshold.value)``, aggregates Dice across
the cohort per grid cell, and reports the cell that maximises the chosen
aggregation (mean or median) along with a CSV per-patient table, a CSV grid
summary, a JSON best-cell record, and an optional heatmap PNG.

Registered as ``thesis run -w tract_similarity_sweep -c <config>``. Cohort
one-shot — no per-patient fan-out, no IdentityInterface gate.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Tuple

import nipype.pipeline.engine as pe
from nipype.interfaces.utility import Function

from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.decorators import verify, workflow
from thesis.core.logging import get_logger
from thesis.workflows.tract_similarity._io import discover_patient_dirs
from thesis.workflows.tract_similarity.workflow import _resolve_cohort_input_dir

logger = get_logger(__name__)

__all__ = [
    "build_sweep_workflow",
    "verify_sweep_requirements",
    "_expand_grid",
    "_sweep_patient_dice",
    "_aggregate_grid",
    "_pick_best",
    "_write_heatmap",
]


# ---------------------------------------------------------------------------
# Grid expansion
# ---------------------------------------------------------------------------


def _expand_grid(spec: object) -> List[float]:
    """Expand a ``ThresholdGridConfig`` (or equivalent dict) to a sorted list.

    Accepts either a Pydantic ``ThresholdGridConfig`` or a plain dict with the
    same fields. The ``start``/``stop``/``step`` form yields an inclusive,
    evenly spaced grid: ``round((stop - start) / step)`` steps give
    ``n_steps + 1`` points spanning ``start`` through ``stop``.
    """
    import numpy as np

    if hasattr(spec, "model_dump"):
        data = spec.model_dump()
    elif isinstance(spec, dict):
        data = dict(spec)
    else:
        raise TypeError(
            f"_expand_grid expects a ThresholdGridConfig or dict; got {type(spec).__name__}"
        )

    values = data.get("values")
    if values:
        out = sorted(float(v) for v in values)
        if not out or not all(0.0 < v < 1.0 for v in out):
            raise ValueError("values must be a non-empty list of floats in (0, 1)")
        return out

    start = data.get("start")
    stop = data.get("stop")
    step = data.get("step")
    if start is None or stop is None or step is None:
        raise ValueError("grid spec needs values OR start/stop/step")
    if stop <= start:
        raise ValueError("stop must be > start")
    if step <= 0:
        raise ValueError("step must be > 0")

    n_steps = int(round((stop - start) / step))
    grid = [round(float(v), 12) for v in np.linspace(start, stop, n_steps + 1)]
    if not grid:
        raise ValueError(f"empty grid for start={start}, stop={stop}, step={step}")
    return grid


# ---------------------------------------------------------------------------
# Per-patient Dice grid
# ---------------------------------------------------------------------------


def _sweep_patient_dice(
    vol_prob: "object",
    vol_atlas: "object",
    subject_grid: List[float],
    atlas_grid: List[float],
    mode: str = "fraction",
) -> Dict[Tuple[float, float], float]:
    """Evaluate Dice at every grid cell for one patient's loaded volumes.

    Builds each subject mask once (per ``subject_grid`` value) and each atlas
    mask once (per ``atlas_grid`` value), then iterates the 2-D product. Dice
    follows the project convention (``empty_value=1.0`` when both masks are
    empty), matching :func:`thesis.workflows.atlas.qc_metrics.dice_score`.
    """
    import numpy as np

    from thesis.workflows.atlas.qc_metrics import dice_score

    arr_prob = np.asarray(vol_prob)
    arr_atlas = np.asarray(vol_atlas)
    if arr_prob.shape != arr_atlas.shape:
        raise ValueError(f"shape mismatch: subject {arr_prob.shape} vs atlas {arr_atlas.shape}")

    # Hoist the per-array cutoff bases out of the grid loop. For mode="fraction"
    # the cutoff is value * max(arr) (with the same 1e-12 floor as
    # apply_threshold); for mode="absolute" it is the raw value. Computing
    # np.max once per array avoids recomputing it for every grid value.
    def _cutoff(value: float, arr_max: float) -> float:
        if mode == "fraction":
            return float(value) * arr_max
        if mode == "absolute":
            return float(value)
        raise ValueError(f"unknown threshold mode: {mode!r}")

    prob_max = max(float(np.max(arr_prob)), 1e-12)
    atlas_max = max(float(np.max(arr_atlas)), 1e-12)

    subject_masks = {s_thr: arr_prob > _cutoff(s_thr, prob_max) for s_thr in subject_grid}
    atlas_masks = {a_thr: arr_atlas > _cutoff(a_thr, atlas_max) for a_thr in atlas_grid}

    out: Dict[Tuple[float, float], float] = {}
    for s_thr, mask_s in subject_masks.items():
        for a_thr, mask_a in atlas_masks.items():
            out[(s_thr, a_thr)] = float(dice_score(mask_s, mask_a))
    return out


# ---------------------------------------------------------------------------
# Cohort aggregation and best-cell selection
# ---------------------------------------------------------------------------


def _aggregate_grid(
    per_patient: Dict[str, Dict[Tuple[float, float], float]],
    aggregation: str,
) -> Dict[Tuple[float, float], dict]:
    """Collapse per-patient Dice grids into per-cell summary statistics."""
    import numpy as np

    if aggregation not in ("mean", "median"):
        raise ValueError(f"aggregation must be 'mean' or 'median'; got {aggregation!r}")

    cells: Dict[Tuple[float, float], List[float]] = {}
    for grid in per_patient.values():
        for cell, dice in grid.items():
            cells.setdefault(cell, []).append(float(dice))

    summary: Dict[Tuple[float, float], dict] = {}
    for cell, values in cells.items():
        arr = np.asarray(values, dtype=np.float64)
        summary[cell] = {
            "n": int(arr.size),
            "mean_dice": float(np.mean(arr)) if arr.size else float("nan"),
            "median_dice": float(np.median(arr)) if arr.size else float("nan"),
            "std_dice": float(np.std(arr)) if arr.size else float("nan"),
        }
    return summary


def _pick_best(
    grid_summary: Dict[Tuple[float, float], dict],
    aggregation: str,
) -> Tuple[float, float, float]:
    """Return ``(best_subject_thr, best_atlas_thr, best_score)`` by aggregation."""
    if not grid_summary:
        raise ValueError("grid_summary is empty")
    key = "mean_dice" if aggregation == "mean" else "median_dice"
    best_cell, best_stats = max(grid_summary.items(), key=lambda kv: kv[1][key])
    return best_cell[0], best_cell[1], float(best_stats[key])


# ---------------------------------------------------------------------------
# Heatmap rendering
# ---------------------------------------------------------------------------


def _write_heatmap(
    grid_summary: Dict[Tuple[float, float], dict],
    subject_grid: List[float],
    atlas_grid: List[float],
    out_path: Path,
    aggregation: str = "mean",
) -> None:
    """Render the aggregated Dice grid as a PNG heatmap with the argmax marked."""
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import numpy as np

    key = "mean_dice" if aggregation == "mean" else "median_dice"
    grid = np.full((len(subject_grid), len(atlas_grid)), np.nan, dtype=np.float64)
    for i, s_thr in enumerate(subject_grid):
        for j, a_thr in enumerate(atlas_grid):
            stats = grid_summary.get((s_thr, a_thr))
            if stats is not None:
                grid[i, j] = stats[key]

    best_s, best_a, _ = _pick_best(grid_summary, aggregation)
    best_i = subject_grid.index(best_s)
    best_j = atlas_grid.index(best_a)

    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(grid, origin="lower", aspect="auto", cmap="viridis")
    ax.set_xticks(range(len(atlas_grid)))
    ax.set_xticklabels([f"{v:.3f}" for v in atlas_grid], rotation=90, fontsize=7)
    ax.set_yticks(range(len(subject_grid)))
    ax.set_yticklabels([f"{v:.3f}" for v in subject_grid], fontsize=7)
    ax.set_xlabel("atlas_threshold.value")
    ax.set_ylabel("subject_threshold.value")
    ax.set_title(f"tract_similarity_sweep: {key} across cohort")
    ax.scatter([best_j], [best_i], marker="*", s=180, c="red", edgecolors="white")
    fig.colorbar(im, ax=ax, label=key)
    fig.tight_layout()
    fig.savefig(str(out_path), dpi=120)
    plt.close(fig)


# ---------------------------------------------------------------------------
# Nipype task body (runs in subprocess — all imports local)
# ---------------------------------------------------------------------------


def _sweep_dice_task(
    input_dir: str,
    cohort_output_dir: str,
    output_subdir: str,
    probtrackx_relpath: str,
    fdt_name: str,
    waytotal_name: str,
    atlas_relpath: str,
    subject_threshold_grid: list,
    atlas_threshold_grid: list,
    threshold_mode: str,
    aggregation: str,
    emit_heatmap: bool,
) -> list:
    """Cohort sweep task: load each patient once, evaluate the Dice grid, aggregate."""
    import csv
    import gc
    import json
    import sys
    from pathlib import Path

    from thesis.workflows.tract_similarity._io import (
        discover_patient_dirs,
        load_atlas_normalized,
        load_probtrackx_volume,
        resolve_atlas_file,
    )
    from thesis.workflows.tract_similarity.sweep import (
        _aggregate_grid,
        _pick_best,
        _sweep_patient_dice,
        _write_heatmap,
    )

    in_path = Path(input_dir)
    out_root = Path(cohort_output_dir) / output_subdir
    out_root.mkdir(parents=True, exist_ok=True)

    patient_dirs = discover_patient_dirs(in_path)

    per_patient: dict = {}
    skipped: list = []
    for pdir in patient_dirs:
        pid = pdir.name
        try:
            probtrackx_dir = pdir / probtrackx_relpath
            atlas_file = resolve_atlas_file(pdir, atlas_relpath)
            vol_prob, _ = load_probtrackx_volume(
                probtrackx_dir, fdt_name=fdt_name, waytotal_name=waytotal_name
            )
            vol_atlas, _ = load_atlas_normalized(atlas_file)
        except (FileNotFoundError, ValueError) as exc:
            print(
                f"[tract_similarity_sweep] skipping {pid}: {exc}",
                file=sys.stderr,
            )
            skipped.append(pid)
            continue

        if vol_prob.shape != vol_atlas.shape:
            print(
                f"[tract_similarity_sweep] skipping {pid}: shape mismatch "
                f"{vol_prob.shape} vs {vol_atlas.shape}",
                file=sys.stderr,
            )
            skipped.append(pid)
            del vol_prob, vol_atlas
            continue

        per_patient[pid] = _sweep_patient_dice(
            vol_prob, vol_atlas, subject_threshold_grid, atlas_threshold_grid, threshold_mode
        )
        print(
            f"[tract_similarity_sweep] {pid}: evaluated " f"{len(per_patient[pid])} grid cells",
            file=sys.stderr,
        )
        del vol_prob, vol_atlas
        gc.collect()

    if not per_patient:
        raise RuntimeError(
            f"No usable patient inputs found under {in_path}; "
            f"skipped={skipped}. Run 'hcp' and 'atlas_to_patient' first."
        )

    grid_summary = _aggregate_grid(per_patient, aggregation)
    best_s, best_a, best_score = _pick_best(grid_summary, aggregation)

    per_patient_csv = out_root / "sweep_per_patient.csv"
    with per_patient_csv.open("w", encoding="utf-8", newline="") as fh:
        writer = csv.writer(fh)
        writer.writerow(["patient_id", "subject_thr", "atlas_thr", "dice"])
        for pid, grid in per_patient.items():
            for (s_thr, a_thr), dice in grid.items():
                writer.writerow([pid, s_thr, a_thr, dice])

    summary_csv = out_root / "sweep_summary.csv"
    with summary_csv.open("w", encoding="utf-8", newline="") as fh:
        writer = csv.writer(fh)
        writer.writerow(["subject_thr", "atlas_thr", "n", "mean_dice", "median_dice", "std_dice"])
        for (s_thr, a_thr), stats in sorted(grid_summary.items()):
            writer.writerow(
                [
                    s_thr,
                    a_thr,
                    stats["n"],
                    stats["mean_dice"],
                    stats["median_dice"],
                    stats["std_dice"],
                ]
            )

    best_stats = grid_summary[(best_s, best_a)]
    best_json = out_root / "sweep_best.json"
    best_json.write_text(
        json.dumps(
            {
                "best_subject_threshold": best_s,
                "best_atlas_threshold": best_a,
                "best_mean_dice": best_stats["mean_dice"],
                "best_median_dice": best_stats["median_dice"],
                "best_std_dice": best_stats["std_dice"],
                "best_score": best_score,
                "n_patients": best_stats["n"],
                "aggregation": aggregation,
                "threshold_mode": threshold_mode,
                "subject_grid": list(subject_threshold_grid),
                "atlas_grid": list(atlas_threshold_grid),
                "skipped_patients": skipped,
            },
            indent=2,
        ),
        encoding="utf-8",
    )

    outputs = [str(per_patient_csv), str(summary_csv), str(best_json)]

    if emit_heatmap:
        heatmap_png = out_root / "sweep_heatmap.png"
        _write_heatmap(
            grid_summary,
            list(subject_threshold_grid),
            list(atlas_threshold_grid),
            heatmap_png,
            aggregation=aggregation,
        )
        outputs.append(str(heatmap_png))

    print(
        f"[tract_similarity_sweep] best: subject_thr={best_s} atlas_thr={best_a} "
        f"{aggregation}_dice={best_score:.4f} (n={best_stats['n']})",
        file=sys.stderr,
    )
    return outputs


# ---------------------------------------------------------------------------
# Verifier
# ---------------------------------------------------------------------------


[docs] def verify_sweep_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]: """Pre-flight checks for the cohort-wide sweep.""" errors: List[str] = [] if context.output_dir is None: errors.append("output_dir is not set in the processing context") return errors input_dir = _resolve_cohort_input_dir(Path(context.output_dir)) if not input_dir.is_dir(): errors.append(f"cohort input directory does not exist: {input_dir}") return errors patient_dirs = discover_patient_dirs(input_dir, sort=False) if not patient_dirs: errors.append( f"No patient subdirectories found under {input_dir}. " f"Run 'hcp' and 'atlas_to_patient' for at least one patient first." ) return errors
# --------------------------------------------------------------------------- # Workflow factory # ---------------------------------------------------------------------------
[docs] @workflow( name="tract_similarity_sweep", description=( "Cohort-wide grid search over subject and atlas binarization " "thresholds, maximizing mean (or median) Dice across patients." ), protocol="hcp", scope="cohort", ) @verify(verify_sweep_requirements) def build_sweep_workflow(*, config: PipelineConfig, context: ProcessingContext) -> pe.Workflow: """Build the cohort-wide tract_similarity threshold sweep workflow.""" if context.output_dir is None: raise ValueError("output_dir must be set before building the workflow") wf = pe.Workflow(name="tract_similarity_sweep") if context.working_dir: wf.base_dir = str(context.working_dir) ts = config.tract_similarity sw = config.tract_similarity_sweep subject_grid = _expand_grid(sw.subject_threshold_grid) atlas_grid = _expand_grid(sw.atlas_threshold_grid) cohort_out = Path(context.output_dir) input_dir = _resolve_cohort_input_dir(cohort_out) node = pe.Node( Function( input_names=[ "input_dir", "cohort_output_dir", "output_subdir", "probtrackx_relpath", "fdt_name", "waytotal_name", "atlas_relpath", "subject_threshold_grid", "atlas_threshold_grid", "threshold_mode", "aggregation", "emit_heatmap", ], output_names=["generated_files"], function=_sweep_dice_task, ), name="sweep_dice", ) node.inputs.input_dir = str(input_dir) node.inputs.cohort_output_dir = str(cohort_out) node.inputs.output_subdir = sw.output_subdir node.inputs.probtrackx_relpath = ts.probtrackx_relpath node.inputs.fdt_name = ts.fdt_name node.inputs.waytotal_name = ts.waytotal_name node.inputs.atlas_relpath = ts.atlas_relpath node.inputs.subject_threshold_grid = subject_grid node.inputs.atlas_threshold_grid = atlas_grid node.inputs.threshold_mode = "fraction" node.inputs.aggregation = sw.aggregation node.inputs.emit_heatmap = bool(sw.emit_heatmap) wf.add_nodes([node]) logger.info( "Built tract_similarity_sweep workflow: {} subject thresholds x " "{} atlas thresholds = {} cells, aggregation={}", len(subject_grid), len(atlas_grid), len(subject_grid) * len(atlas_grid), sw.aggregation, ) return wf