Source code for thesis.workflows.learned_atlas.diagnostics

"""Phase-0 diagnostics and non-circular evaluation metrics for the learned atlas.

Framework-free: every function operates on in-memory numpy arrays (typically the
``(n_subjects, X, Y, Z)`` stack from ``atlas/_io._build_patient_stack``) and
returns plain floats / arrays / JSON-serialisable dicts. No torch, Nipype, or
workflow-orchestration dependencies, mirroring
:mod:`thesis.workflows.atlas.qc_metrics` and
:mod:`thesis.workflows.tract_similarity._metrics`.

Two families:

* Phase-0 diagnostics answer \"is a learned deformable atlas worth training on
  this cohort?\" before any GPU time is spent.
* Non-circular evaluation metrics score ``warp(T, field_i)`` against the held-out
  subject on its TRUE support (not the union support that flatters do-nothing
  models), plus a regression-to-mean control.
"""

from __future__ import annotations

import numpy as np

from thesis.core.logging import get_logger
from thesis.workflows.atlas.qc_metrics import build_occupancy_map, threshold_core_mask

logger = get_logger(__name__)

__all__ = [
    "residual_variance_map",
    "centroid_scatter",
    "occupancy_entropy",
    "mass_matched_residual_fraction",
    "phase0_go_no_go",
    "true_support_mask",
    "delta_on_true_support",
    "regression_to_mean_control",
    "evaluate_prediction",
]


def _validate_value_stack(value_maps: np.ndarray) -> np.ndarray:
    """Validate and coerce a subject-first numeric voxel stack to float32."""
    stack = np.asarray(value_maps, dtype=np.float32)
    if stack.ndim < 2:
        raise ValueError("value_maps must have shape (n_subjects, ...)")
    if stack.shape[0] < 1:
        raise ValueError("value_maps must contain at least one subject")
    return stack


def _same_shape(a: np.ndarray, b: np.ndarray) -> None:
    """Raise if two arrays do not share an identical shape."""
    if a.shape != b.shape:
        raise ValueError(f"arrays must have identical shapes, got {a.shape} vs {b.shape}")


[docs] def residual_variance_map(value_maps: np.ndarray) -> np.ndarray: """Per-voxel residual variance of the cohort against its mean.""" stack = _validate_value_stack(value_maps) result: np.ndarray = np.var(stack, axis=0, dtype=np.float32).astype(np.float32) return result
[docs] def centroid_scatter( value_maps: np.ndarray, voxel_size_mm: tuple[float, float, float] = (1.0, 1.0, 1.0), eps: float = 1e-12, ) -> dict[str, float]: """Quantify how far per-subject density centroids scatter about their mean. Large scatter is direct evidence of spatial (not intensity) disagreement - what a learned deformation field is meant to remove. Returns: Dict with ``mean_distance_mm``/``rms_distance_mm``/``max_distance_mm`` and ``n_subjects`` used (empty-mass subjects excluded). """ stack = _validate_value_stack(value_maps) if len(voxel_size_mm) != 3: raise ValueError("voxel_size_mm must be length 3") if stack.ndim != 4: raise ValueError("centroid_scatter requires a (n_subjects, X, Y, Z) stack") scale = np.asarray(voxel_size_mm, dtype=np.float64) grids = np.indices(stack.shape[1:], dtype=np.float64) centroids: list = [] for subject in stack: weights = np.clip(subject.astype(np.float64), 0.0, None) total = float(weights.sum()) if total <= eps: continue com_vox = np.asarray([float((axis * weights).sum() / total) for axis in grids]) centroids.append(com_vox * scale) if len(centroids) < 1: nan = float("nan") return { "mean_distance_mm": nan, "rms_distance_mm": nan, "max_distance_mm": nan, "n_subjects": 0, } coords = np.vstack(centroids) cohort_centroid = coords.mean(axis=0) distances = np.linalg.norm(coords - cohort_centroid, axis=1) return { "mean_distance_mm": float(np.mean(distances)), "rms_distance_mm": float(np.sqrt(np.mean(distances**2))), "max_distance_mm": float(np.max(distances)), "n_subjects": int(len(centroids)), }
[docs] def occupancy_entropy( subject_masks: np.ndarray, threshold: float = 0.0, restrict_to_support: bool = True, eps: float = 1e-12, ) -> dict[str, float]: """Mean per-voxel binary occupancy entropy across a cohort. Accepts a DENSITY stack and binarises internally at ``threshold`` (voxels strictly above it count as occupied), so the >0 binarisation is explicit rather than the silent bool-coercion of ``build_occupancy_map``. Entropy ~0 means voxels are consistently on/off (a sharp template is recoverable); ~1 means subjects flip a coin per voxel. Returns: Dict with ``mean_entropy_bits``, ``max_entropy_bits``, ``support_voxels``. """ masks = np.asarray(subject_masks) > float(threshold) occupancy = build_occupancy_map(masks).astype(np.float64) p = np.clip(occupancy, eps, 1.0 - eps) per_voxel = -(p * np.log2(p) + (1.0 - p) * np.log2(1.0 - p)) per_voxel = np.where((occupancy <= eps) | (occupancy >= 1.0 - eps), 0.0, per_voxel) if restrict_to_support: support = occupancy > eps n_support = int(np.count_nonzero(support)) mean_entropy = float(np.mean(per_voxel[support])) if n_support > 0 else 0.0 else: n_support = int(per_voxel.size) mean_entropy = float(np.mean(per_voxel)) return { "mean_entropy_bits": mean_entropy, "max_entropy_bits": float(np.max(per_voxel)) if per_voxel.size else 0.0, "support_voxels": n_support, }
[docs] def mass_matched_residual_fraction(value_maps: np.ndarray, eps: float = 1e-12) -> float: """Cheap PROXY for the spatially-driven fraction of cohort residual. Rescales every subject to the cohort-mean total mass and re-measures residual variance; the surviving fraction is a heuristic for disagreement a deformation could move. NOTE: no alignment is performed, so this is a proxy, not a true lower bound - it is advisory in the go/no-go decision. Returns: Float in [0, 1]; 0.0 when the cohort residual is negligible. """ stack = _validate_value_stack(value_maps) n = stack.shape[0] if n < 2: return 0.0 flat = stack.reshape(n, -1).astype(np.float64) total_residual = float(np.sum(np.var(flat, axis=0))) if total_residual <= eps: return 0.0 masses = flat.sum(axis=1, keepdims=True) mean_mass = float(np.mean(masses)) gains = mean_mass / np.clip(masses, eps, None) spatial_residual = float(np.sum(np.var(flat * gains, axis=0))) return float(np.clip(spatial_residual / total_residual, 0.0, 1.0))
[docs] def phase0_go_no_go( value_maps: np.ndarray, voxel_size_mm: tuple[float, float, float] = (1.0, 1.0, 1.0), subject_mask_threshold: float = 0.0, core_occupancy_threshold: float = 0.5, min_centroid_scatter_mm: float = 1.0, max_occupancy_entropy_bits: float = 0.9, ) -> dict[str, object]: """Roll the Phase-0 diagnostics into a single go/no-go decision. The hard gate is: spatial residual (centroid scatter above ``min_centroid_scatter_mm``) AND a coherent bundle (mean support entropy below ``max_occupancy_entropy_bits``). The mass-matched residual fraction is reported as an ADVISORY proxy (not part of the conjunction) until validated. Returns: JSON-serialisable dict with ``go`` (bool), ``reasons``, sub-criteria, raw diagnostics, and ``core_voxels``. """ stack = _validate_value_stack(value_maps) masks = stack > float(subject_mask_threshold) scatter = centroid_scatter(stack, voxel_size_mm=voxel_size_mm) entropy = occupancy_entropy(stack, threshold=float(subject_mask_threshold)) mass_fraction = mass_matched_residual_fraction(stack) occupancy = build_occupancy_map(masks) core = threshold_core_mask(occupancy, core_occupancy_threshold) core_voxels = int(np.count_nonzero(core)) scatter_ok = bool( np.isfinite(scatter["mean_distance_mm"]) and scatter["mean_distance_mm"] >= float(min_centroid_scatter_mm) ) entropy_ok = bool(entropy["mean_entropy_bits"] <= float(max_occupancy_entropy_bits)) reasons: list = [] if not scatter_ok: reasons.append( f"centroid scatter {scatter['mean_distance_mm']:.2f} mm < " f"{min_centroid_scatter_mm:.2f} mm threshold (residual not spatial)" ) if not entropy_ok: reasons.append( f"support occupancy entropy {entropy['mean_entropy_bits']:.2f} bits > " f"{max_occupancy_entropy_bits:.2f} (bundle topology too incoherent to recover)" ) go = scatter_ok and entropy_ok if go and not reasons: reasons.append("cohort residual is spatial and recoverable; learned atlas is warranted") return { "go": go, "reasons": reasons, "criteria": { "centroid_scatter_ok": scatter_ok, "occupancy_entropy_ok": entropy_ok, }, "centroid_scatter": scatter, "occupancy_entropy": entropy, "mass_matched_residual_fraction_advisory": float(mass_fraction), "core_voxels": core_voxels, "n_subjects": int(stack.shape[0]), }
[docs] def true_support_mask(truth: np.ndarray, threshold: float = 0.0) -> np.ndarray: """Return the support of the ground-truth subject volume (voxels > threshold).""" return np.asarray(truth, dtype=np.float64) > float(threshold)
[docs] def delta_on_true_support( prediction: np.ndarray, truth: np.ndarray, loo_mean: np.ndarray, support_threshold: float = 0.0, eps: float = 1e-12, ) -> dict[str, float]: """Improvement of the prediction over the LOO mean on the true support. Headline non-circular metric. Restricted to voxels where the held-out subject has real signal, does ``warp(T, field_i)`` reduce error vs the LOO averaging baseline? Returns: Dict with ``mae_pred``, ``mae_loo_mean``, ``delta_mae`` (positive = improvement), ``relative_reduction`` (delta / mae_loo_mean; range (-inf, 1]), ``support_voxels``. NaN when the support is empty. """ pred = np.asarray(prediction, dtype=np.float64) real = np.asarray(truth, dtype=np.float64) base = np.asarray(loo_mean, dtype=np.float64) _same_shape(pred, real) _same_shape(real, base) support = true_support_mask(real, threshold=support_threshold) n_support = int(np.count_nonzero(support)) if n_support == 0: nan = float("nan") return { "mae_pred": nan, "mae_loo_mean": nan, "delta_mae": nan, "relative_reduction": nan, "support_voxels": 0, } mae_pred = float(np.mean(np.abs(pred[support] - real[support]))) mae_base = float(np.mean(np.abs(base[support] - real[support]))) delta = mae_base - mae_pred relative = float(delta / mae_base) if mae_base > eps else float("nan") return { "mae_pred": mae_pred, "mae_loo_mean": mae_base, "delta_mae": float(delta), "relative_reduction": relative, "support_voxels": n_support, }
[docs] def regression_to_mean_control( prediction: np.ndarray, truth: np.ndarray, loo_mean: np.ndarray, support_threshold: float = 0.0, eps: float = 1e-12, ) -> dict[str, float]: """Detect a model that merely reproduces the cohort mean. Correlates predicted departure ``(prediction - loo_mean)`` against real departure ``(truth - loo_mean)`` on the true support. A do-nothing model (prediction ~= loo_mean) yields near-zero correlation / energy ratio. Returns: Dict with ``departure_pearson`` (in [-1, 1]; NaN if constant), ``predicted_departure_energy``, ``real_departure_energy``, ``departure_energy_ratio`` (near 0 flags a do-nothing model). """ pred = np.asarray(prediction, dtype=np.float64) real = np.asarray(truth, dtype=np.float64) base = np.asarray(loo_mean, dtype=np.float64) _same_shape(pred, real) _same_shape(real, base) support = true_support_mask(real, threshold=support_threshold) if not np.any(support): nan = float("nan") return { "departure_pearson": nan, "predicted_departure_energy": 0.0, "real_departure_energy": 0.0, "departure_energy_ratio": nan, } pred_dep = pred[support] - base[support] real_dep = real[support] - base[support] pred_energy = float(np.sum(pred_dep**2)) real_energy = float(np.sum(real_dep**2)) energy_ratio = float(pred_energy / real_energy) if real_energy > eps else float("nan") if pred_dep.size < 2: pearson = float("nan") else: dp = pred_dep - pred_dep.mean() dr = real_dep - real_dep.mean() denom = float(np.sqrt(np.sum(dp * dp) * np.sum(dr * dr))) pearson = float(np.sum(dp * dr) / denom) if denom > eps else float("nan") return { "departure_pearson": pearson, "predicted_departure_energy": pred_energy, "real_departure_energy": real_energy, "departure_energy_ratio": energy_ratio, }
[docs] def evaluate_prediction( prediction: np.ndarray, truth: np.ndarray, loo_mean: np.ndarray, support_threshold: float = 0.0, ) -> dict[str, dict[str, float]]: """Bundle the non-circular metrics for one held-out subject. Returns: Dict with ``delta_on_true_support`` and ``regression_to_mean`` sub-dicts. """ return { "delta_on_true_support": delta_on_true_support( prediction, truth, loo_mean, support_threshold=support_threshold ), "regression_to_mean": regression_to_mean_control( prediction, truth, loo_mean, support_threshold=support_threshold ), }