Source code for thesis.workflows.qc.track_density

"""Generate QC figures for track density maps on template backgrounds.

Overlays ``fdt_paths.nii.gz`` on a template (or subject) image at
configurable percentile thresholds.  All plotting uses the ``Agg``
backend so the module works in headless environments.

Example:
    >>> from thesis.workflows.qc.track_density import generate_track_density_figures
    >>> paths = generate_track_density_figures(
    ...     fdt_paths=Path("fdt_paths.nii.gz"),
    ...     template_image=Path("MNI152_T1_2mm.nii.gz"),
    ...     output_dir=Path("qc/normtracks"),
    ...     thresholds=[50.0, 90.0, 99.0],
    ... )
"""

from __future__ import annotations

from pathlib import Path
from typing import List, Optional, Union

from thesis.workflows.qc._plotting import configure_headless_matplotlib

configure_headless_matplotlib()

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402

from thesis.core.logging import get_logger  # noqa: E402

logger = get_logger(__name__)

try:
    import nibabel as nib

    NIBABEL_AVAILABLE = True
except ImportError:  # pragma: no cover
    NIBABEL_AVAILABLE = False

try:
    from nilearn.plotting import plot_stat_map

    NILEARN_AVAILABLE = True
except Exception as exc:  # pragma: no cover
    logger.debug("Failed to import nilearn plotting support: {}", exc)
    NILEARN_AVAILABLE = False

_DEFAULT_THRESHOLDS: List[float] = [50.0, 90.0, 99.0]

__all__ = ["generate_track_density_figures"]


def _validate_dependencies() -> None:
    """Raise if required optional libraries are missing."""
    if not NIBABEL_AVAILABLE:
        raise ImportError(
            "nibabel is required for track density QC figures. " "Install with: pip install nibabel"
        )
    if not NILEARN_AVAILABLE:
        raise ImportError(
            "nilearn is required for track density QC figures. " "Install with: pip install nilearn"
        )


def _resolve_path(path: Union[str, Path]) -> Path:
    """Convert to :class:`Path` and verify the file exists."""
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"File not found: {p}")
    return p


def _compute_threshold_value(data: np.ndarray, percentile: float) -> float:
    """Compute an absolute threshold from the non-zero voxels.

    Args:
        data: 3-D array of track density values.
        percentile: Percentile in the range (0, 100).

    Returns:
        The absolute intensity value at *percentile* of non-zero
        voxels.  Returns ``0.0`` if there are no non-zero voxels.
    """
    non_zero = data[data > 0]
    if non_zero.size == 0:
        logger.warning("fdt_paths contains no non-zero voxels — threshold will be 0")
        return 0.0
    return float(np.percentile(non_zero, percentile))


[docs] def generate_track_density_figures( fdt_paths: Union[str, Path], template_image: Union[str, Path], output_dir: Union[str, Path], thresholds: Optional[List[float]] = None, display_mode: str = "ortho", prefix: str = "track_density", ) -> List[Path]: """Generate QC figures overlaying track density on a template image. For every requested percentile threshold the track density map is thresholded at that percentile of non-zero voxels and overlaid on the template background. One PNG is produced per threshold. Args: fdt_paths: Path to ``fdt_paths.nii.gz`` (track density map). template_image: Path to the background image (MNI template for template-space maps, T1w for subject-space maps). output_dir: Directory to write PNGs into. Created if it does not exist. thresholds: Percentile thresholds (default ``[50, 90, 99]``). Each value must be in the range (0, 100). display_mode: nilearn ``display_mode`` parameter. prefix: Filename prefix for the output PNGs. Returns: List of :class:`Path` objects for every generated PNG. Raises: ImportError: If nibabel or nilearn is not installed. FileNotFoundError: If *fdt_paths* or *template_image* does not exist. ValueError: If *thresholds* contains values outside (0, 100). Example: >>> figures = generate_track_density_figures( ... "fdt_paths.nii.gz", ... "MNI152_T1_2mm.nii.gz", ... "qc/normtracks", ... ) """ _validate_dependencies() fdt_path = _resolve_path(fdt_paths) bg_path = _resolve_path(template_image) out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) if thresholds is None: thresholds = list(_DEFAULT_THRESHOLDS) # Validate thresholds. for t in thresholds: if not (0.0 < t < 100.0): raise ValueError(f"Each threshold must be between 0 and 100 (exclusive), got {t}") # Load the density map to compute absolute thresholds. img = nib.load(str(fdt_path)) data: np.ndarray = np.asarray(img.dataobj) # type: ignore[attr-defined] generated: List[Path] = [] for pct in sorted(thresholds): abs_thresh = _compute_threshold_value(data, pct) out_path = out / f"{prefix}_p{int(pct):02d}.png" display = plot_stat_map( stat_map_img=str(fdt_path), bg_img=str(bg_path), threshold=abs_thresh, display_mode=display_mode, title=f"Track density ≥ {int(pct)}th percentile (thresh={abs_thresh:.1f})", colorbar=True, ) display.savefig(str(out_path), dpi=150) display.close() logger.debug("Saved track density figure: {}", out_path) generated.append(out_path) logger.info( "Generated {} track density figure(s) in {}", len(generated), out, ) # Ensure matplotlib releases figure memory. plt.close("all") return generated