"""Collect tractography statistics from per-patient output directories.
Reads ``waytotal``, ``fdt_paths.nii.gz``, and ROI mask files to produce
a structured summary of tractography results for one or more subjects.
Works with both ProbTrackX2 (``tractography/probtrackx2``, the default)
and MRtrix3 (``tractography/mrtrix3``) layouts via the
``tractography_relpath`` parameter.
For cross-method comparison the collector also emits a unified
**Normalized Connection Strength (NCS)** per subject and per hemisphere:
* ProbTrackX2:
``NCS = target_total_streamlines / (n_samples × |seed_voxels|)`` —
per-sample arrival probability in ``[0, 1]``.
* MRtrix3:
``NCS = mu × target_total_streamlines`` — Fibre Bundle Capacity
(units of fibre cross-sectional area). The target sum is already
SIFT2-weighted (the per-target TDI uses ``-tck_weights_in``), so
multiplying by ``mu`` from ``sift2_mu.txt`` gives the FBC value.
The two values are different physical quantities; cross-method
comparison should use Spearman rank / z-scored profiles, not absolute
equality.
Example:
>>> from thesis.workflows.qc.statistics import collect_tractography_stats
>>> stats = collect_tractography_stats(Path("outputs/114823"))
>>> print(stats["waytotal"])
42351
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from thesis.core.logging import get_logger
logger = get_logger(__name__)
try:
import nibabel as nib
NIBABEL_AVAILABLE = True
except ImportError: # pragma: no cover
NIBABEL_AVAILABLE = False
__all__ = [
"collect_tractography_stats",
"collect_batch_stats",
"format_stats_table",
]
def _detect_backend(tractography_relpath: str) -> str:
"""Return ``"mrtrix3"`` or ``"probtrackx2"`` based on the relpath.
Falls back to ``"probtrackx2"`` for unrecognised paths because that
is the historical default and matches the bare ``waytotal`` schema.
"""
rel = tractography_relpath.lower()
if "mrtrix3" in rel:
return "mrtrix3"
return "probtrackx2"
def _read_tractography_params(tractography_dir: Path) -> Dict[str, Any]:
"""Read ``tractography_params.json`` if present.
Returns the parsed dict on success, or ``{}`` when the file is
missing or unparsable. Keys present in the file (``backend``,
``n_samples`` for ProbTrackX2; ``backend``, ``select``, ``mu`` for
MRtrix3) are passed through unchanged.
"""
params_file = tractography_dir / "tractography_params.json"
if not params_file.is_file():
return {}
try:
parsed: Any = json.loads(params_file.read_text())
except (OSError, ValueError) as exc:
logger.debug("Could not read {}: {}", params_file, exc)
return {}
return parsed if isinstance(parsed, dict) else {}
def _read_sift2_mu_file(tractography_dir: Path) -> Optional[float]:
"""Read ``sift2_mu.txt`` (single float) if present.
Returns ``None`` when the file is missing or unparsable. This is
the MRtrix3 proportionality coefficient that converts SIFT2 weight
sums into Fibre Bundle Capacity values.
"""
mu_file = tractography_dir / "sift2_mu.txt"
if not mu_file.is_file():
return None
try:
text = mu_file.read_text().strip()
except OSError:
return None
for token in text.split():
try:
return float(token)
except ValueError:
continue
return None
def _classify_roi_voxel_count(
roi_counts: Dict[str, int],
keyword: str,
) -> int:
"""Sum voxel counts for ROI keys whose name contains ``keyword``.
The canonical 5-tuple uses names like ``main/medulla-seed_transformed``
and ``main/precentral-target_transformed``. Matching on the substring
(case-insensitive) covers single-side and hemisphere-prefixed layouts
(``"left:main/medulla-seed_transformed"``).
Args:
roi_counts: Output of :func:`_roi_voxel_counts` (possibly
hemisphere-prefixed via :func:`_prefix_roi_counts`).
keyword: ``"seed"`` or ``"target"``.
Returns:
Sum of matching counts, or ``0`` if nothing matches.
"""
total = 0
kw = keyword.lower()
for name, count in roi_counts.items():
if kw in name.lower() and int(count) > 0:
total += int(count)
return total
def _compute_ncs(
backend: str,
target_streamlines: Optional[float],
seed_voxels: int,
n_samples: Optional[int] = None,
mu: Optional[float] = None,
) -> Optional[float]:
"""Compute the Normalized Connection Strength for one run.
See module docstring for the per-backend formula. Returns ``None``
when any input required by the active backend is missing or
non-positive.
"""
if target_streamlines is None or target_streamlines <= 0:
return None
if backend == "probtrackx2":
if not n_samples or seed_voxels <= 0:
return None
denom = float(n_samples) * float(seed_voxels)
if denom <= 0:
return None
return float(target_streamlines) / denom
if backend == "mrtrix3":
if mu is None:
return None
return float(mu) * float(target_streamlines)
return None
def _discover_tractography_dirs(tractography_dir: Path) -> List[Path]:
"""Return tractography directories to inspect.
If ``left/`` and/or ``right/`` subdirectories exist, they are treated as
independent hemisphere runs. Otherwise the base directory itself is used.
Args:
tractography_dir: Base ``probtrackx2`` output directory.
Returns:
List of tractography directories to inspect.
"""
hemisphere_dirs = [
hemi_dir
for hemi_dir in (tractography_dir / "left", tractography_dir / "right")
if hemi_dir.is_dir()
]
return hemisphere_dirs or [tractography_dir]
def _merge_space_stats(stats_list: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge one or more ``_fdt_paths_stats`` dictionaries.
Args:
stats_list: Per-run stats dictionaries.
Returns:
Aggregated stats dictionary.
"""
if not stats_list:
return {}
total_streamlines = int(sum(int(s.get("total_streamlines", 0)) for s in stats_list))
nonzero_voxels = int(sum(int(s.get("nonzero_voxels", 0)) for s in stats_list))
total_voxels = int(sum(int(s.get("total_voxels", 0)) for s in stats_list))
max_density = int(max(int(s.get("max_density", 0)) for s in stats_list))
if total_voxels == 0 or nonzero_voxels == 0:
return {
"total_streamlines": total_streamlines,
"nonzero_voxels": nonzero_voxels,
"total_voxels": total_voxels,
"volume_fraction": 0.0,
"mean_density": 0.0,
"max_density": max_density,
"median_density": 0.0,
}
mean_density = round(total_streamlines / nonzero_voxels, 2)
volume_fraction = round(nonzero_voxels / total_voxels, 6)
median_candidates = [float(s.get("median_density", 0.0)) for s in stats_list if s]
median_density = round(float(np.median(median_candidates)), 2) if median_candidates else 0.0
return {
"total_streamlines": total_streamlines,
"nonzero_voxels": nonzero_voxels,
"total_voxels": total_voxels,
"volume_fraction": volume_fraction,
"mean_density": mean_density,
"max_density": max_density,
"median_density": median_density,
}
def _prefix_roi_counts(roi_counts: Dict[str, int], prefix: str) -> Dict[str, int]:
"""Prefix ROI names when aggregating hemisphere-specific outputs."""
return {f"{prefix}:{name}": count for name, count in roi_counts.items()}
def _read_waytotal(tractography_dir: Path) -> Optional[float]:
"""Read the ``waytotal`` text file produced by ProbTrackX2 or MRtrix3.
ProbTrackX2 writes an integer streamline count; MRtrix3 writes a float
(sum of SIFT2 weights). Both round-trip through ``float`` cleanly.
Args:
tractography_dir: Tractography output directory (e.g.
``probtrackx2/`` or ``mrtrix3/{left,right}/``).
Returns:
The waytotal value as a float, or ``None`` if the file is missing
or unreadable.
"""
waytotal_file = tractography_dir / "waytotal"
if not waytotal_file.exists():
return None
try:
text = waytotal_file.read_text().strip()
return float(text)
except (ValueError, OSError) as exc:
logger.debug("Could not read waytotal from {}: {}", waytotal_file, exc)
return None
def _fdt_paths_stats(fdt_path: Path) -> Dict[str, Any]:
"""Compute summary statistics from an ``fdt_paths.nii.gz`` image.
Args:
fdt_path: Path to the track density NIfTI file.
Returns:
Dictionary with keys: ``total_streamlines``, ``nonzero_voxels``,
``total_voxels``, ``volume_fraction``, ``mean_density``,
``max_density``, ``median_density`` (all over non-zero voxels).
"""
if not NIBABEL_AVAILABLE:
return {}
img = nib.load(str(fdt_path))
data: np.ndarray = np.asarray(img.dataobj) # type: ignore[attr-defined]
nonzero_mask = data > 0
nonzero_vals = data[nonzero_mask]
total_voxels = int(data.size)
nonzero_voxels = int(nonzero_vals.size)
if nonzero_voxels == 0:
return {
"total_streamlines": 0,
"nonzero_voxels": 0,
"total_voxels": total_voxels,
"volume_fraction": 0.0,
"mean_density": 0.0,
"max_density": 0.0,
"median_density": 0.0,
}
return {
"total_streamlines": int(np.sum(nonzero_vals)),
"nonzero_voxels": nonzero_voxels,
"total_voxels": total_voxels,
"volume_fraction": round(nonzero_voxels / total_voxels, 6),
"mean_density": round(float(np.mean(nonzero_vals)), 2),
"max_density": int(np.max(nonzero_vals)),
"median_density": round(float(np.median(nonzero_vals)), 2),
}
def discover_roi_dir_files(tractography_dir: Path) -> Tuple[Optional[Path], List[Path]]:
"""Locate ROI mask files under a tractography output directory.
Searches ``rois_merged/``, ``rois_transformed/``, and ``rois/``
(in that order) and returns the first directory that contains at
least one NIfTI mask, together with the sorted list of mask files
found there. Callers apply their own per-file name mapping.
Args:
tractography_dir: The tractography output directory
(e.g. ``.../tractography/probtrackx2``).
Returns:
Tuple of ``(roi_dir, nii_files)``. ``roi_dir`` is ``None`` and
``nii_files`` is empty when no non-empty ROI directory exists.
"""
search_order = [
tractography_dir / "rois_merged",
tractography_dir / "rois_transformed",
tractography_dir / "rois",
]
for roi_dir in search_order:
if not roi_dir.is_dir():
continue
nii_files = sorted(roi_dir.rglob("*.nii*"))
if nii_files:
return roi_dir, nii_files
return None, []
def _roi_voxel_counts(tractography_dir: Path) -> Dict[str, int]:
"""Count non-zero voxels in each ROI mask.
Searches ``rois_merged/``, ``rois_transformed/``, and ``rois/``
(in that order), returning the first non-empty set found.
Args:
tractography_dir: The ``probtrackx2/`` output directory.
Returns:
Mapping of ROI name to non-zero voxel count.
"""
if not NIBABEL_AVAILABLE:
return {}
roi_dir, nii_files = discover_roi_dir_files(tractography_dir)
if roi_dir is None:
return {}
counts: Dict[str, int] = {}
for f in nii_files:
rel = f.relative_to(roi_dir).as_posix()
name = rel.replace(".nii.gz", "").replace(".nii", "")
try:
img = nib.load(str(f))
data = np.asarray(img.dataobj) # type: ignore[attr-defined]
counts[name] = int(np.count_nonzero(data))
except Exception as exc:
logger.debug("Could not read ROI {}: {}", f, exc)
counts[name] = -1
return counts
[docs]
def collect_tractography_stats(
patient_output: Union[str, Path],
patient_id: Optional[str] = None,
tractography_relpath: str = "tractography/probtrackx2",
tract_similarity_subdir: str = "tract_similarity",
) -> Dict[str, Any]:
"""Collect tractography statistics for a single subject.
Reads output files from a tractography run (ProbTrackX2 or MRtrix3)
and returns a structured dictionary of statistics.
Args:
patient_output: Patient-level output directory
(e.g. ``outputs/114823/``).
patient_id: Optional patient identifier (included in the
returned dict for convenience).
tractography_relpath: Relative path under ``patient_output``
where the tractography run lives. Defaults to
``"tractography/probtrackx2"``. Use
``"tractography/mrtrix3"`` for the MRtrix3 backend.
tract_similarity_subdir: Relative path under ``patient_output``
where the ``tract_similarity`` / ``tract_similarity_hcp_loo``
workflow wrote ``metrics.json``. Read opportunistically and
attached under ``result["tract_similarity"]`` when present.
Returns:
Dictionary with keys:
- ``patient_id`` — subject identifier (or ``"unknown"``).
- ``waytotal`` — total streamlines reaching all waypoints.
- ``subject_space`` — fdt_paths statistics in subject space.
- ``template_space`` — fdt_paths statistics in template space
(empty dict if warped_streamlines not present).
- ``roi_voxel_counts`` — non-zero voxel count per ROI mask.
- ``tract_similarity`` — overlap / correlation / distance_mm /
distribution / voxel_counts / thresholds blocks copied from
``<tract_similarity_subdir>/metrics.json`` (only present when
that file exists).
Raises:
FileNotFoundError: If *patient_output* does not exist.
Example:
>>> stats = collect_tractography_stats("outputs/114823")
>>> stats["waytotal"]
42351
"""
out = Path(patient_output)
if not out.exists():
raise FileNotFoundError(f"Patient output directory not found: {out}")
tractography_dir = out / tractography_relpath
run_dirs = _discover_tractography_dirs(tractography_dir)
pid = patient_id or out.name
backend = _detect_backend(tractography_relpath)
result: Dict[str, Any] = {"patient_id": pid, "backend": backend}
hemisphere_runs: Dict[str, Dict[str, Any]] = {}
waytotals: List[float] = []
subject_stats: List[Dict[str, Any]] = []
template_stats: List[Dict[str, Any]] = []
roi_counts_merged: Dict[str, int] = {}
seed_voxel_count: Dict[str, int] = {}
target_voxel_count: Dict[str, int] = {}
per_run_data: List[Dict[str, Any]] = []
for run_dir in run_dirs:
run_name = run_dir.name if run_dir != tractography_dir else "combined"
run_stats: Dict[str, Any] = {}
waytotal = _read_waytotal(run_dir)
run_stats["waytotal"] = waytotal
if waytotal is not None:
waytotals.append(waytotal)
subject_fdt = run_dir / "fdt_paths.nii.gz"
if subject_fdt.exists():
run_stats["subject_space"] = _fdt_paths_stats(subject_fdt)
subject_stats.append(run_stats["subject_space"])
else:
run_stats["subject_space"] = {}
template_fdt = run_dir / "warped_streamlines" / "fdt_paths.nii.gz"
if template_fdt.exists():
run_stats["template_space"] = _fdt_paths_stats(template_fdt)
template_stats.append(run_stats["template_space"])
else:
run_stats["template_space"] = {}
run_roi_counts = _roi_voxel_counts(run_dir)
run_seed_voxels = _classify_roi_voxel_count(run_roi_counts, "seed")
run_target_voxels = _classify_roi_voxel_count(run_roi_counts, "target")
if len(run_dirs) > 1:
run_stats["roi_voxel_counts"] = _prefix_roi_counts(run_roi_counts, run_name)
roi_counts_merged.update(run_stats["roi_voxel_counts"])
else:
run_stats["roi_voxel_counts"] = run_roi_counts
roi_counts_merged.update(run_roi_counts)
# Per-target connectivity stats for the run (needed for NCS).
run_conn_stats: Dict[str, Dict[str, Any]] = {}
try:
from thesis.workflows.qc.checks import collect_connectivity_map_stats
run_conn_stats = collect_connectivity_map_stats(run_dir)
except Exception:
run_conn_stats = {}
if run_conn_stats:
run_stats["connectivity_maps"] = run_conn_stats
# Per-run params and SIFT2 mu.
run_params = _read_tractography_params(run_dir)
if run_params:
run_stats["tractography_params"] = run_params
run_mu = _read_sift2_mu_file(run_dir)
if run_mu is not None:
run_stats.setdefault("tractography_params", {})["mu"] = run_mu
per_run_data.append(
{
"name": run_name,
"stats": run_stats,
"seed_voxels": run_seed_voxels,
"target_voxels": run_target_voxels,
"conn": run_conn_stats,
"params": run_stats.get("tractography_params", {}),
}
)
if run_seed_voxels:
seed_voxel_count[run_name] = run_seed_voxels
if run_target_voxels:
target_voxel_count[run_name] = run_target_voxels
if len(run_dirs) > 1:
hemisphere_runs[run_name] = run_stats
result["waytotal"] = sum(waytotals) if waytotals else None
result["subject_space"] = _merge_space_stats(subject_stats)
result["template_space"] = _merge_space_stats(template_stats)
result["roi_voxel_counts"] = roi_counts_merged
if hemisphere_runs:
result["hemisphere_runs"] = hemisphere_runs
# Combined connectivity map stats (fall back to the first run when
# only the per-hemisphere directories carry seeds_to_* files). In the
# single-run case the combined target directory is the run directory
# itself, so reuse the per-run result instead of re-loading and
# re-summing every connectivity volume a second time.
try:
if len(run_dirs) == 1 and per_run_data:
conn_stats = per_run_data[0]["conn"]
else:
from thesis.workflows.qc.checks import collect_connectivity_map_stats
conn_target_dir = tractography_dir if len(run_dirs) == 1 else run_dirs[0]
conn_stats = collect_connectivity_map_stats(conn_target_dir)
if conn_stats:
result["connectivity_maps"] = conn_stats
except Exception:
pass
# Aggregate seed/target voxel counts.
if seed_voxel_count or target_voxel_count:
combined_seed = sum(seed_voxel_count.values())
combined_target = sum(target_voxel_count.values())
seeds_payload: Dict[str, int] = dict(seed_voxel_count)
targets_payload: Dict[str, int] = dict(target_voxel_count)
seeds_payload["combined"] = combined_seed
targets_payload["combined"] = combined_target
result["seed_voxel_count"] = seeds_payload
result["target_voxel_count"] = targets_payload
# Top-level tractography_params: prefer the first run's params and
# merge any per-run mu if missing at top level.
top_params: Dict[str, Any] = {}
for entry in per_run_data:
for key, val in (entry["params"] or {}).items():
if key == "mu":
continue # mu is per-hemisphere; keep at run level
top_params.setdefault(key, val)
# Promote the first available mu only when there's a single run, to
# avoid implying that two hemispheres share the same mu (they don't).
if len(per_run_data) == 1 and "mu" in (per_run_data[0]["params"] or {}):
top_params["mu"] = per_run_data[0]["params"]["mu"]
if top_params:
result["tractography_params"] = top_params
# NCS — per-run (left/right) and combined.
n_samples = top_params.get("n_samples") if backend == "probtrackx2" else None
ncs: Dict[str, Optional[float]] = {}
def _target_streamlines(conn: Dict[str, Dict[str, Any]]) -> Optional[float]:
if not conn:
return None
# Prefer the canonical "target_dwi" key; otherwise sum all entries.
if "target_dwi" in conn:
return float(conn["target_dwi"].get("total_streamlines", 0) or 0)
total = 0.0
for entry in conn.values():
total += float(entry.get("total_streamlines", 0) or 0)
return total if total > 0 else None
combined_target_streamlines: float = 0.0
have_combined_target = False
for entry in per_run_data:
run_ts = _target_streamlines(entry["conn"])
if run_ts is None:
continue
have_combined_target = True
combined_target_streamlines += run_ts
run_mu = (entry["params"] or {}).get("mu")
run_n = (entry["params"] or {}).get("n_samples", n_samples)
run_ncs = _compute_ncs(
backend=backend,
target_streamlines=run_ts,
seed_voxels=int(entry["seed_voxels"]),
n_samples=run_n,
mu=run_mu,
)
if run_ncs is not None and entry["name"] != "combined":
ncs[entry["name"]] = run_ncs
# Combined NCS — recompute on aggregate seed voxels and aggregate
# target streamlines using the top-level mu (single run) or the
# per-run mu when ambiguous (defer in that case).
if have_combined_target:
combined_seed = seed_voxel_count.get("combined") or sum(seed_voxel_count.values())
combined_mu: Optional[float]
if backend == "mrtrix3":
# Use top-level mu when available (single run); otherwise
# leave combined NCS blank — averaging mu across hemispheres
# would lose the per-run cross-sectional-area meaning.
combined_mu = top_params.get("mu")
else:
combined_mu = None
combined_ncs = _compute_ncs(
backend=backend,
target_streamlines=combined_target_streamlines,
seed_voxels=int(combined_seed),
n_samples=n_samples,
mu=combined_mu,
)
if combined_ncs is not None:
ncs["combined"] = combined_ncs
elif backend == "mrtrix3":
left_ncs = ncs.get("left")
right_ncs = ncs.get("right")
if left_ncs is not None and right_ncs is not None:
# MRtrix3 FBC is additive across disjoint hemispheres; sum
# the per-run NCS when each hemisphere has its own mu.
ncs["combined"] = left_ncs + right_ncs
if ncs:
result["ncs"] = ncs
# SynthSeg QC and volume stats
try:
from thesis.workflows.qc.checks import parse_synthseg_qc_csv, parse_synthseg_volumes_csv
synthseg_qc = parse_synthseg_qc_csv(out)
if synthseg_qc:
result["synthseg_qc"] = synthseg_qc
synthseg_vols = parse_synthseg_volumes_csv(out)
if synthseg_vols:
result["synthseg_volumes"] = synthseg_vols
except Exception:
pass
# tract_similarity metrics — opportunistic: attached when the workflow
# has been run for this subject. Copies only the metric content (not
# the duplicated patient_id) so the per-subject dict stays flat.
ts_file = out / tract_similarity_subdir / "metrics.json"
if ts_file.is_file():
try:
ts_data = json.loads(ts_file.read_text(encoding="utf-8"))
ts_block = {
key: ts_data.get(key)
for key in (
"overlap",
"correlation",
"distance_mm",
"distribution",
"voxel_counts",
"thresholds",
)
if ts_data.get(key) is not None
}
if ts_block:
result["tract_similarity"] = ts_block
except (json.JSONDecodeError, OSError) as exc:
logger.debug("Could not parse {}: {}", ts_file, exc)
logger.debug("Collected stats for {}: waytotal={}", pid, result["waytotal"])
return result
[docs]
def collect_batch_stats(
output_base: Union[str, Path],
patient_ids: Optional[List[str]] = None,
tractography_relpath: str = "tractography/probtrackx2",
tract_similarity_subdir: str = "tract_similarity",
) -> List[Dict[str, Any]]:
"""Collect tractography statistics for multiple subjects.
If *patient_ids* is ``None``, discovers subject directories
automatically by looking for subdirectories of *output_base*
that contain ``<tractography_relpath>/``.
Args:
output_base: Parent directory containing per-subject output
directories (e.g. ``outputs/``).
patient_ids: Explicit list of subject IDs. If ``None``,
auto-discovers from *output_base*.
tractography_relpath: Relative path under each patient
directory where the tractography run lives. Defaults to
``"tractography/probtrackx2"``. Use
``"tractography/mrtrix3"`` for the MRtrix3 backend.
tract_similarity_subdir: Relative path under each patient
directory where ``tract_similarity`` /
``tract_similarity_hcp_loo`` wrote ``metrics.json``. When
present, that file's content is attached under
``subject["tract_similarity"]``.
Returns:
List of per-subject statistics dictionaries (same format
as :func:`collect_tractography_stats`).
Example:
>>> all_stats = collect_batch_stats("outputs/")
>>> for s in all_stats:
... print(s["patient_id"], s["waytotal"])
"""
base = Path(output_base)
if patient_ids is None:
# Auto-discover subjects
patient_ids = []
if base.is_dir():
for child in sorted(base.iterdir()):
if child.is_dir() and (child / tractography_relpath).is_dir():
patient_ids.append(child.name)
results: List[Dict[str, Any]] = []
for pid in patient_ids:
patient_dir = base / pid
if not patient_dir.exists():
logger.warning("Patient directory not found: {}", patient_dir)
continue
try:
stats = collect_tractography_stats(
patient_dir,
patient_id=pid,
tractography_relpath=tractography_relpath,
tract_similarity_subdir=tract_similarity_subdir,
)
results.append(stats)
except Exception as exc:
logger.warning("Failed to collect stats for {}: {}", pid, exc)
return results