Source code for thesis.workflows.atlas._statistics

# src/thesis/workflows/atlas/_statistics.py
"""Numpy-based atlas statistics computation.

Computes five essential statistics over a cohort of patient tractography volumes:
- mean: Average connectivity across all subjects
- std: Standard deviation across all subjects
- std_error: Standard error of the mean (std / sqrt(n))
- cov: Coefficient of variation (std / mean) for sufficiently strong mean signal
- prob_threshold: Percentage of subjects where voxel exceeds presence threshold
"""

from __future__ import annotations

import numpy as np

from thesis.workflows.atlas._params import (
    DEFAULT_COV_MEAN_THRESHOLD_PCT,
    DEFAULT_PRESENCE_VALUE,
)


[docs] def compute_atlas_statistics( data: np.ndarray, presence_value: float = DEFAULT_PRESENCE_VALUE, cov_mean_threshold_pct: float = DEFAULT_COV_MEAN_THRESHOLD_PCT, ) -> dict[str, np.ndarray]: """Compute all atlas statistics from a stacked patient array. Args: data: Numpy array of shape (n_subjects, X, Y, Z) with normalised values. presence_value: Threshold for prob_threshold calculation. A voxel is counted as 'present' when its value exceeds this threshold. cov_mean_threshold_pct: Fraction of the global atlas mean maximum used to suppress low-signal voxels when computing cov. Returns: Dict mapping statistic names to numpy arrays, each of shape (X, Y, Z). Example: >>> stats = compute_atlas_statistics( ... patient_stack, ... presence_value=0.10, ... cov_mean_threshold_pct=0.01, ... ) >>> mean_map = stats["mean"] """ n_subjects = int(data.shape[0]) mean_map = np.mean(data, axis=0) std_map = np.std(data, axis=0) max_mean = np.max(mean_map) cov_mean_threshold = max_mean * cov_mean_threshold_pct presence_mask = data > presence_value prob_threshold = (np.sum(presence_mask, axis=0) / n_subjects) * 100 cov_mask = mean_map > cov_mean_threshold safe_mean_map = np.where(cov_mask, mean_map, 1.0) cov_map = np.where(cov_mask, std_map / safe_mean_map, 0.0) return { "mean": mean_map.astype(np.float32), "std": std_map.astype(np.float32), "std_error": (std_map / np.sqrt(n_subjects)).astype(np.float32), "cov": cov_map.astype(np.float32), "prob_threshold": prob_threshold.astype(np.float32), }