Source code for thesis.workflows.atlas._statistics
# src/thesis/workflows/atlas/_statistics.py
"""Numpy-based atlas statistics computation.
Computes five essential statistics over a cohort of patient tractography volumes:
- mean: Average connectivity across all subjects
- std: Standard deviation across all subjects
- std_error: Standard error of the mean (std / sqrt(n))
- cov: Coefficient of variation (std / mean) for sufficiently strong mean signal
- prob_threshold: Percentage of subjects where voxel exceeds presence threshold
"""
from __future__ import annotations
import numpy as np
from thesis.workflows.atlas._params import (
DEFAULT_COV_MEAN_THRESHOLD_PCT,
DEFAULT_PRESENCE_VALUE,
)
[docs]
def compute_atlas_statistics(
data: np.ndarray,
presence_value: float = DEFAULT_PRESENCE_VALUE,
cov_mean_threshold_pct: float = DEFAULT_COV_MEAN_THRESHOLD_PCT,
) -> dict[str, np.ndarray]:
"""Compute all atlas statistics from a stacked patient array.
Args:
data: Numpy array of shape (n_subjects, X, Y, Z) with normalised values.
presence_value: Threshold for prob_threshold calculation. A voxel is
counted as 'present' when its value exceeds this threshold.
cov_mean_threshold_pct: Fraction of the global atlas mean maximum used
to suppress low-signal voxels when computing cov.
Returns:
Dict mapping statistic names to numpy arrays, each of shape (X, Y, Z).
Example:
>>> stats = compute_atlas_statistics(
... patient_stack,
... presence_value=0.10,
... cov_mean_threshold_pct=0.01,
... )
>>> mean_map = stats["mean"]
"""
n_subjects = int(data.shape[0])
mean_map = np.mean(data, axis=0)
std_map = np.std(data, axis=0)
max_mean = np.max(mean_map)
cov_mean_threshold = max_mean * cov_mean_threshold_pct
presence_mask = data > presence_value
prob_threshold = (np.sum(presence_mask, axis=0) / n_subjects) * 100
cov_mask = mean_map > cov_mean_threshold
safe_mean_map = np.where(cov_mask, mean_map, 1.0)
cov_map = np.where(cov_mask, std_map / safe_mean_map, 0.0)
return {
"mean": mean_map.astype(np.float32),
"std": std_map.astype(np.float32),
"std_error": (std_map / np.sqrt(n_subjects)).astype(np.float32),
"cov": cov_map.astype(np.float32),
"prob_threshold": prob_threshold.astype(np.float32),
}