Source code for thesis.core.nipype.interfaces.freesurfer

"""
Custom FreeSurfer Nipype interfaces.

Provides CommandLine-based Nipype interfaces for FreeSurfer tools that do not
have upstream wrappers in nipype.interfaces.freesurfer, enabling them to
participate in Nipype workflows with full caching and parallelisation support.

Classes:
    SynthSeg: Nipype interface for mri_synthseg (brain MRI segmentation).
    SynthStrip: Nipype interface for mri_synthstrip (brain extraction).
"""

import os
import shlex
import subprocess
import sys
from typing import Any

from nipype.interfaces.base import (
    CommandLine,
    CommandLineInputSpec,
    File,
    TraitedSpec,
    isdefined,
    traits,
)

from thesis.core.logging import get_logger

logger = get_logger(__name__)

__all__ = ["SynthSeg", "SynthStrip"]


def _is_gpu_oom(error_text: str) -> bool:
    """Return True when the error string describes a GPU out-of-memory event."""
    oom_markers = ("ResourceExhaustedError", "OOM when allocating", "out of memory")
    return any(marker in error_text for marker in oom_markers)


# ---------------------------------------------------------------------------
# Input / Output specifications
# ---------------------------------------------------------------------------


class SynthSegInputSpec(CommandLineInputSpec):
    """Input specification for mri_synthseg."""

    input_image = File(
        argstr="--i %s",
        mandatory=True,
        exists=True,
        desc=(
            "Input brain scan. Accepts a single NIfTI/MGZ file, a directory of "
            "scans, or a plain-text file whose lines list individual image paths."
        ),
    )
    output_segmentation = File(
        argstr="--o %s",
        mandatory=True,
        hash_files=False,
        desc=(
            "Destination for the output segmentation. Must match the type of "
            "--i (file → file, directory → directory, text → text)."
        ),
    )

    # --- Segmentation options -----------------------------------------------
    parc = traits.Bool(
        argstr="--parc",
        desc="Enable cortical parcellation in addition to whole-brain segmentation.",
    )
    robust = traits.Bool(
        argstr="--robust",
        desc=(
            "Use the robust variant, recommended for clinical data with poor contrast "
            "or unusual acquisition parameters."
        ),
    )
    fast = traits.Bool(
        argstr="--fast",
        desc="Disable postprocessing for ~2× speed at a small accuracy cost.",
    )
    ct = traits.Bool(
        argstr="--ct",
        desc="Optimise for CT input by clipping to [0, 80] Hounsfield units.",
    )
    v1 = traits.Bool(
        argstr="--v1",
        desc="Run the original SynthSeg 1.0 model instead of the default 2.0 model.",
    )

    # --- Optional outputs ----------------------------------------------------
    vol = File(
        argstr="--vol %s",
        hash_files=False,
        desc="Path for a CSV file containing estimated volumes per brain region.",
    )
    qc = File(
        argstr="--qc %s",
        hash_files=False,
        desc="Path for a CSV file containing per-scan quality-control scores.",
    )
    post = File(
        argstr="--post %s",
        hash_files=False,
        desc="Path for a NIfTI file containing 3-D posterior probability maps.",
    )
    resample = File(
        argstr="--resample %s",
        hash_files=False,
        desc="Save the internally resampled (1 mm isotropic) image to this path.",
    )

    # --- Performance --------------------------------------------------------
    cpu = traits.Bool(
        argstr="--cpu",
        desc=(
            "Force CPU execution. GPU (~6 s/scan) is used by default when available; "
            "CPU takes ~2 min/scan."
        ),
    )
    use_gpu = traits.Bool(
        False,
        usedefault=True,
        desc=(
            "Marks this node as requiring a GPU slot. "
            "GPU serialization is handled at the scheduler level via "
            "``n_gpu_procs=1`` in ``plugin_args``."
        ),
    )
    threads = traits.Int(
        argstr="--threads %d",
        desc="Number of CPU threads to use (default: 1). Only relevant with --cpu.",
    )

    # --- Crop ---------------------------------------------------------------
    crop = traits.Int(
        argstr="--crop %d",
        desc=(
            "Crop input volumes to this size (must be divisible by 32) before "
            "processing. Reduces memory usage at the cost of field-of-view."
        ),
    )


class SynthSegOutputSpec(TraitedSpec):
    """Output specification for mri_synthseg."""

    output_segmentation = File(desc="Brain segmentation label map (NIfTI/MGZ).")
    vol = File(desc="CSV file with per-region volumes (if --vol was specified).")
    qc = File(desc="CSV file with QC scores (if --qc was specified).")
    post = File(desc="Posterior probability map (if --post was specified).")
    resample = File(desc="Resampled image at 1 mm isotropic (if --resample was specified).")


class SynthStripInputSpec(CommandLineInputSpec):
    """Input specification for ``mri_synthstrip``."""

    input_image = File(
        mandatory=True,
        exists=True,
        desc="Input brain MRI volume to strip.",
    )
    output_image = File(
        mandatory=True,
        hash_files=False,
        desc="Destination for the skull-stripped brain image.",
    )
    output_mask = File(
        mandatory=True,
        hash_files=False,
        desc="Destination for the binary brain mask.",
    )
    use_gpu = traits.Bool(
        False,
        usedefault=True,
        desc="Request GPU execution through the FreeSurfer Python launcher.",
    )
    gpu_device = traits.Int(
        desc="Optional CUDA device index exposed via CUDA_VISIBLE_DEVICES.",
    )


class SynthStripOutputSpec(TraitedSpec):
    """Output specification for ``mri_synthstrip``."""

    output_image = File(desc="Skull-stripped brain image.")
    output_mask = File(desc="Binary brain mask.")


# ---------------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------------


[docs] class SynthSeg(CommandLine): """ Nipype interface for FreeSurfer's ``mri_synthseg`` segmentation tool. SynthSeg segments brain MRI scans (and optionally performs cortical parcellation) without requiring a specific acquisition protocol or contrast. It is robust to clinical data and non-standard sequences. Because this class subclasses ``CommandLine``, it integrates transparently with Nipype's workflow engine: - **Caching**: results are hashed and re-used automatically. - **Parallelisation**: place it in a ``MapNode`` or connect it to the workflow graph; the MultiProc / SGEGraph plugins handle scheduling. - **Provenance**: command line, inputs, and outputs are recorded. Requirements: ``mri_synthseg`` must be on ``PATH`` (provided by FreeSurfer ≥ 7.3). Examples: Basic segmentation:: from thesis.core.nipype.interfaces.freesurfer import SynthSeg seg = SynthSeg( input_image="sub-01_T1w.nii.gz", output_segmentation="sub-01_synthseg.nii.gz", ) result = seg.run() With parcellation, volumes CSV, and CPU execution:: seg = SynthSeg( input_image="sub-01_T1w.nii.gz", output_segmentation="sub-01_synthseg.nii.gz", parc=True, vol="sub-01_volumes.csv", cpu=True, threads=4, ) Inside a Nipype workflow:: from nipype import Node, Workflow from thesis.core.nipype.interfaces.freesurfer import SynthSeg seg_node = Node(SynthSeg(), name="synthseg") seg_node.inputs.input_image = "T1w.nii.gz" seg_node.inputs.output_segmentation = "T1w_synthseg.nii.gz" seg_node.inputs.parc = True wf = Workflow(name="seg_workflow") wf.add_nodes([seg_node]) wf.run() Batch segmentation via MapNode:: from nipype import MapNode seg_mapnode = MapNode( SynthSeg(parc=True, cpu=True, threads=4), iterfield=["input_image", "output_segmentation"], name="synthseg_batch", ) seg_mapnode.inputs.input_image = ["sub-01_T1w.nii.gz", "sub-02_T1w.nii.gz"] seg_mapnode.inputs.output_segmentation = ["sub-01_seg.nii.gz", "sub-02_seg.nii.gz"] """ _cmd = "mri_synthseg" input_spec = SynthSegInputSpec output_spec = SynthSegOutputSpec def _run_interface(self, runtime): """ Execute mri_synthseg with GPU OOM fallback. GPU serialization is now handled at the scheduler level via ``n_gpu_procs=1`` in ``plugin_args`` — no process-level lock needed. If ``cpu=True`` was already set explicitly, the node runs in CPU mode as normal. If a GPU out-of-memory error occurs, the node retries with ``CUDA_VISIBLE_DEVICES=""`` (CPU fallback). """ # Honour an explicit cpu=True flag — run directly. if isdefined(self.inputs.cpu) and self.inputs.cpu: return super()._run_interface(runtime) # GPU mode — run directly, handle OOM. try: return super()._run_interface(runtime) except RuntimeError as exc: if _is_gpu_oom(str(exc)): logger.warning("SynthSeg: GPU OOM — falling back to CPU (CUDA_VISIBLE_DEVICES='')") original_node_environ = ( dict(self.inputs.environ) if isdefined(self.inputs.environ) else None ) self.inputs.environ = {"CUDA_VISIBLE_DEVICES": ""} try: return super()._run_interface(runtime) finally: if original_node_environ is not None: self.inputs.environ = original_node_environ else: self.inputs.environ = {} raise def _list_outputs(self) -> dict[str, Any]: """Map declared input paths to expected output files.""" outputs: dict[str, Any] = self.output_spec().get() outputs["output_segmentation"] = os.path.abspath(self.inputs.output_segmentation) for field in ("vol", "qc", "post", "resample"): value = getattr(self.inputs, field, None) if isdefined(value): outputs[field] = os.path.abspath(str(value)) return outputs
[docs] class SynthStrip(CommandLine): """Nipype CommandLine interface for FreeSurfer's ``mri_synthstrip`` brain extraction tool. ``mri_synthstrip`` performs learning-based skull-stripping without requiring a specific acquisition protocol or contrast. This wrapper drives it through a custom ``_run_interface`` that bypasses the standard Nipype ``cmdline()`` machinery so that GPU execution can be handled via the FreeSurfer Python launcher with automatic CPU fallback. Requirements: ``FREESURFER_HOME`` must be set in the environment. FreeSurfer ≥ 7 is required for the ``mri_synthstrip`` script to be present at ``$FREESURFER_HOME/python/scripts/mri_synthstrip``. GPU behaviour: When ``use_gpu=True`` the FreeSurfer Python launcher is invoked with the ``-g`` flag and, if ``gpu_device`` is defined, ``CUDA_VISIBLE_DEVICES`` is set accordingly. If the GPU run exits non-zero the node automatically retries with the same launcher but without ``-g`` (CPU fallback). A torch diagnostic probe is logged at WARNING level before the retry to aid debugging. Attributes: _cmd (str): Fallback binary name (``"mri_synthstrip"``); used for the CPU non-launcher path only. input_spec: :class:`SynthStripInputSpec` output_spec: :class:`SynthStripOutputSpec` Example: Basic skull-stripping inside a Nipype workflow:: from nipype import Node from thesis.core.nipype.interfaces.freesurfer import SynthStrip node = Node(SynthStrip(), name="synthstrip") node.inputs.input_image = "sub-01_T1w.nii.gz" node.inputs.output_image = "sub-01_brain.nii.gz" node.inputs.output_mask = "sub-01_mask.nii.gz" result = node.run() """ _cmd = "mri_synthstrip" input_spec = SynthStripInputSpec output_spec = SynthStripOutputSpec def _build_launcher_cmd(self, use_gpu: bool) -> list[str]: """Build a SynthStrip command using the active conda Python interpreter. Uses ``sys.executable`` so that the CUDA-capable ``torch`` in the conda environment is available for GPU execution. FreeSurfer's versioned ``site-packages`` directory is intentionally **not** injected because it contains ``surfa`` compiled for FreeSurfer's embedded Python 3.8, which is ABI-incompatible with a newer conda Python. Instead, ``surfa`` must be installed directly in the conda environment:: pip install surfa Only ``$FREESURFER_HOME/python/packages`` is appended to ``sys.path`` so that any non-compiled FreeSurfer resources remain accessible. Args: use_gpu: Append the ``-g`` flag for GPU-accelerated execution. Returns: Command list suitable for ``subprocess.run``. """ fs_home = os.environ.get("FREESURFER_HOME") if not fs_home: raise RuntimeError("mri_synthstrip failed: FREESURFER_HOME is not set") script = os.path.join(fs_home, "python", "scripts", "mri_synthstrip") # Inject only the non-compiled packages dir; the versioned site-packages # (containing the Python-3.8-compiled surfa .so files) is excluded to # prevent ABI mismatches with the conda Python. extra_paths: list[str] = [] packages_dir = os.path.join(fs_home, "python", "packages") if os.path.isdir(packages_dir): extra_paths.append(packages_dir) bootstrap = ( "import sys, runpy; " f"[sys.path.append(p) for p in {extra_paths!r} if p not in sys.path]; " f"sys.argv = [{script!r}] + sys.argv[1:]; " f"runpy.run_path({script!r}, run_name='__main__')" ) cmd = [ sys.executable, "-c", bootstrap, "-i", self.inputs.input_image, "-o", self.inputs.output_image, "-m", self.inputs.output_mask, ] if use_gpu: cmd.append("-g") logger.debug( "Using SynthStrip launcher: python={} script={} gpu={}", sys.executable, script, use_gpu, ) return cmd @staticmethod def _run_command( runtime: Any, cmd: list[str], env: dict[str, str], ) -> subprocess.CompletedProcess[str]: """Execute a subprocess and mirror the result onto the Nipype runtime.""" result = subprocess.run(cmd, capture_output=True, text=True, check=False, env=env) runtime.cmdline = subprocess.list2cmdline(cmd) if os.name == "nt" else shlex.join(cmd) runtime.stdout = result.stdout runtime.stderr = result.stderr runtime.returncode = result.returncode return result def _run_interface(self, runtime: Any, _correct_return_codes: tuple[int, ...] = (0,)) -> Any: """Execute ``mri_synthstrip`` with GPU/CPU fallback.""" # Nipype's rtc context manager checks runtime.success_codes on exit; initialise # it here so the attribute exists even when we raise RuntimeError below. if not hasattr(runtime, "success_codes"): runtime.success_codes = (0,) cmd = self._build_launcher_cmd(use_gpu=self.inputs.use_gpu) env = os.environ.copy() if isdefined(self.inputs.gpu_device): env["CUDA_VISIBLE_DEVICES"] = str(self.inputs.gpu_device) result = self._run_command(runtime, cmd, env) if result.returncode != 0 and self.inputs.use_gpu: gpu_stdout = (result.stdout or "").strip() gpu_stderr = (result.stderr or "").strip() logger.warning( "SynthStrip GPU run failed; retrying on CPU. stdout={} stderr={}", gpu_stdout or "<empty>", gpu_stderr or "<empty>", ) cpu_cmd = self._build_launcher_cmd(use_gpu=False) result = self._run_command(runtime, cpu_cmd, env) if result.returncode != 0: error_text = (result.stderr or result.stdout or "unknown error").strip() logger.error("mri_synthstrip failed: {}", error_text) raise RuntimeError(f"mri_synthstrip failed: {error_text}") return runtime def _list_outputs(self) -> dict[str, Any]: """Map declared output paths to absolute filesystem locations.""" outputs: dict[str, Any] = self.output_spec().get() outputs["output_image"] = os.path.abspath(self.inputs.output_image) outputs["output_mask"] = os.path.abspath(self.inputs.output_mask) return outputs