Source code for thesis.workflows.tract_similarity._metrics

"""Similarity metrics between two normalized 3-D tractography volumes.

All functions take two ``numpy.ndarray`` volumes of identical shape (already
normalized to [0, 1]) and return plain ``float`` values suitable for JSON
serialisation. Volumes may be continuous densities or binary masks depending on
the metric family:

- Overlap metrics take binary masks (callers threshold upstream).
- Correlation / distribution metrics take continuous volumes.
- Spatial-distance metrics take binary masks and a voxel-size tuple in mm.

The Dice implementation is imported from
:mod:`thesis.workflows.atlas.qc_metrics` to avoid duplication.
"""

from __future__ import annotations

from typing import Tuple

import numpy as np
from scipy import ndimage

from thesis.workflows.atlas.qc_metrics import dice_score

__all__ = [
    "overlap_metrics",
    "correlation_metrics",
    "distance_metrics",
    "distribution_metrics",
]


# ---------------------------------------------------------------------------
# Overlap (binary)
# ---------------------------------------------------------------------------


[docs] def overlap_metrics(mask_probtrackx: np.ndarray, mask_atlas: np.ndarray) -> dict: """Compute binary-overlap metrics between two masks of identical shape. Args: mask_probtrackx: Thresholded probtrackx2 mask (bool or 0/1). mask_atlas: Thresholded atlas mask (bool or 0/1), identical shape. Returns: Dict with dice, jaccard, volume_ratio, volume_abs_diff, sensitivity, precision. Voxel-based (unitless) volumes in voxel count. """ a = np.asarray(mask_probtrackx, dtype=bool) b = np.asarray(mask_atlas, dtype=bool) if a.shape != b.shape: raise ValueError("masks must have identical shapes") va = int(np.count_nonzero(a)) vb = int(np.count_nonzero(b)) inter = int(np.count_nonzero(a & b)) union = int(np.count_nonzero(a | b)) jaccard = float(inter / union) if union > 0 else 1.0 volume_ratio = float(va / vb) if vb > 0 else float("nan") sensitivity = float(inter / vb) if vb > 0 else float("nan") precision = float(inter / va) if va > 0 else float("nan") return { "dice": float(dice_score(a, b)), "jaccard": jaccard, "volume_ratio": volume_ratio, "volume_abs_diff": float(abs(va - vb)), "sensitivity": sensitivity, "precision": precision, }
# --------------------------------------------------------------------------- # Correlation (continuous) # --------------------------------------------------------------------------- def _union_support(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> np.ndarray: """Indices where at least one volume is non-negligibly non-zero.""" support: np.ndarray = (np.abs(a) > eps) | (np.abs(b) > eps) return support
[docs] def correlation_metrics(vol_probtrackx: np.ndarray, vol_atlas: np.ndarray) -> dict: """Compute voxelwise similarity metrics on continuous volumes. Args: vol_probtrackx: Normalized probtrackx2 density volume (float). vol_atlas: Normalized atlas density volume (float), identical shape. Returns: Dict with pearson, spearman, cosine. Values in [-1, 1]. Returns NaN for metrics undefined on the given volumes (e.g., zero variance). """ a = np.asarray(vol_probtrackx, dtype=np.float64).ravel() b = np.asarray(vol_atlas, dtype=np.float64).ravel() if a.shape != b.shape: raise ValueError("volumes must have identical shapes") # Restrict to union of non-zero support to avoid the huge zero-background # dominating the correlation. support = _union_support(a, b) if not np.any(support): return {"pearson": float("nan"), "spearman": float("nan"), "cosine": float("nan")} a_s = a[support] b_s = b[support] pearson = _pearson(a_s, b_s) spearman = _pearson(_rankdata(a_s), _rankdata(b_s)) cosine = _cosine(a_s, b_s) return {"pearson": pearson, "spearman": spearman, "cosine": cosine}
def _pearson(a: np.ndarray, b: np.ndarray) -> float: if a.size < 2: return float("nan") da = a - a.mean() db = b - b.mean() denom = float(np.sqrt(np.sum(da * da) * np.sum(db * db))) if denom <= 0.0: return float("nan") return float(np.sum(da * db) / denom) def _rankdata(x: np.ndarray) -> np.ndarray: """Average-rank tiebreaking, equivalent to scipy.stats.rankdata(method='average').""" order = np.argsort(x, kind="mergesort") ranks = np.empty_like(order, dtype=np.float64) ranks[order] = np.arange(len(x), dtype=np.float64) # Average ranks within ties sorted_x = x[order] i = 0 while i < len(x): j = i while j + 1 < len(x) and sorted_x[j + 1] == sorted_x[i]: j += 1 if j > i: mean_rank = 0.5 * (i + j) ranks[order[i : j + 1]] = mean_rank i = j + 1 return ranks + 1.0 def _cosine(a: np.ndarray, b: np.ndarray) -> float: na = float(np.linalg.norm(a)) nb = float(np.linalg.norm(b)) if na == 0.0 or nb == 0.0: return float("nan") return float(np.dot(a, b) / (na * nb)) # --------------------------------------------------------------------------- # Spatial distance (binary) # ---------------------------------------------------------------------------
[docs] def distance_metrics( mask_probtrackx: np.ndarray, mask_atlas: np.ndarray, voxel_size_mm: Tuple[float, float, float] = (1.0, 1.0, 1.0), ) -> dict: """Compute spatial-distance metrics between two binary masks. Args: mask_probtrackx: Thresholded probtrackx2 mask (bool or 0/1). mask_atlas: Thresholded atlas mask (bool or 0/1), identical shape. voxel_size_mm: Physical voxel size (sx, sy, sz) in mm for distance conversion. Extracted from the NIfTI header by callers. Returns: Dict with hausdorff_95, mean_surface, centroid (all in mm). Returns NaN for an empty mask side. """ a = np.asarray(mask_probtrackx, dtype=bool) b = np.asarray(mask_atlas, dtype=bool) if a.shape != b.shape: raise ValueError("masks must have identical shapes") if len(voxel_size_mm) != 3: raise ValueError("voxel_size_mm must be length 3") if not np.any(a) or not np.any(b): nan = float("nan") return {"hausdorff_95": nan, "mean_surface": nan, "centroid": nan} scale = np.asarray(voxel_size_mm, dtype=np.float64) surf_a = _surface_points(a) * scale surf_b = _surface_points(b) * scale d_ab = _directed_distances(surf_a, surf_b) d_ba = _directed_distances(surf_b, surf_a) all_d = np.concatenate([d_ab, d_ba]) hausdorff_95 = float(np.percentile(all_d, 95)) mean_surface = float(np.mean(all_d)) com_a = np.asarray(ndimage.center_of_mass(a), dtype=np.float64) * scale com_b = np.asarray(ndimage.center_of_mass(b), dtype=np.float64) * scale centroid = float(np.linalg.norm(com_a - com_b)) return { "hausdorff_95": hausdorff_95, "mean_surface": mean_surface, "centroid": centroid, }
def _surface_points(mask: np.ndarray) -> np.ndarray: """Return voxel coordinates of the mask's surface layer (one-voxel shell).""" eroded = ndimage.binary_erosion(mask) surface = mask & ~eroded if not np.any(surface): # Fully thin mask (single voxel or sheet) — use the mask itself. surface = mask return np.argwhere(surface).astype(np.float64) def _directed_distances(a: np.ndarray, b: np.ndarray) -> np.ndarray: """For each point in *a*, return the distance to the nearest point in *b*. Uses a KDTree (scipy ``cKDTree``) nearest-neighbour query; the returned per-point distances drive both the Hausdorff-95 and mean-surface metrics. """ # Cap the point clouds to keep this tractable on 1 mm volumes. max_pts = 20000 if len(a) > max_pts: idx = np.random.default_rng(0).choice(len(a), size=max_pts, replace=False) a = a[idx] if len(b) > max_pts: idx = np.random.default_rng(0).choice(len(b), size=max_pts, replace=False) b = b[idx] from scipy.spatial import cKDTree # local import to keep startup light tree = cKDTree(b) distances, _ = tree.query(a, k=1) return np.asarray(distances, dtype=np.float64) # --------------------------------------------------------------------------- # Distribution (continuous, normalised) # ---------------------------------------------------------------------------
[docs] def distribution_metrics( vol_probtrackx: np.ndarray, vol_atlas: np.ndarray, n_bins: int = 64, ) -> dict: """Compute distribution-level similarity metrics on continuous volumes. The volumes are treated as probability distributions over voxels (each is normalised to sum to 1 over its non-negative support). For NMI, both sides are discretised to ``n_bins`` equal-width bins over [0, max]. Args: vol_probtrackx: Normalized probtrackx2 density (non-negative float). vol_atlas: Normalized atlas density (non-negative float), identical shape. n_bins: Histogram bin count for NMI. Returns: Dict with nmi, kl_symmetric, bhattacharyya. """ a = np.asarray(vol_probtrackx, dtype=np.float64) b = np.asarray(vol_atlas, dtype=np.float64) if a.shape != b.shape: raise ValueError("volumes must have identical shapes") nmi = _normalized_mutual_information(a, b, n_bins=n_bins) pa = _to_prob_distribution(a) pb = _to_prob_distribution(b) kl_sym = 0.5 * (_kl_divergence(pa, pb) + _kl_divergence(pb, pa)) bhatt = float(np.sum(np.sqrt(pa * pb))) return {"nmi": nmi, "kl_symmetric": float(kl_sym), "bhattacharyya": bhatt}
def _to_prob_distribution(vol: np.ndarray, eps: float = 1e-12) -> np.ndarray: """Clip negatives and normalise so the volume sums to 1.""" v = np.clip(vol, 0.0, None).ravel() total = float(v.sum()) if total <= eps: # Uniform fallback — avoids division by zero and makes KL symmetric # behave sensibly on all-zero inputs. return np.full_like(v, 1.0 / v.size) normalized: np.ndarray = v / total return normalized def _kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float: p_safe = np.clip(p, eps, None) q_safe = np.clip(q, eps, None) mask = p > eps return float(np.sum(p_safe[mask] * (np.log(p_safe[mask]) - np.log(q_safe[mask])))) def _normalized_mutual_information( a: np.ndarray, b: np.ndarray, n_bins: int = 64, eps: float = 1e-12 ) -> float: """Histogram-based NMI using the arithmetic-mean normalization. NMI = 2 * I(A; B) / (H(A) + H(B)) """ a_flat = a.ravel() b_flat = b.ravel() a_max = float(np.max(a_flat)) b_max = float(np.max(b_flat)) if a_max <= 0.0 and b_max <= 0.0: return float("nan") # Bin to the same number of bins, each over its own [0, max]. joint, _, _ = np.histogram2d( a_flat, b_flat, bins=n_bins, range=[[0.0, max(a_max, eps)], [0.0, max(b_max, eps)]], ) total = float(joint.sum()) if total <= 0.0: return float("nan") pxy = joint / total px = pxy.sum(axis=1) py = pxy.sum(axis=0) # Entropies def _h(p: np.ndarray) -> float: p_nz = p[p > eps] return float(-np.sum(p_nz * np.log(p_nz))) hx = _h(px) hy = _h(py) hxy = _h(pxy.ravel()) mi = hx + hy - hxy denom = hx + hy if denom <= 0.0: return float("nan") return float(2.0 * mi / denom)