Source code for thesis.workflows.tract_synthseg.workflow

"""Combined tractography + SynthSeg segmentation workflow (backend-selectable).

Single meta-workflow that joins a SynthSeg segmentation with a tractography
backend chosen at build time from ``config.tractography.method``:

* ``probtrackx2`` / ``fsl``  → the ``hcp`` workflow (FSL BedpostX → ProbTrackX2)
* ``mrtrix3`` / ``tckgen``   → the ``mrtrix3`` workflow (5ttgen → dhollander →
  msmt_csd → mtnormalise → tckgen ACT → tcksift2 → tckmap)

The two backends write to *separate* output folders (``tractography/probtrackx2``
vs ``tractography/mrtrix3``) — that separation is declared by each tractography
sub-workflow's own ``@produces`` and is preserved here unchanged.

Architecture::

    meta-workflow  (tract_synthseg_{patient_id})
    ├── inputnode                     ← contract: t1_brain, dwi_mask,
    │                                    t1_to_dwi_transform + backend DWI fields
    ├── synthseg_wf  (synthseg_{patient_id})   [own inputnode/outputnode]
    ├── tract_wf     (hcp_{pid} | mrtrix3_{pid}) [own inputnode/outputnode]
    └── outputnode                    ← contract: fdt_paths

Contract wiring (no node-name introspection)::

    inputnode.t1_brain                  → synthseg_wf.inputnode.input_image
    synthseg_wf.outputnode.segmentation → tract_wf.inputnode.seg_map
    inputnode.<field>                   → tract_wf.inputnode.<field>
    tract_wf.outputnode.fdt_paths       → outputnode.fdt_paths

The SynthSeg label map is fed into the tractography workflow via a single
contract edge; the per-resampler / per-hemisphere fan-out lives inside the
tractography workflow. This enables role-per-source mixing without any
per-mask resampling step at the meta level:

    atlas   → seed, waypoint, target  (transformed to subject space via ANTs)
    SynthSeg → avoid  (or any combination; already in subject space)

Usage::

    thesis run -w tract_synthseg -p 114823 -c tract_synthseg           # ProbTrackX2
    thesis run -w tract_synthseg -p 114823 -c tract_synthseg_mrtrix3   # MRtrix3
"""

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Mapping, cast

from nipype import Node, Workflow
from nipype.interfaces.utility import IdentityInterface

import thesis.workflows.hcp.workflow  # noqa: F401 — populate registry + verifier
import thesis.workflows.mrtrix3.workflow  # noqa: F401 — populate registry + verifier
import thesis.workflows.synthseg.workflow  # noqa: F401 — populate registry
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import attach_inputnode, attach_outputnode
from thesis.core.decorators import requires, verify, workflow
from thesis.core.exceptions import ConfigurationError
from thesis.core.logging import get_logger
from thesis.core.path_declarations import PatientFile
from thesis.core.registry import WORKFLOW_REGISTRY
from thesis.workflows.full_pipeline._core import select_backend
from thesis.workflows.hcp.config import prepare_hcp_paths
from thesis.workflows.hcp.workflow import verify_requirements as verify_hcp_requirements
from thesis.workflows.mrtrix3.config import prepare_mrtrix3_paths
from thesis.workflows.mrtrix3.workflow import verify_requirements as verify_mrtrix3_requirements

_WorkflowFactory = Callable[[PipelineConfig, ProcessingContext], Workflow]
_Verifier = Callable[[PipelineConfig, ProcessingContext], List[str]]

logger = get_logger(__name__)


@dataclass(frozen=True)
class _BackendImpl:
    """Per-backend deltas of the synthseg meta-workflow.

    Attributes:
        tract_workflow: Registry name of the tractography sub-workflow to build
            (``"hcp"`` for ProbTrackX2, ``"mrtrix3"`` for MRtrix3).
        path_field_map: ``(inputnode_field, prepare_paths_key)`` pairs used to
            seed the meta ``inputnode`` defaults.
        wire_roi_gate: Whether to forward the ``entry_gate`` ordering signal into
            the tractography sub-workflow's ``inputnode.roi_transform_gate``
            (ProbTrackX2 only; MRtrix3's gate is inert).
    """

    tract_workflow: str
    path_field_map: tuple[tuple[str, str], ...]
    wire_roi_gate: bool


_IMPLS: dict[str, _BackendImpl] = {
    "probtrackx2": _BackendImpl(
        tract_workflow="hcp",
        path_field_map=(
            ("t1_brain", "t1_path"),
            ("dwi_mask", "mask_path"),
            ("bedpostx_thsamples", "thsamples"),
            ("bedpostx_phsamples", "phsamples"),
            ("bedpostx_fsamples", "fsamples"),
        ),
        wire_roi_gate=True,
    ),
    "mrtrix3": _BackendImpl(
        tract_workflow="mrtrix3",
        path_field_map=(
            ("t1_brain", "t1_path"),
            ("dwi_mask", "mask_path"),
            ("dwi_corrected", "dwi_image"),
            ("dwi_bvec", "bvec"),
            ("dwi_bval", "bval"),
        ),
        wire_roi_gate=False,
    ),
}


def _resolve_method(config: PipelineConfig) -> str:
    """Return ``config.tractography.method`` (defaulting to ``probtrackx2``)."""
    tract_cfg = getattr(config, "tractography", None)
    return getattr(tract_cfg, "method", None) or "probtrackx2"


def _resolve_impl(config: PipelineConfig) -> _BackendImpl:
    """Map ``config.tractography.method`` to its :class:`_BackendImpl`.

    Reuses :func:`select_backend` so method aliases (``fsl``, ``tckgen``) and
    the unknown-method error are shared with ``full_pipeline``.
    """
    return _IMPLS[select_backend(_resolve_method(config)).name]


def _prepare_paths(
    impl: _BackendImpl, config: PipelineConfig, context: ProcessingContext
) -> Mapping[str, Any]:
    """Resolve the backend's static path bundle.

    Dispatches via the module-level ``prepare_*`` names (resolved at call time)
    so tests can monkeypatch either helper.
    """
    if impl.tract_workflow == "hcp":
        return prepare_hcp_paths(config, context)
    return prepare_mrtrix3_paths(config, context)


def _backend_verifier(impl: _BackendImpl) -> _Verifier:
    """Return the active backend's pre-run requirement checker (call-time)."""
    if impl.tract_workflow == "hcp":
        return verify_hcp_requirements
    return verify_mrtrix3_requirements


def _build_synthseg_workflow(config: PipelineConfig, context: ProcessingContext) -> Workflow:
    """Build the SynthSeg sub-workflow (resolved at call time for patchability)."""
    factory = cast(_WorkflowFactory, WORKFLOW_REGISTRY.get("synthseg").factory)
    return factory(config, context)


def _build_tract_workflow(
    name: str, config: PipelineConfig, context: ProcessingContext
) -> Workflow:
    """Build the tractography sub-workflow by registry *name* (call-time)."""
    factory = cast(_WorkflowFactory, WORKFLOW_REGISTRY.get(name).factory)
    return factory(config, context)


[docs] def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]: """Pre-run checks for the combined workflow. Dispatches to the active backend's verifier (HCP for ProbTrackX2, MRtrix3 for the MRtrix3 backend) based on ``config.tractography.method``, then adds the shared SynthSeg ROI label-file check. The T1w image existence check is provided declaratively by ``@requires(t1=...)`` on :func:`build_workflow`. Args: config: PipelineConfig context: ProcessingContext Returns: List of human-readable error strings (empty = all clear). """ errors: List[str] = [] try: impl = _resolve_impl(config) except ConfigurationError as exc: return [str(exc)] errors.extend(_backend_verifier(impl)(config, context)) # Check SynthSeg label file if synthseg_roi_labels is configured tract_cfg = getattr(config, "tractography", None) synthseg_roi_cfg = getattr(tract_cfg, "synthseg_roi_labels", None) if tract_cfg else None if synthseg_roi_cfg and synthseg_roi_cfg.get("label_file"): from thesis.core.utils import resolve_path input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".") label_file_raw = synthseg_roi_cfg["label_file"] label_file = resolve_path(input_dir, label_file_raw) if not label_file.exists(): label_file = resolve_path(context.data_dir, label_file_raw) if not label_file.exists(): errors.append( f"SynthSeg ROI label file not found: '{label_file_raw}' " "(checked under input_dir and data_dir)." ) return errors
[docs] @workflow( name="tract_synthseg", description=( "Tractography + SynthSeg combined workflow; backend selected by " "config.tractography.method (probtrackx2 → FSL ProbTrackX2, " "mrtrix3 → MRtrix3 tckgen). Writes to tractography/<backend>." ), protocol="tract_synthseg", ) @requires( t1=PatientFile( default="T1w/T1w_acpc_dc_restore_brain.nii.gz", config_paths=["synthseg.t1_image", "hcp.t1_image"], ), ) @verify(verify_requirements) def build_workflow( *, t1: Path, config: PipelineConfig, context: ProcessingContext, ) -> Workflow: """Build the combined SynthSeg + tractography meta-workflow. The tractography backend is chosen from ``config.tractography.method``: ``probtrackx2``/``fsl`` builds the ``hcp`` workflow, ``mrtrix3``/``tckgen`` builds the ``mrtrix3`` workflow. Each backend writes to its own ``tractography/<backend>`` folder. **ROI source mixing** Configure ``tractography.atlas_sources`` for atlas-sourced roles and ``tractography.synthseg_roi_labels`` for SynthSeg-sourced roles. Roles can be split arbitrarily: .. code-block:: yaml tractography: atlas_sources: - name: main # atlas — seed / waypoint / target roi_file: data/masks/annotation_full.nii.gz label_file: data/masks/label_map.csv waypoint_labels: thalamus_left: region_kind: seed label_name: Thalamus-Left synthseg_roi_labels: # SynthSeg — avoid / stop label_file: data/masks/synthseg_lut.csv waypoint_labels: csf: region_kind: avoid label_name: CSF The ``synthseg_roi_labels`` section does **not** include a ``roi_file`` because the file is supplied at runtime by connecting the SynthSeg output. **Connection logic** The meta-workflow uses a strict I/O contract — no node-name introspection. ``inputnode`` forwards the backend's contract fields directly to ``tract_wf.inputnode``, and also connects ``inputnode.t1_brain`` to ``synthseg_wf.inputnode.input_image``. The single cross-workflow edge ``synthseg_wf.outputnode.segmentation → tract_wf.inputnode.seg_map`` delivers the label map; per-resampler / per-hemisphere fan-out is handled inside the tractography workflow. The ``outputnode`` re-exposes ``fdt_paths``. Args: config: PipelineConfig. Must satisfy the active backend's requirements. May include a ``synthseg`` section and/or a ``tractography.synthseg_roi_labels`` section. context: ProcessingContext with ``patient_id``, ``input_dir``, ``output_dir``. Returns: Nipype Workflow ready for ``NipypeExecutor`` or ``.run()``. """ del t1 # validated by @requires; sub-workflows re-resolve from the same config pid = context.patient_id impl = _resolve_impl(config) meta = Workflow(name=f"tract_synthseg_{pid}") synthseg_wf = _build_synthseg_workflow(config, context) tract_wf = _build_tract_workflow(impl.tract_workflow, config, context) # -- Re-expose the contract -------------------------------------------- # Standalone tract_synthseg runs have nothing feeding this meta inputnode, so # its fields would stay Undefined and — because the connections below forward # them into tract_wf.inputnode — would *overwrite* the build-time defaults the # tractography sub-workflow seeds on its own inputnode. That Undefined then # surfaces as missing-input errors (e.g. SynthSeg input_image, probtrackx2 # samples, 5ttgen t1, grad bvec/bval) for every subject/hemisphere. Seed the # meta inputnode from the same paths dict the backend uses so standalone runs # resolve statically; embedded full_pipeline runs still override these via # standard Nipype connections. _paths = _prepare_paths(impl, config, context) inputnode_defaults: dict = {} for _field, _key in impl.path_field_map: _value = _paths.get(_key) if not _value: continue # bedpostx samples are lists of files (probtrack *samples are # InputMultiPath); seed list-typed values as a list of path strings (a # valid multi-file value), scalars as a single path string. inputnode_defaults[_field] = ( [str(item) for item in _value] if isinstance(_value, (list, tuple)) else str(_value) ) # Contract fields are the destination names the backend descriptor wires # (single source of truth, shared with full_pipeline) plus the always-present # optional intra-subject transform. desc = select_backend(_resolve_method(config)) contract_fields = [dst for _, dst in desc.tractography_edges] + ["t1_to_dwi_transform"] inputnode = attach_inputnode(meta, contract_fields, defaults=inputnode_defaults) meta.connect(inputnode, "t1_brain", synthseg_wf, "inputnode.input_image") meta.connect(synthseg_wf, "outputnode.segmentation", tract_wf, "inputnode.seg_map") for field in contract_fields: meta.connect(inputnode, field, tract_wf, f"inputnode.{field}") outputnode = attach_outputnode(meta, ["fdt_paths"]) meta.connect(tract_wf, "outputnode.fdt_paths", outputnode, "fdt_paths") logger.info( "Built tract_synthseg meta-workflow for {} (backend={}, contract-wired)", pid, impl.tract_workflow, ) # -- Optionally attach validation sub-workflow ---------------------------- validation_cfg = getattr(config, "validation", None) if getattr(validation_cfg, "check_rois", False): from .validation import build_validation_workflow brain_mask_path = str(_prepare_paths(impl, config, context)["mask_path"]) validation_wf = build_validation_workflow(config, context, brain_mask_path) # Contract-wired — no node-name introspection. Both backends re-publish # their ROI terminus on outputnode.roi_{seed,stop,avoid,target}, so # connect to those stable contract outputs directly. meta.connect( [ ( tract_wf, validation_wf, [ ("outputnode.roi_seed", "roi_collector.seed"), ("outputnode.roi_stop", "roi_collector.stop_mask"), ("outputnode.roi_avoid", "roi_collector.avoid_mask"), ("outputnode.roi_target", "roi_collector.target_mask"), ], ) ] ) logger.info( "Connected tract_wf.outputnode.roi_* -> validation_wf.roi_collector (contract-wired)" ) # Fixed-name anchor for cross-workflow ordering. When embedded in # full_pipeline, the meta connects registration's completion to # ``entry_gate._ordering_signal``. For the ProbTrackX2 backend we forward # that into the hcp workflow's ``roi_transform_gate`` so the atlas ROI warp # (which reads the template->patient transforms by path) only runs once # registration has written them. Seeded with "" so standalone runs (no # external connect) leave the gate satisfied and unordered. The MRtrix3 # backend has no such gate, so its entry_gate stays inert. entry_gate = Node(IdentityInterface(fields=["_ordering_signal"]), name="entry_gate") meta.add_nodes([entry_gate]) if impl.wire_roi_gate: entry_gate.inputs._ordering_signal = "" meta.connect(entry_gate, "_ordering_signal", tract_wf, "inputnode.roi_transform_gate") return meta