"""
SynthSeg brain segmentation workflow.
Builds a standalone Nipype workflow that runs FreeSurfer's ``mri_synthseg``
on a T1-weighted image. The workflow can be executed on its own or embedded
inside a larger meta-workflow via Nipype's nested-workflow mechanism.
"""
from pathlib import Path
from typing import Dict
from nipype import Node, Workflow
from nipype.interfaces.base import isdefined
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import attach_inputnode, attach_outputnode
from thesis.core.decorators import produces, requires, workflow
from thesis.core.logging import get_logger
from thesis.core.nipype.interfaces.freesurfer import SynthSeg
from thesis.core.path_declarations import OutputDir, PatientFile
from thesis.core.resources import cpu, gpu
from thesis.workflows.hcp.common import resolve_t1_path
logger = get_logger(__name__)
[docs]
def resolve_t1_for_synthseg(config: PipelineConfig, context: ProcessingContext) -> Path:
"""Resolve the T1 image path for the SynthSeg workflow.
Kept as a thin shim so the composite ``tract_synthseg`` meta-workflow and
external preflight checks can still resolve the path without instantiating
a :class:`PatientFile`.
"""
resolved = resolve_t1_path(config, context)
synthseg_cfg = getattr(config, "synthseg", None)
synthseg_t1 = getattr(synthseg_cfg, "t1_image", None) if synthseg_cfg is not None else None
hcp_cfg = getattr(config, "hcp", None)
hcp_t1 = getattr(hcp_cfg, "t1_image", None) if hcp_cfg is not None else None
if not synthseg_t1 and not hcp_t1:
logger.warning(
"No t1_image configured for SynthSeg or HCP; using default path: {}",
resolved,
)
return resolved
[docs]
@workflow(name="synthseg", description="FreeSurfer SynthSeg brain segmentation workflow.")
@requires(
t1=PatientFile(
default="T1w/T1w_acpc_dc_restore_1.25.nii",
config_paths=["synthseg.t1_image", "hcp.t1_image"],
optional=True,
),
)
@produces(seg_dir=OutputDir("segmentation/synthseg"))
def build_workflow(
*,
t1: Path | None,
seg_dir: Path,
config: PipelineConfig,
context: ProcessingContext,
) -> Workflow:
"""Build a Nipype workflow that segments a T1w image with SynthSeg.
The workflow contains a single node (``synthseg``) wrapping
``mri_synthseg``. Outputs land in
``{output_dir}/segmentation/synthseg/{patient_id}_synthseg.nii.gz``.
Args:
t1: Resolved T1-weighted image path (declared via ``@requires``).
seg_dir: Resolved segmentation output directory.
config: PipelineConfig (may include an optional ``synthseg`` section).
context: ProcessingContext carrying ``patient_id``.
Returns:
Nipype Workflow ready to run or embed in a meta-workflow.
"""
pid = context.patient_id
wf = Workflow(name=f"synthseg_{pid}")
seg_dir = seg_dir.resolve()
out_seg = seg_dir / f"{pid}_synthseg.nii.gz"
synthseg_cfg = config.synthseg.model_dump(mode="python") if config.synthseg else {}
want_parc = bool(synthseg_cfg.get("parc", False))
want_robust = bool(synthseg_cfg.get("robust", False))
want_fast = bool(synthseg_cfg.get("fast", False))
want_vol = bool(synthseg_cfg.get("vol", False))
want_qc = bool(synthseg_cfg.get("qc", False))
crop_size = synthseg_cfg.get("crop", None)
if crop_size is not None:
crop_size = int(crop_size)
# Priority: synthseg.cpu (explicit) > hardware.gpu_enabled / hardware.gpu_device
want_cpu = bool(synthseg_cfg.get("cpu", False))
n_threads = int(synthseg_cfg.get("threads", 1))
gpu_device = None
if not want_cpu:
hw_cfg = getattr(config, "hardware", None)
if hw_cfg is not None:
if not bool(getattr(hw_cfg, "gpu_enabled", True)):
want_cpu = True
logger.info("SynthSeg: hardware.gpu_enabled=false → forcing CPU mode")
else:
gpu_device = getattr(hw_cfg, "gpu_device", None)
seg_node = Node(SynthSeg(), name="synthseg")
# Only set input_image statically when the file already exists.
# In a composite meta-workflow (e.g. full_pipeline), the file is created
# by an upstream stage at runtime and injected via a Nipype connection
# instead. Nipype's File(exists=True) trait raises at assignment time if
# the path does not yet exist on disk.
if t1 is not None and t1.exists():
seg_node.inputs.input_image = str(t1)
seg_node.inputs.output_segmentation = str(out_seg)
if want_parc:
seg_node.inputs.parc = True
if want_robust:
seg_node.inputs.robust = True
if want_fast:
seg_node.inputs.fast = True
if want_cpu:
seg_node.inputs.cpu = True
seg_node.inputs.threads = n_threads
cpu(seg_node, mem_gb=6.0, threads=n_threads)
else:
# MultiProc.is_gpu_node() returns True on `inputs.use_gpu`, which lets
# the scheduler gate concurrent SynthSeg leaves in a parallel batch
# DAG against the configured `n_gpu_procs` budget.
seg_node.inputs.use_gpu = True
gpu(seg_node, mem_gb=6.0, est_minutes=15)
# Node-scoped runtime env for the GPU SynthSeg subprocess only. Nipype
# passes it to the mri_synthseg subprocess (the Neurodesk Singularity
# wrapper), so it reaches ONLY the FreeSurfer container -- the probtrackx2
# / ANTs containers are separate nodes and never see it. Carries any
# CUDA_VISIBLE_DEVICES plus the configured synthseg.gpu_runtime_env (e.g.
# SINGULARITY_NV=1 + SINGULARITYENV_LD_LIBRARY_PATH=<cuda runtime> so the
# container's TensorFlow can actually use the GPU). Merge, don't clobber.
gpu_env: Dict[str, str] = {}
if gpu_device is not None:
gpu_env["CUDA_VISIBLE_DEVICES"] = str(gpu_device)
logger.info("SynthSeg: setting CUDA_VISIBLE_DEVICES={}", gpu_device)
runtime_env = synthseg_cfg.get("gpu_runtime_env") or {}
if runtime_env:
gpu_env.update({k: str(v) for k, v in runtime_env.items()})
if gpu_env:
existing = dict(seg_node.inputs.environ) if isdefined(seg_node.inputs.environ) else {}
existing.update(gpu_env)
seg_node.inputs.environ = existing
if crop_size is not None:
seg_node.inputs.crop = crop_size
if want_vol:
seg_node.inputs.vol = str(seg_dir / f"{pid}_volumes.csv")
if want_qc:
seg_node.inputs.qc = str(seg_dir / f"{pid}_qc.csv")
logger.info(
"SynthSeg node configured | input={} | output={} | parc={} | robust={}",
t1,
out_seg,
want_parc,
want_robust,
)
wf.add_nodes([seg_node])
# -- I/O contract --------------------------------------------------------
# inputnode.input_image defaults to the resolved T1 (standalone runs) and
# feeds the synthseg node; the meta-workflow overrides it with the runtime
# T1 brain. outputnode.segmentation re-exposes the label map.
inputnode = attach_inputnode(
wf,
["input_image"],
defaults={"input_image": str(t1) if (t1 is not None and t1.exists()) else None},
)
wf.connect(inputnode, "input_image", seg_node, "input_image")
outputnode = attach_outputnode(wf, ["segmentation"])
wf.connect(seg_node, "output_segmentation", outputnode, "segmentation")
return wf