Source code for thesis.workflows.qc.roi_overlay

"""Generate QC overlay figures for ROI masks on anatomical backgrounds.

All plotting is performed with the ``Agg`` backend so the module works
in headless environments (e.g. compute nodes without an X server).

Example:
    >>> from thesis.workflows.qc.roi_overlay import generate_roi_overlays
    >>> paths = generate_roi_overlays(
    ...     roi_paths={"seed": Path("seed.nii.gz"), "waypoint": Path("wp.nii.gz")},
    ...     background_image=Path("T1w.nii.gz"),
    ...     output_dir=Path("qc/roi_overlays"),
    ... )
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Mapping, Union

from thesis.workflows.qc._plotting import configure_headless_matplotlib

configure_headless_matplotlib()

import matplotlib.pyplot as plt  # noqa: E402

from thesis.core.logging import get_logger  # noqa: E402

logger = get_logger(__name__)

try:
    import nibabel as nib  # noqa: F401

    NIBABEL_AVAILABLE = True
except ImportError:  # pragma: no cover
    NIBABEL_AVAILABLE = False

try:
    from nilearn.plotting import plot_roi

    NILEARN_AVAILABLE = True
except Exception as exc:  # pragma: no cover
    logger.debug("Failed to import nilearn plotting support: {}", exc)
    NILEARN_AVAILABLE = False

# Colours used for combined overlay (one per ROI, cycled if more).
_ROI_COLOURS = [
    "Reds",
    "Blues",
    "Greens",
    "Purples",
    "Oranges",
    "YlOrBr",
]

__all__ = ["generate_roi_overlays"]


def _validate_dependencies() -> None:
    """Raise if required optional libraries are missing."""
    if not NIBABEL_AVAILABLE:
        raise ImportError(
            "nibabel is required for QC overlay generation. " "Install with: pip install nibabel"
        )
    if not NILEARN_AVAILABLE:
        raise ImportError(
            "nilearn is required for QC overlay generation. " "Install with: pip install nilearn"
        )


def _resolve_path(path: Union[str, Path]) -> Path:
    """Convert to :class:`Path` and verify the file exists."""
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"File not found: {p}")
    return p


def _generate_single_roi_overlay(
    roi_path: Path,
    roi_name: str,
    background_image: Path,
    output_dir: Path,
    display_mode: str = "ortho",
    cmap: str = "Reds",
) -> Path:
    """Render a single ROI mask on an anatomical background.

    Args:
        roi_path: NIfTI binary mask.
        roi_name: Human-readable label used in the figure title and filename.
        background_image: T1w (or FA) anatomical image.
        output_dir: Directory for the output PNG.
        display_mode: nilearn display mode (``"ortho"``, ``"x"``, ``"y"``, ``"z"``).
        cmap: Matplotlib colourmap for the overlay.

    Returns:
        Path to the generated PNG file.
    """
    out_path = output_dir / f"{roi_name}_overlay.png"
    display = plot_roi(
        roi_img=str(roi_path),
        bg_img=str(background_image),
        display_mode=display_mode,
        cmap=cmap,
        title=f"ROI: {roi_name}",
    )
    display.savefig(str(out_path), dpi=150)
    display.close()
    logger.debug("Saved ROI overlay: {}", out_path)
    return out_path


def _generate_combined_overlay(
    roi_paths: Dict[str, Path],
    background_image: Path,
    output_dir: Path,
    display_mode: str = "ortho",
) -> Path:
    """Render all ROI masks on a single combined figure.

    Each ROI is drawn in a different colour so they can be
    visually distinguished.

    Args:
        roi_paths: Mapping of ROI name to NIfTI mask path.
        background_image: Anatomical background image.
        output_dir: Directory for the output PNG.
        display_mode: nilearn display mode.

    Returns:
        Path to the generated PNG file.
    """
    out_path = output_dir / "all_rois_overlay.png"

    # Plot the first ROI to get the display axes.
    items = list(roi_paths.items())
    first_name, first_path = items[0]
    cmap = _ROI_COLOURS[0 % len(_ROI_COLOURS)]
    display = plot_roi(
        roi_img=str(first_path),
        bg_img=str(background_image),
        display_mode=display_mode,
        cmap=cmap,
        title="All ROIs (combined)",
    )

    # Overlay subsequent ROIs on the same axes.
    for idx, (name, path) in enumerate(items[1:], start=1):
        cmap = _ROI_COLOURS[idx % len(_ROI_COLOURS)]
        display.add_overlay(str(path), cmap=cmap, transparency=0.4)

    display.savefig(str(out_path), dpi=150)
    display.close()
    logger.debug("Saved combined ROI overlay: {}", out_path)
    return out_path


[docs] def generate_roi_overlays( roi_paths: Mapping[str, Union[str, Path]], background_image: Union[str, Path], output_dir: Union[str, Path], display_mode: str = "ortho", ) -> List[Path]: """Generate QC overlay PNGs for each ROI mask on a background image. Produces one PNG per ROI mask plus a combined figure showing all ROIs in different colours. Runs headlessly (``Agg`` backend). Args: roi_paths: Mapping of ROI name to NIfTI mask path. background_image: T1w (or FA) image used as the anatomical background. output_dir: Directory to write PNGs into. Created if it does not exist. display_mode: nilearn ``display_mode`` parameter (``"ortho"``, ``"x"``, ``"y"``, ``"z"``). Returns: List of :class:`Path` objects for every generated PNG. Raises: ImportError: If nibabel or nilearn is not installed. FileNotFoundError: If the background image or any ROI mask does not exist. ValueError: If ``roi_paths`` is empty. Example: >>> paths = generate_roi_overlays( ... roi_paths={"seed": "seed.nii.gz"}, ... background_image="T1w.nii.gz", ... output_dir="qc/roi_overlays", ... ) """ _validate_dependencies() if not roi_paths: raise ValueError("roi_paths must contain at least one ROI") bg = _resolve_path(background_image) out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) # Resolve and validate all ROI paths up front. resolved: Dict[str, Path] = {} for name, path in roi_paths.items(): resolved[name] = _resolve_path(path) generated: List[Path] = [] # Individual overlays. for idx, (name, roi_path) in enumerate(resolved.items()): cmap = _ROI_COLOURS[idx % len(_ROI_COLOURS)] png = _generate_single_roi_overlay( roi_path=roi_path, roi_name=name, background_image=bg, output_dir=out, display_mode=display_mode, cmap=cmap, ) generated.append(png) # Combined overlay (only useful when there are multiple ROIs). if len(resolved) > 1: combined = _generate_combined_overlay( roi_paths=resolved, background_image=bg, output_dir=out, display_mode=display_mode, ) generated.append(combined) logger.info( "Generated {} QC ROI overlay(s) in {}", len(generated), out, ) # Ensure matplotlib releases figure memory. plt.close("all") return generated