"""Volume loading, normalisation, and resolution helpers.
Handles:
- Loading ``fdt_paths.nii.gz`` and dividing by a scalar waytotal to produce a
probability-like density in [0, 1].
- Loading a warped atlas mean volume and min-max normalising it to [0, 1].
- Extracting voxel spacing (mm) from a NIfTI header for distance metrics.
- Resolving the warped-atlas file from either an explicit path or a glob
pattern — the ANTs-generated suffix (e.g. ``_SyN_template_to_patient``)
varies with the registration transform type, so glob fallback is useful.
"""
from __future__ import annotations
from pathlib import Path
from typing import Tuple
import nibabel as nib
import numpy as np
__all__ = [
"discover_patient_dirs",
"load_fdt_paths_normalized",
"load_probtrackx_volume",
"load_atlas_normalized",
"resolve_atlas_file",
"voxel_size_mm",
]
# Reserved subdirectory names under a cohort output root that are not patients.
_NON_PATIENT_DIR_NAMES = frozenset({"cohort", "temp", "work"})
[docs]
def discover_patient_dirs(input_dir: Path | str, *, sort: bool = True) -> list[Path]:
"""Return the patient subdirectories under a cohort output root.
A patient directory is any immediate subdirectory of *input_dir* whose name
is not one of the reserved cohort scratch names (``cohort``, ``temp``,
``work``). Shared by the cohort task bodies and verifiers so the scan rule
has a single source of truth.
Args:
input_dir: Cohort output root to scan.
sort: When True (default), return the directories sorted by name.
Returns:
List of patient directory paths.
"""
dirs = [
d for d in Path(input_dir).iterdir() if d.is_dir() and d.name not in _NON_PATIENT_DIR_NAMES
]
return sorted(dirs) if sort else dirs
def _read_waytotal(path: Path) -> float:
"""Read the scalar streamline count from the probtrackx2 ``waytotal`` file."""
text = Path(path).read_text().strip()
value = float(text.split()[0]) if text else 0.0
if value <= 0.0:
raise ValueError(f"waytotal must be positive; got {value!r} from {path}")
return value
[docs]
def load_fdt_paths_normalized(
fdt_path: Path | str,
waytotal_path: Path | str,
) -> Tuple[np.ndarray, np.ndarray]:
"""Load fdt_paths and divide by waytotal to give a probability-like volume.
Args:
fdt_path: Path to ``fdt_paths.nii.gz``.
waytotal_path: Path to the ``waytotal`` text file.
Returns:
Tuple ``(volume, affine)`` where *volume* is float32 in approximately
[0, 1] (may exceed 1 in occasional high-count voxels) and *affine* is
the 4x4 NIfTI affine from the fdt image.
"""
img = nib.load(str(fdt_path))
arr = np.asarray(img.dataobj, dtype=np.float32) # type: ignore[attr-defined]
waytotal = _read_waytotal(Path(waytotal_path))
normalized = arr / np.float32(waytotal)
return normalized, np.asarray(img.affine, dtype=np.float64) # type: ignore[attr-defined]
def _discover_run_dirs(probtrackx_dir: Path) -> list[Path]:
"""Return the list of probtrackx run directories for one patient.
Mirrors the atlas workflow's ``_discover_probtrackx_run_dirs``: when
``hemisphere: "both-separately"`` is set, probtrackx2 writes to
``probtrackx_dir/left/`` and ``probtrackx_dir/right/``; otherwise results
sit directly in ``probtrackx_dir``.
"""
hemisphere_dirs = [
hemi_dir
for hemi_dir in (probtrackx_dir / "left", probtrackx_dir / "right")
if hemi_dir.is_dir()
]
return hemisphere_dirs or [probtrackx_dir]
[docs]
def load_probtrackx_volume(
probtrackx_dir: Path | str,
fdt_name: str = "fdt_paths.nii.gz",
waytotal_name: str = "waytotal",
) -> Tuple[np.ndarray, np.ndarray]:
"""Load a patient's probtrackx2 density, summing hemispheres when present.
Auto-detects layout:
- Single run: ``probtrackx_dir/fdt_paths.nii.gz`` + ``probtrackx_dir/waytotal``.
- ``both-separately``: sums the waytotal-normalised volumes from
``probtrackx_dir/left/`` and ``probtrackx_dir/right/``.
Matches how ``atlas._io._build_patient_volume`` combines hemispheres so
that comparisons against the cohort mean atlas are like-for-like.
Args:
probtrackx_dir: Directory under the patient output (e.g.
``tractography/probtrackx2``).
fdt_name: Name of the fdt_paths file within each run directory.
waytotal_name: Name of the waytotal text file within each run directory.
Returns:
Tuple ``(volume, affine)``. Affine is taken from the first run loaded.
Raises:
FileNotFoundError: If no valid (fdt_paths, waytotal) pair is found.
"""
base = Path(probtrackx_dir)
run_dirs = _discover_run_dirs(base)
loaded: list[np.ndarray] = []
affine: np.ndarray | None = None
for run_dir in run_dirs:
fdt = run_dir / fdt_name
waytotal = run_dir / waytotal_name
if not fdt.is_file() or not waytotal.is_file():
continue
vol, aff = load_fdt_paths_normalized(fdt, waytotal)
loaded.append(vol)
if affine is None:
affine = aff
if not loaded or affine is None:
raise FileNotFoundError(
f"No valid fdt_paths/waytotal pair found under {base}. "
f"Expected '{fdt_name}' + '{waytotal_name}' directly in {base}, "
f"or in hemisphere subdirectories (left/, right/)."
)
if len(loaded) == 1:
return loaded[0], affine
combined = np.sum(loaded, axis=0).astype(np.float32)
return combined, affine
[docs]
def load_atlas_normalized(atlas_path: Path | str) -> Tuple[np.ndarray, np.ndarray]:
"""Load a warped atlas map and min-max normalise to [0, 1].
Args:
atlas_path: Path to the warped atlas NIfTI in patient native space.
Returns:
Tuple ``(volume, affine)`` where *volume* is float32 in [0, 1].
"""
img = nib.load(str(atlas_path))
arr = np.asarray(img.dataobj, dtype=np.float32) # type: ignore[attr-defined]
arr = np.clip(arr, 0.0, None)
max_val = float(np.max(arr))
if max_val > 0.0:
arr = arr / np.float32(max_val)
return arr, np.asarray(img.affine, dtype=np.float64) # type: ignore[attr-defined]
[docs]
def resolve_atlas_file(base_dir: Path | str, path_or_glob: str) -> Path:
"""Resolve the warped atlas file, supporting either an explicit path or a glob.
Args:
base_dir: Patient output directory (used as the root for relative
paths/globs).
path_or_glob: Either an explicit relative path (e.g.
``"atlas_in_patient_space/atlas_mean_SyN_template_to_patient_in_patient_space.nii.gz"``)
or a glob pattern (e.g. ``"atlas_in_patient_space/atlas_mean*.nii.gz"``).
Returns:
Resolved absolute path.
Raises:
FileNotFoundError: If no file matches.
ValueError: If the glob matches more than one file.
"""
base = Path(base_dir)
candidate = base / path_or_glob
if candidate.is_file():
return candidate
matches = sorted(base.glob(path_or_glob))
if not matches:
raise FileNotFoundError(
f"No atlas file matches '{path_or_glob}' under {base}. Run the "
f"'atlas_to_patient' workflow first, or set "
f"'tract_similarity.atlas_relpath' to the correct pattern."
)
if len(matches) > 1:
names = ", ".join(m.name for m in matches)
raise ValueError(
f"Glob '{path_or_glob}' under {base} matches {len(matches)} files "
f"({names}). Set 'tract_similarity.atlas_relpath' to a specific file."
)
return matches[0]
[docs]
def voxel_size_mm(affine: np.ndarray) -> Tuple[float, float, float]:
"""Extract per-axis voxel spacing in mm from a NIfTI 4x4 affine."""
affine = np.asarray(affine, dtype=np.float64)
if affine.shape != (4, 4):
raise ValueError("affine must be shape (4, 4)")
spacings = np.linalg.norm(affine[:3, :3], axis=0)
return float(spacings[0]), float(spacings[1]), float(spacings[2])