"""Phase-0 diagnostics and non-circular evaluation metrics for the learned atlas.
Framework-free: every function operates on in-memory numpy arrays (typically the
``(n_subjects, X, Y, Z)`` stack from ``atlas/_io._build_patient_stack``) and
returns plain floats / arrays / JSON-serialisable dicts. No torch, Nipype, or
workflow-orchestration dependencies, mirroring
:mod:`thesis.workflows.atlas.qc_metrics` and
:mod:`thesis.workflows.tract_similarity._metrics`.
Two families:
* Phase-0 diagnostics answer \"is a learned deformable atlas worth training on
this cohort?\" before any GPU time is spent.
* Non-circular evaluation metrics score ``warp(T, field_i)`` against the held-out
subject on its TRUE support (not the union support that flatters do-nothing
models), plus a regression-to-mean control.
"""
from __future__ import annotations
import numpy as np
from thesis.core.logging import get_logger
from thesis.workflows.atlas.qc_metrics import build_occupancy_map, threshold_core_mask
logger = get_logger(__name__)
__all__ = [
"residual_variance_map",
"centroid_scatter",
"occupancy_entropy",
"mass_matched_residual_fraction",
"phase0_go_no_go",
"true_support_mask",
"delta_on_true_support",
"regression_to_mean_control",
"evaluate_prediction",
]
def _validate_value_stack(value_maps: np.ndarray) -> np.ndarray:
"""Validate and coerce a subject-first numeric voxel stack to float32."""
stack = np.asarray(value_maps, dtype=np.float32)
if stack.ndim < 2:
raise ValueError("value_maps must have shape (n_subjects, ...)")
if stack.shape[0] < 1:
raise ValueError("value_maps must contain at least one subject")
return stack
def _same_shape(a: np.ndarray, b: np.ndarray) -> None:
"""Raise if two arrays do not share an identical shape."""
if a.shape != b.shape:
raise ValueError(f"arrays must have identical shapes, got {a.shape} vs {b.shape}")
[docs]
def residual_variance_map(value_maps: np.ndarray) -> np.ndarray:
"""Per-voxel residual variance of the cohort against its mean."""
stack = _validate_value_stack(value_maps)
result: np.ndarray = np.var(stack, axis=0, dtype=np.float32).astype(np.float32)
return result
[docs]
def centroid_scatter(
value_maps: np.ndarray,
voxel_size_mm: tuple[float, float, float] = (1.0, 1.0, 1.0),
eps: float = 1e-12,
) -> dict[str, float]:
"""Quantify how far per-subject density centroids scatter about their mean.
Large scatter is direct evidence of spatial (not intensity) disagreement -
what a learned deformation field is meant to remove.
Returns:
Dict with ``mean_distance_mm``/``rms_distance_mm``/``max_distance_mm``
and ``n_subjects`` used (empty-mass subjects excluded).
"""
stack = _validate_value_stack(value_maps)
if len(voxel_size_mm) != 3:
raise ValueError("voxel_size_mm must be length 3")
if stack.ndim != 4:
raise ValueError("centroid_scatter requires a (n_subjects, X, Y, Z) stack")
scale = np.asarray(voxel_size_mm, dtype=np.float64)
grids = np.indices(stack.shape[1:], dtype=np.float64)
centroids: list = []
for subject in stack:
weights = np.clip(subject.astype(np.float64), 0.0, None)
total = float(weights.sum())
if total <= eps:
continue
com_vox = np.asarray([float((axis * weights).sum() / total) for axis in grids])
centroids.append(com_vox * scale)
if len(centroids) < 1:
nan = float("nan")
return {
"mean_distance_mm": nan,
"rms_distance_mm": nan,
"max_distance_mm": nan,
"n_subjects": 0,
}
coords = np.vstack(centroids)
cohort_centroid = coords.mean(axis=0)
distances = np.linalg.norm(coords - cohort_centroid, axis=1)
return {
"mean_distance_mm": float(np.mean(distances)),
"rms_distance_mm": float(np.sqrt(np.mean(distances**2))),
"max_distance_mm": float(np.max(distances)),
"n_subjects": int(len(centroids)),
}
[docs]
def occupancy_entropy(
subject_masks: np.ndarray,
threshold: float = 0.0,
restrict_to_support: bool = True,
eps: float = 1e-12,
) -> dict[str, float]:
"""Mean per-voxel binary occupancy entropy across a cohort.
Accepts a DENSITY stack and binarises internally at ``threshold`` (voxels
strictly above it count as occupied), so the >0 binarisation is explicit
rather than the silent bool-coercion of ``build_occupancy_map``.
Entropy ~0 means voxels are consistently on/off (a sharp template is
recoverable); ~1 means subjects flip a coin per voxel.
Returns:
Dict with ``mean_entropy_bits``, ``max_entropy_bits``, ``support_voxels``.
"""
masks = np.asarray(subject_masks) > float(threshold)
occupancy = build_occupancy_map(masks).astype(np.float64)
p = np.clip(occupancy, eps, 1.0 - eps)
per_voxel = -(p * np.log2(p) + (1.0 - p) * np.log2(1.0 - p))
per_voxel = np.where((occupancy <= eps) | (occupancy >= 1.0 - eps), 0.0, per_voxel)
if restrict_to_support:
support = occupancy > eps
n_support = int(np.count_nonzero(support))
mean_entropy = float(np.mean(per_voxel[support])) if n_support > 0 else 0.0
else:
n_support = int(per_voxel.size)
mean_entropy = float(np.mean(per_voxel))
return {
"mean_entropy_bits": mean_entropy,
"max_entropy_bits": float(np.max(per_voxel)) if per_voxel.size else 0.0,
"support_voxels": n_support,
}
[docs]
def mass_matched_residual_fraction(value_maps: np.ndarray, eps: float = 1e-12) -> float:
"""Cheap PROXY for the spatially-driven fraction of cohort residual.
Rescales every subject to the cohort-mean total mass and re-measures residual
variance; the surviving fraction is a heuristic for disagreement a
deformation could move. NOTE: no alignment is performed, so this is a proxy,
not a true lower bound - it is advisory in the go/no-go decision.
Returns:
Float in [0, 1]; 0.0 when the cohort residual is negligible.
"""
stack = _validate_value_stack(value_maps)
n = stack.shape[0]
if n < 2:
return 0.0
flat = stack.reshape(n, -1).astype(np.float64)
total_residual = float(np.sum(np.var(flat, axis=0)))
if total_residual <= eps:
return 0.0
masses = flat.sum(axis=1, keepdims=True)
mean_mass = float(np.mean(masses))
gains = mean_mass / np.clip(masses, eps, None)
spatial_residual = float(np.sum(np.var(flat * gains, axis=0)))
return float(np.clip(spatial_residual / total_residual, 0.0, 1.0))
[docs]
def phase0_go_no_go(
value_maps: np.ndarray,
voxel_size_mm: tuple[float, float, float] = (1.0, 1.0, 1.0),
subject_mask_threshold: float = 0.0,
core_occupancy_threshold: float = 0.5,
min_centroid_scatter_mm: float = 1.0,
max_occupancy_entropy_bits: float = 0.9,
) -> dict[str, object]:
"""Roll the Phase-0 diagnostics into a single go/no-go decision.
The hard gate is: spatial residual (centroid scatter above
``min_centroid_scatter_mm``) AND a coherent bundle (mean support entropy
below ``max_occupancy_entropy_bits``). The mass-matched residual fraction is
reported as an ADVISORY proxy (not part of the conjunction) until validated.
Returns:
JSON-serialisable dict with ``go`` (bool), ``reasons``, sub-criteria,
raw diagnostics, and ``core_voxels``.
"""
stack = _validate_value_stack(value_maps)
masks = stack > float(subject_mask_threshold)
scatter = centroid_scatter(stack, voxel_size_mm=voxel_size_mm)
entropy = occupancy_entropy(stack, threshold=float(subject_mask_threshold))
mass_fraction = mass_matched_residual_fraction(stack)
occupancy = build_occupancy_map(masks)
core = threshold_core_mask(occupancy, core_occupancy_threshold)
core_voxels = int(np.count_nonzero(core))
scatter_ok = bool(
np.isfinite(scatter["mean_distance_mm"])
and scatter["mean_distance_mm"] >= float(min_centroid_scatter_mm)
)
entropy_ok = bool(entropy["mean_entropy_bits"] <= float(max_occupancy_entropy_bits))
reasons: list = []
if not scatter_ok:
reasons.append(
f"centroid scatter {scatter['mean_distance_mm']:.2f} mm < "
f"{min_centroid_scatter_mm:.2f} mm threshold (residual not spatial)"
)
if not entropy_ok:
reasons.append(
f"support occupancy entropy {entropy['mean_entropy_bits']:.2f} bits > "
f"{max_occupancy_entropy_bits:.2f} (bundle topology too incoherent to recover)"
)
go = scatter_ok and entropy_ok
if go and not reasons:
reasons.append("cohort residual is spatial and recoverable; learned atlas is warranted")
return {
"go": go,
"reasons": reasons,
"criteria": {
"centroid_scatter_ok": scatter_ok,
"occupancy_entropy_ok": entropy_ok,
},
"centroid_scatter": scatter,
"occupancy_entropy": entropy,
"mass_matched_residual_fraction_advisory": float(mass_fraction),
"core_voxels": core_voxels,
"n_subjects": int(stack.shape[0]),
}
[docs]
def true_support_mask(truth: np.ndarray, threshold: float = 0.0) -> np.ndarray:
"""Return the support of the ground-truth subject volume (voxels > threshold)."""
return np.asarray(truth, dtype=np.float64) > float(threshold)
[docs]
def delta_on_true_support(
prediction: np.ndarray,
truth: np.ndarray,
loo_mean: np.ndarray,
support_threshold: float = 0.0,
eps: float = 1e-12,
) -> dict[str, float]:
"""Improvement of the prediction over the LOO mean on the true support.
Headline non-circular metric. Restricted to voxels where the held-out
subject has real signal, does ``warp(T, field_i)`` reduce error vs the LOO
averaging baseline?
Returns:
Dict with ``mae_pred``, ``mae_loo_mean``, ``delta_mae`` (positive =
improvement), ``relative_reduction`` (delta / mae_loo_mean; range
(-inf, 1]), ``support_voxels``. NaN when the support is empty.
"""
pred = np.asarray(prediction, dtype=np.float64)
real = np.asarray(truth, dtype=np.float64)
base = np.asarray(loo_mean, dtype=np.float64)
_same_shape(pred, real)
_same_shape(real, base)
support = true_support_mask(real, threshold=support_threshold)
n_support = int(np.count_nonzero(support))
if n_support == 0:
nan = float("nan")
return {
"mae_pred": nan,
"mae_loo_mean": nan,
"delta_mae": nan,
"relative_reduction": nan,
"support_voxels": 0,
}
mae_pred = float(np.mean(np.abs(pred[support] - real[support])))
mae_base = float(np.mean(np.abs(base[support] - real[support])))
delta = mae_base - mae_pred
relative = float(delta / mae_base) if mae_base > eps else float("nan")
return {
"mae_pred": mae_pred,
"mae_loo_mean": mae_base,
"delta_mae": float(delta),
"relative_reduction": relative,
"support_voxels": n_support,
}
[docs]
def regression_to_mean_control(
prediction: np.ndarray,
truth: np.ndarray,
loo_mean: np.ndarray,
support_threshold: float = 0.0,
eps: float = 1e-12,
) -> dict[str, float]:
"""Detect a model that merely reproduces the cohort mean.
Correlates predicted departure ``(prediction - loo_mean)`` against real
departure ``(truth - loo_mean)`` on the true support. A do-nothing model
(prediction ~= loo_mean) yields near-zero correlation / energy ratio.
Returns:
Dict with ``departure_pearson`` (in [-1, 1]; NaN if constant),
``predicted_departure_energy``, ``real_departure_energy``,
``departure_energy_ratio`` (near 0 flags a do-nothing model).
"""
pred = np.asarray(prediction, dtype=np.float64)
real = np.asarray(truth, dtype=np.float64)
base = np.asarray(loo_mean, dtype=np.float64)
_same_shape(pred, real)
_same_shape(real, base)
support = true_support_mask(real, threshold=support_threshold)
if not np.any(support):
nan = float("nan")
return {
"departure_pearson": nan,
"predicted_departure_energy": 0.0,
"real_departure_energy": 0.0,
"departure_energy_ratio": nan,
}
pred_dep = pred[support] - base[support]
real_dep = real[support] - base[support]
pred_energy = float(np.sum(pred_dep**2))
real_energy = float(np.sum(real_dep**2))
energy_ratio = float(pred_energy / real_energy) if real_energy > eps else float("nan")
if pred_dep.size < 2:
pearson = float("nan")
else:
dp = pred_dep - pred_dep.mean()
dr = real_dep - real_dep.mean()
denom = float(np.sqrt(np.sum(dp * dp) * np.sum(dr * dr)))
pearson = float(np.sum(dp * dr) / denom) if denom > eps else float("nan")
return {
"departure_pearson": pearson,
"predicted_departure_energy": pred_energy,
"real_departure_energy": real_energy,
"departure_energy_ratio": energy_ratio,
}
[docs]
def evaluate_prediction(
prediction: np.ndarray,
truth: np.ndarray,
loo_mean: np.ndarray,
support_threshold: float = 0.0,
) -> dict[str, dict[str, float]]:
"""Bundle the non-circular metrics for one held-out subject.
Returns:
Dict with ``delta_on_true_support`` and ``regression_to_mean`` sub-dicts.
"""
return {
"delta_on_true_support": delta_on_true_support(
prediction, truth, loo_mean, support_threshold=support_threshold
),
"regression_to_mean": regression_to_mean_control(
prediction, truth, loo_mean, support_threshold=support_threshold
),
}