"""
Custom FreeSurfer Nipype interfaces.
Provides CommandLine-based Nipype interfaces for FreeSurfer tools that do not
have upstream wrappers in nipype.interfaces.freesurfer, enabling them to
participate in Nipype workflows with full caching and parallelisation support.
Classes:
SynthSeg: Nipype interface for mri_synthseg (brain MRI segmentation).
SynthStrip: Nipype interface for mri_synthstrip (brain extraction).
"""
import os
import shlex
import subprocess
import sys
from typing import Any
from nipype.interfaces.base import (
CommandLine,
CommandLineInputSpec,
File,
TraitedSpec,
isdefined,
traits,
)
from thesis.core.logging import get_logger
logger = get_logger(__name__)
__all__ = ["SynthSeg", "SynthStrip"]
def _is_gpu_oom(error_text: str) -> bool:
"""Return True when the error string describes a GPU out-of-memory event."""
oom_markers = ("ResourceExhaustedError", "OOM when allocating", "out of memory")
return any(marker in error_text for marker in oom_markers)
# ---------------------------------------------------------------------------
# Input / Output specifications
# ---------------------------------------------------------------------------
class SynthSegInputSpec(CommandLineInputSpec):
"""Input specification for mri_synthseg."""
input_image = File(
argstr="--i %s",
mandatory=True,
exists=True,
desc=(
"Input brain scan. Accepts a single NIfTI/MGZ file, a directory of "
"scans, or a plain-text file whose lines list individual image paths."
),
)
output_segmentation = File(
argstr="--o %s",
mandatory=True,
hash_files=False,
desc=(
"Destination for the output segmentation. Must match the type of "
"--i (file → file, directory → directory, text → text)."
),
)
# --- Segmentation options -----------------------------------------------
parc = traits.Bool(
argstr="--parc",
desc="Enable cortical parcellation in addition to whole-brain segmentation.",
)
robust = traits.Bool(
argstr="--robust",
desc=(
"Use the robust variant, recommended for clinical data with poor contrast "
"or unusual acquisition parameters."
),
)
fast = traits.Bool(
argstr="--fast",
desc="Disable postprocessing for ~2× speed at a small accuracy cost.",
)
ct = traits.Bool(
argstr="--ct",
desc="Optimise for CT input by clipping to [0, 80] Hounsfield units.",
)
v1 = traits.Bool(
argstr="--v1",
desc="Run the original SynthSeg 1.0 model instead of the default 2.0 model.",
)
# --- Optional outputs ----------------------------------------------------
vol = File(
argstr="--vol %s",
hash_files=False,
desc="Path for a CSV file containing estimated volumes per brain region.",
)
qc = File(
argstr="--qc %s",
hash_files=False,
desc="Path for a CSV file containing per-scan quality-control scores.",
)
post = File(
argstr="--post %s",
hash_files=False,
desc="Path for a NIfTI file containing 3-D posterior probability maps.",
)
resample = File(
argstr="--resample %s",
hash_files=False,
desc="Save the internally resampled (1 mm isotropic) image to this path.",
)
# --- Performance --------------------------------------------------------
cpu = traits.Bool(
argstr="--cpu",
desc=(
"Force CPU execution. GPU (~6 s/scan) is used by default when available; "
"CPU takes ~2 min/scan."
),
)
use_gpu = traits.Bool(
False,
usedefault=True,
desc=(
"Marks this node as requiring a GPU slot. "
"GPU serialization is handled at the scheduler level via "
"``n_gpu_procs=1`` in ``plugin_args``."
),
)
threads = traits.Int(
argstr="--threads %d",
desc="Number of CPU threads to use (default: 1). Only relevant with --cpu.",
)
# --- Crop ---------------------------------------------------------------
crop = traits.Int(
argstr="--crop %d",
desc=(
"Crop input volumes to this size (must be divisible by 32) before "
"processing. Reduces memory usage at the cost of field-of-view."
),
)
class SynthSegOutputSpec(TraitedSpec):
"""Output specification for mri_synthseg."""
output_segmentation = File(desc="Brain segmentation label map (NIfTI/MGZ).")
vol = File(desc="CSV file with per-region volumes (if --vol was specified).")
qc = File(desc="CSV file with QC scores (if --qc was specified).")
post = File(desc="Posterior probability map (if --post was specified).")
resample = File(desc="Resampled image at 1 mm isotropic (if --resample was specified).")
class SynthStripInputSpec(CommandLineInputSpec):
"""Input specification for ``mri_synthstrip``."""
input_image = File(
mandatory=True,
exists=True,
desc="Input brain MRI volume to strip.",
)
output_image = File(
mandatory=True,
hash_files=False,
desc="Destination for the skull-stripped brain image.",
)
output_mask = File(
mandatory=True,
hash_files=False,
desc="Destination for the binary brain mask.",
)
use_gpu = traits.Bool(
False,
usedefault=True,
desc="Request GPU execution through the FreeSurfer Python launcher.",
)
gpu_device = traits.Int(
desc="Optional CUDA device index exposed via CUDA_VISIBLE_DEVICES.",
)
class SynthStripOutputSpec(TraitedSpec):
"""Output specification for ``mri_synthstrip``."""
output_image = File(desc="Skull-stripped brain image.")
output_mask = File(desc="Binary brain mask.")
# ---------------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------------
[docs]
class SynthSeg(CommandLine):
"""
Nipype interface for FreeSurfer's ``mri_synthseg`` segmentation tool.
SynthSeg segments brain MRI scans (and optionally performs cortical
parcellation) without requiring a specific acquisition protocol or
contrast. It is robust to clinical data and non-standard sequences.
Because this class subclasses ``CommandLine``, it integrates transparently
with Nipype's workflow engine:
- **Caching**: results are hashed and re-used automatically.
- **Parallelisation**: place it in a ``MapNode`` or connect it to the
workflow graph; the MultiProc / SGEGraph plugins handle scheduling.
- **Provenance**: command line, inputs, and outputs are recorded.
Requirements:
``mri_synthseg`` must be on ``PATH`` (provided by FreeSurfer ≥ 7.3).
Examples:
Basic segmentation::
from thesis.core.nipype.interfaces.freesurfer import SynthSeg
seg = SynthSeg(
input_image="sub-01_T1w.nii.gz",
output_segmentation="sub-01_synthseg.nii.gz",
)
result = seg.run()
With parcellation, volumes CSV, and CPU execution::
seg = SynthSeg(
input_image="sub-01_T1w.nii.gz",
output_segmentation="sub-01_synthseg.nii.gz",
parc=True,
vol="sub-01_volumes.csv",
cpu=True,
threads=4,
)
Inside a Nipype workflow::
from nipype import Node, Workflow
from thesis.core.nipype.interfaces.freesurfer import SynthSeg
seg_node = Node(SynthSeg(), name="synthseg")
seg_node.inputs.input_image = "T1w.nii.gz"
seg_node.inputs.output_segmentation = "T1w_synthseg.nii.gz"
seg_node.inputs.parc = True
wf = Workflow(name="seg_workflow")
wf.add_nodes([seg_node])
wf.run()
Batch segmentation via MapNode::
from nipype import MapNode
seg_mapnode = MapNode(
SynthSeg(parc=True, cpu=True, threads=4),
iterfield=["input_image", "output_segmentation"],
name="synthseg_batch",
)
seg_mapnode.inputs.input_image = ["sub-01_T1w.nii.gz", "sub-02_T1w.nii.gz"]
seg_mapnode.inputs.output_segmentation = ["sub-01_seg.nii.gz", "sub-02_seg.nii.gz"]
"""
_cmd = "mri_synthseg"
input_spec = SynthSegInputSpec
output_spec = SynthSegOutputSpec
def _run_interface(self, runtime):
"""
Execute mri_synthseg with GPU OOM fallback.
GPU serialization is now handled at the scheduler level via
``n_gpu_procs=1`` in ``plugin_args`` — no process-level lock needed.
If ``cpu=True`` was already set explicitly, the node runs in CPU mode
as normal.
If a GPU out-of-memory error occurs, the node retries with
``CUDA_VISIBLE_DEVICES=""`` (CPU fallback).
"""
# Honour an explicit cpu=True flag — run directly.
if isdefined(self.inputs.cpu) and self.inputs.cpu:
return super()._run_interface(runtime)
# GPU mode — run directly, handle OOM.
try:
return super()._run_interface(runtime)
except RuntimeError as exc:
if _is_gpu_oom(str(exc)):
logger.warning("SynthSeg: GPU OOM — falling back to CPU (CUDA_VISIBLE_DEVICES='')")
original_node_environ = (
dict(self.inputs.environ) if isdefined(self.inputs.environ) else None
)
self.inputs.environ = {"CUDA_VISIBLE_DEVICES": ""}
try:
return super()._run_interface(runtime)
finally:
if original_node_environ is not None:
self.inputs.environ = original_node_environ
else:
self.inputs.environ = {}
raise
def _list_outputs(self) -> dict[str, Any]:
"""Map declared input paths to expected output files."""
outputs: dict[str, Any] = self.output_spec().get()
outputs["output_segmentation"] = os.path.abspath(self.inputs.output_segmentation)
for field in ("vol", "qc", "post", "resample"):
value = getattr(self.inputs, field, None)
if isdefined(value):
outputs[field] = os.path.abspath(str(value))
return outputs
[docs]
class SynthStrip(CommandLine):
"""Nipype CommandLine interface for FreeSurfer's ``mri_synthstrip`` brain extraction tool.
``mri_synthstrip`` performs learning-based skull-stripping without requiring a
specific acquisition protocol or contrast. This wrapper drives it through a custom
``_run_interface`` that bypasses the standard Nipype ``cmdline()`` machinery so that
GPU execution can be handled via the FreeSurfer Python launcher with automatic CPU
fallback.
Requirements:
``FREESURFER_HOME`` must be set in the environment. FreeSurfer ≥ 7 is required
for the ``mri_synthstrip`` script to be present at
``$FREESURFER_HOME/python/scripts/mri_synthstrip``.
GPU behaviour:
When ``use_gpu=True`` the FreeSurfer Python launcher is invoked with the ``-g``
flag and, if ``gpu_device`` is defined, ``CUDA_VISIBLE_DEVICES`` is set
accordingly. If the GPU run exits non-zero the node automatically retries with
the same launcher but without ``-g`` (CPU fallback). A torch diagnostic probe is
logged at WARNING level before the retry to aid debugging.
Attributes:
_cmd (str): Fallback binary name (``"mri_synthstrip"``); used for the CPU
non-launcher path only.
input_spec: :class:`SynthStripInputSpec`
output_spec: :class:`SynthStripOutputSpec`
Example:
Basic skull-stripping inside a Nipype workflow::
from nipype import Node
from thesis.core.nipype.interfaces.freesurfer import SynthStrip
node = Node(SynthStrip(), name="synthstrip")
node.inputs.input_image = "sub-01_T1w.nii.gz"
node.inputs.output_image = "sub-01_brain.nii.gz"
node.inputs.output_mask = "sub-01_mask.nii.gz"
result = node.run()
"""
_cmd = "mri_synthstrip"
input_spec = SynthStripInputSpec
output_spec = SynthStripOutputSpec
def _build_launcher_cmd(self, use_gpu: bool) -> list[str]:
"""Build a SynthStrip command using the active conda Python interpreter.
Uses ``sys.executable`` so that the CUDA-capable ``torch`` in the conda
environment is available for GPU execution. FreeSurfer's versioned
``site-packages`` directory is intentionally **not** injected because
it contains ``surfa`` compiled for FreeSurfer's embedded Python 3.8,
which is ABI-incompatible with a newer conda Python. Instead,
``surfa`` must be installed directly in the conda environment::
pip install surfa
Only ``$FREESURFER_HOME/python/packages`` is appended to ``sys.path``
so that any non-compiled FreeSurfer resources remain accessible.
Args:
use_gpu: Append the ``-g`` flag for GPU-accelerated execution.
Returns:
Command list suitable for ``subprocess.run``.
"""
fs_home = os.environ.get("FREESURFER_HOME")
if not fs_home:
raise RuntimeError("mri_synthstrip failed: FREESURFER_HOME is not set")
script = os.path.join(fs_home, "python", "scripts", "mri_synthstrip")
# Inject only the non-compiled packages dir; the versioned site-packages
# (containing the Python-3.8-compiled surfa .so files) is excluded to
# prevent ABI mismatches with the conda Python.
extra_paths: list[str] = []
packages_dir = os.path.join(fs_home, "python", "packages")
if os.path.isdir(packages_dir):
extra_paths.append(packages_dir)
bootstrap = (
"import sys, runpy; "
f"[sys.path.append(p) for p in {extra_paths!r} if p not in sys.path]; "
f"sys.argv = [{script!r}] + sys.argv[1:]; "
f"runpy.run_path({script!r}, run_name='__main__')"
)
cmd = [
sys.executable,
"-c",
bootstrap,
"-i",
self.inputs.input_image,
"-o",
self.inputs.output_image,
"-m",
self.inputs.output_mask,
]
if use_gpu:
cmd.append("-g")
logger.debug(
"Using SynthStrip launcher: python={} script={} gpu={}",
sys.executable,
script,
use_gpu,
)
return cmd
@staticmethod
def _run_command(
runtime: Any,
cmd: list[str],
env: dict[str, str],
) -> subprocess.CompletedProcess[str]:
"""Execute a subprocess and mirror the result onto the Nipype runtime."""
result = subprocess.run(cmd, capture_output=True, text=True, check=False, env=env)
runtime.cmdline = subprocess.list2cmdline(cmd) if os.name == "nt" else shlex.join(cmd)
runtime.stdout = result.stdout
runtime.stderr = result.stderr
runtime.returncode = result.returncode
return result
def _run_interface(self, runtime: Any, _correct_return_codes: tuple[int, ...] = (0,)) -> Any:
"""Execute ``mri_synthstrip`` with GPU/CPU fallback."""
# Nipype's rtc context manager checks runtime.success_codes on exit; initialise
# it here so the attribute exists even when we raise RuntimeError below.
if not hasattr(runtime, "success_codes"):
runtime.success_codes = (0,)
cmd = self._build_launcher_cmd(use_gpu=self.inputs.use_gpu)
env = os.environ.copy()
if isdefined(self.inputs.gpu_device):
env["CUDA_VISIBLE_DEVICES"] = str(self.inputs.gpu_device)
result = self._run_command(runtime, cmd, env)
if result.returncode != 0 and self.inputs.use_gpu:
gpu_stdout = (result.stdout or "").strip()
gpu_stderr = (result.stderr or "").strip()
logger.warning(
"SynthStrip GPU run failed; retrying on CPU. stdout={} stderr={}",
gpu_stdout or "<empty>",
gpu_stderr or "<empty>",
)
cpu_cmd = self._build_launcher_cmd(use_gpu=False)
result = self._run_command(runtime, cpu_cmd, env)
if result.returncode != 0:
error_text = (result.stderr or result.stdout or "unknown error").strip()
logger.error("mri_synthstrip failed: {}", error_text)
raise RuntimeError(f"mri_synthstrip failed: {error_text}")
return runtime
def _list_outputs(self) -> dict[str, Any]:
"""Map declared output paths to absolute filesystem locations."""
outputs: dict[str, Any] = self.output_spec().get()
outputs["output_image"] = os.path.abspath(self.inputs.output_image)
outputs["output_mask"] = os.path.abspath(self.inputs.output_mask)
return outputs