"""Similarity metrics between two normalized 3-D tractography volumes.
All functions take two ``numpy.ndarray`` volumes of identical shape (already
normalized to [0, 1]) and return plain ``float`` values suitable for JSON
serialisation. Volumes may be continuous densities or binary masks depending on
the metric family:
- Overlap metrics take binary masks (callers threshold upstream).
- Correlation / distribution metrics take continuous volumes.
- Spatial-distance metrics take binary masks and a voxel-size tuple in mm.
The Dice implementation is imported from
:mod:`thesis.workflows.atlas.qc_metrics` to avoid duplication.
"""
from __future__ import annotations
from typing import Tuple
import numpy as np
from scipy import ndimage
from thesis.workflows.atlas.qc_metrics import dice_score
__all__ = [
"overlap_metrics",
"correlation_metrics",
"distance_metrics",
"distribution_metrics",
]
# ---------------------------------------------------------------------------
# Overlap (binary)
# ---------------------------------------------------------------------------
[docs]
def overlap_metrics(mask_probtrackx: np.ndarray, mask_atlas: np.ndarray) -> dict:
"""Compute binary-overlap metrics between two masks of identical shape.
Args:
mask_probtrackx: Thresholded probtrackx2 mask (bool or 0/1).
mask_atlas: Thresholded atlas mask (bool or 0/1), identical shape.
Returns:
Dict with dice, jaccard, volume_ratio, volume_abs_diff, sensitivity,
precision. Voxel-based (unitless) volumes in voxel count.
"""
a = np.asarray(mask_probtrackx, dtype=bool)
b = np.asarray(mask_atlas, dtype=bool)
if a.shape != b.shape:
raise ValueError("masks must have identical shapes")
va = int(np.count_nonzero(a))
vb = int(np.count_nonzero(b))
inter = int(np.count_nonzero(a & b))
union = int(np.count_nonzero(a | b))
jaccard = float(inter / union) if union > 0 else 1.0
volume_ratio = float(va / vb) if vb > 0 else float("nan")
sensitivity = float(inter / vb) if vb > 0 else float("nan")
precision = float(inter / va) if va > 0 else float("nan")
return {
"dice": float(dice_score(a, b)),
"jaccard": jaccard,
"volume_ratio": volume_ratio,
"volume_abs_diff": float(abs(va - vb)),
"sensitivity": sensitivity,
"precision": precision,
}
# ---------------------------------------------------------------------------
# Correlation (continuous)
# ---------------------------------------------------------------------------
def _union_support(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> np.ndarray:
"""Indices where at least one volume is non-negligibly non-zero."""
support: np.ndarray = (np.abs(a) > eps) | (np.abs(b) > eps)
return support
[docs]
def correlation_metrics(vol_probtrackx: np.ndarray, vol_atlas: np.ndarray) -> dict:
"""Compute voxelwise similarity metrics on continuous volumes.
Args:
vol_probtrackx: Normalized probtrackx2 density volume (float).
vol_atlas: Normalized atlas density volume (float), identical shape.
Returns:
Dict with pearson, spearman, cosine. Values in [-1, 1].
Returns NaN for metrics undefined on the given volumes (e.g., zero
variance).
"""
a = np.asarray(vol_probtrackx, dtype=np.float64).ravel()
b = np.asarray(vol_atlas, dtype=np.float64).ravel()
if a.shape != b.shape:
raise ValueError("volumes must have identical shapes")
# Restrict to union of non-zero support to avoid the huge zero-background
# dominating the correlation.
support = _union_support(a, b)
if not np.any(support):
return {"pearson": float("nan"), "spearman": float("nan"), "cosine": float("nan")}
a_s = a[support]
b_s = b[support]
pearson = _pearson(a_s, b_s)
spearman = _pearson(_rankdata(a_s), _rankdata(b_s))
cosine = _cosine(a_s, b_s)
return {"pearson": pearson, "spearman": spearman, "cosine": cosine}
def _pearson(a: np.ndarray, b: np.ndarray) -> float:
if a.size < 2:
return float("nan")
da = a - a.mean()
db = b - b.mean()
denom = float(np.sqrt(np.sum(da * da) * np.sum(db * db)))
if denom <= 0.0:
return float("nan")
return float(np.sum(da * db) / denom)
def _rankdata(x: np.ndarray) -> np.ndarray:
"""Average-rank tiebreaking, equivalent to scipy.stats.rankdata(method='average')."""
order = np.argsort(x, kind="mergesort")
ranks = np.empty_like(order, dtype=np.float64)
ranks[order] = np.arange(len(x), dtype=np.float64)
# Average ranks within ties
sorted_x = x[order]
i = 0
while i < len(x):
j = i
while j + 1 < len(x) and sorted_x[j + 1] == sorted_x[i]:
j += 1
if j > i:
mean_rank = 0.5 * (i + j)
ranks[order[i : j + 1]] = mean_rank
i = j + 1
return ranks + 1.0
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
na = float(np.linalg.norm(a))
nb = float(np.linalg.norm(b))
if na == 0.0 or nb == 0.0:
return float("nan")
return float(np.dot(a, b) / (na * nb))
# ---------------------------------------------------------------------------
# Spatial distance (binary)
# ---------------------------------------------------------------------------
[docs]
def distance_metrics(
mask_probtrackx: np.ndarray,
mask_atlas: np.ndarray,
voxel_size_mm: Tuple[float, float, float] = (1.0, 1.0, 1.0),
) -> dict:
"""Compute spatial-distance metrics between two binary masks.
Args:
mask_probtrackx: Thresholded probtrackx2 mask (bool or 0/1).
mask_atlas: Thresholded atlas mask (bool or 0/1), identical shape.
voxel_size_mm: Physical voxel size (sx, sy, sz) in mm for distance
conversion. Extracted from the NIfTI header by callers.
Returns:
Dict with hausdorff_95, mean_surface, centroid (all in mm). Returns NaN
for an empty mask side.
"""
a = np.asarray(mask_probtrackx, dtype=bool)
b = np.asarray(mask_atlas, dtype=bool)
if a.shape != b.shape:
raise ValueError("masks must have identical shapes")
if len(voxel_size_mm) != 3:
raise ValueError("voxel_size_mm must be length 3")
if not np.any(a) or not np.any(b):
nan = float("nan")
return {"hausdorff_95": nan, "mean_surface": nan, "centroid": nan}
scale = np.asarray(voxel_size_mm, dtype=np.float64)
surf_a = _surface_points(a) * scale
surf_b = _surface_points(b) * scale
d_ab = _directed_distances(surf_a, surf_b)
d_ba = _directed_distances(surf_b, surf_a)
all_d = np.concatenate([d_ab, d_ba])
hausdorff_95 = float(np.percentile(all_d, 95))
mean_surface = float(np.mean(all_d))
com_a = np.asarray(ndimage.center_of_mass(a), dtype=np.float64) * scale
com_b = np.asarray(ndimage.center_of_mass(b), dtype=np.float64) * scale
centroid = float(np.linalg.norm(com_a - com_b))
return {
"hausdorff_95": hausdorff_95,
"mean_surface": mean_surface,
"centroid": centroid,
}
def _surface_points(mask: np.ndarray) -> np.ndarray:
"""Return voxel coordinates of the mask's surface layer (one-voxel shell)."""
eroded = ndimage.binary_erosion(mask)
surface = mask & ~eroded
if not np.any(surface):
# Fully thin mask (single voxel or sheet) — use the mask itself.
surface = mask
return np.argwhere(surface).astype(np.float64)
def _directed_distances(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""For each point in *a*, return the distance to the nearest point in *b*.
Uses a KDTree (scipy ``cKDTree``) nearest-neighbour query; the returned
per-point distances drive both the Hausdorff-95 and mean-surface metrics.
"""
# Cap the point clouds to keep this tractable on 1 mm volumes.
max_pts = 20000
if len(a) > max_pts:
idx = np.random.default_rng(0).choice(len(a), size=max_pts, replace=False)
a = a[idx]
if len(b) > max_pts:
idx = np.random.default_rng(0).choice(len(b), size=max_pts, replace=False)
b = b[idx]
from scipy.spatial import cKDTree # local import to keep startup light
tree = cKDTree(b)
distances, _ = tree.query(a, k=1)
return np.asarray(distances, dtype=np.float64)
# ---------------------------------------------------------------------------
# Distribution (continuous, normalised)
# ---------------------------------------------------------------------------
[docs]
def distribution_metrics(
vol_probtrackx: np.ndarray,
vol_atlas: np.ndarray,
n_bins: int = 64,
) -> dict:
"""Compute distribution-level similarity metrics on continuous volumes.
The volumes are treated as probability distributions over voxels (each is
normalised to sum to 1 over its non-negative support). For NMI, both sides
are discretised to ``n_bins`` equal-width bins over [0, max].
Args:
vol_probtrackx: Normalized probtrackx2 density (non-negative float).
vol_atlas: Normalized atlas density (non-negative float), identical shape.
n_bins: Histogram bin count for NMI.
Returns:
Dict with nmi, kl_symmetric, bhattacharyya.
"""
a = np.asarray(vol_probtrackx, dtype=np.float64)
b = np.asarray(vol_atlas, dtype=np.float64)
if a.shape != b.shape:
raise ValueError("volumes must have identical shapes")
nmi = _normalized_mutual_information(a, b, n_bins=n_bins)
pa = _to_prob_distribution(a)
pb = _to_prob_distribution(b)
kl_sym = 0.5 * (_kl_divergence(pa, pb) + _kl_divergence(pb, pa))
bhatt = float(np.sum(np.sqrt(pa * pb)))
return {"nmi": nmi, "kl_symmetric": float(kl_sym), "bhattacharyya": bhatt}
def _to_prob_distribution(vol: np.ndarray, eps: float = 1e-12) -> np.ndarray:
"""Clip negatives and normalise so the volume sums to 1."""
v = np.clip(vol, 0.0, None).ravel()
total = float(v.sum())
if total <= eps:
# Uniform fallback — avoids division by zero and makes KL symmetric
# behave sensibly on all-zero inputs.
return np.full_like(v, 1.0 / v.size)
normalized: np.ndarray = v / total
return normalized
def _kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:
p_safe = np.clip(p, eps, None)
q_safe = np.clip(q, eps, None)
mask = p > eps
return float(np.sum(p_safe[mask] * (np.log(p_safe[mask]) - np.log(q_safe[mask]))))
def _normalized_mutual_information(
a: np.ndarray, b: np.ndarray, n_bins: int = 64, eps: float = 1e-12
) -> float:
"""Histogram-based NMI using the arithmetic-mean normalization.
NMI = 2 * I(A; B) / (H(A) + H(B))
"""
a_flat = a.ravel()
b_flat = b.ravel()
a_max = float(np.max(a_flat))
b_max = float(np.max(b_flat))
if a_max <= 0.0 and b_max <= 0.0:
return float("nan")
# Bin to the same number of bins, each over its own [0, max].
joint, _, _ = np.histogram2d(
a_flat,
b_flat,
bins=n_bins,
range=[[0.0, max(a_max, eps)], [0.0, max(b_max, eps)]],
)
total = float(joint.sum())
if total <= 0.0:
return float("nan")
pxy = joint / total
px = pxy.sum(axis=1)
py = pxy.sum(axis=0)
# Entropies
def _h(p: np.ndarray) -> float:
p_nz = p[p > eps]
return float(-np.sum(p_nz * np.log(p_nz)))
hx = _h(px)
hy = _h(py)
hxy = _h(pxy.ravel())
mi = hx + hy - hxy
denom = hx + hy
if denom <= 0.0:
return float("nan")
return float(2.0 * mi / denom)