"""Combined tractography + SynthSeg segmentation workflow (backend-selectable).
Single meta-workflow that joins a SynthSeg segmentation with a tractography
backend chosen at build time from ``config.tractography.method``:
* ``probtrackx2`` / ``fsl`` → the ``hcp`` workflow (FSL BedpostX → ProbTrackX2)
* ``mrtrix3`` / ``tckgen`` → the ``mrtrix3`` workflow (5ttgen → dhollander →
msmt_csd → mtnormalise → tckgen ACT → tcksift2 → tckmap)
The two backends write to *separate* output folders (``tractography/probtrackx2``
vs ``tractography/mrtrix3``) — that separation is declared by each tractography
sub-workflow's own ``@produces`` and is preserved here unchanged.
Architecture::
meta-workflow (tract_synthseg_{patient_id})
├── inputnode ← contract: t1_brain, dwi_mask,
│ t1_to_dwi_transform + backend DWI fields
├── synthseg_wf (synthseg_{patient_id}) [own inputnode/outputnode]
├── tract_wf (hcp_{pid} | mrtrix3_{pid}) [own inputnode/outputnode]
└── outputnode ← contract: fdt_paths
Contract wiring (no node-name introspection)::
inputnode.t1_brain → synthseg_wf.inputnode.input_image
synthseg_wf.outputnode.segmentation → tract_wf.inputnode.seg_map
inputnode.<field> → tract_wf.inputnode.<field>
tract_wf.outputnode.fdt_paths → outputnode.fdt_paths
The SynthSeg label map is fed into the tractography workflow via a single
contract edge; the per-resampler / per-hemisphere fan-out lives inside the
tractography workflow. This enables role-per-source mixing without any
per-mask resampling step at the meta level:
atlas → seed, waypoint, target (transformed to subject space via ANTs)
SynthSeg → avoid (or any combination; already in subject space)
Usage::
thesis run -w tract_synthseg -p 114823 -c tract_synthseg # ProbTrackX2
thesis run -w tract_synthseg -p 114823 -c tract_synthseg_mrtrix3 # MRtrix3
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Mapping, cast
from nipype import Node, Workflow
from nipype.interfaces.utility import IdentityInterface
import thesis.workflows.hcp.workflow # noqa: F401 — populate registry + verifier
import thesis.workflows.mrtrix3.workflow # noqa: F401 — populate registry + verifier
import thesis.workflows.synthseg.workflow # noqa: F401 — populate registry
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import attach_inputnode, attach_outputnode
from thesis.core.decorators import requires, verify, workflow
from thesis.core.exceptions import ConfigurationError
from thesis.core.logging import get_logger
from thesis.core.path_declarations import PatientFile
from thesis.core.registry import WORKFLOW_REGISTRY
from thesis.workflows.full_pipeline._core import select_backend
from thesis.workflows.hcp.config import prepare_hcp_paths
from thesis.workflows.hcp.workflow import verify_requirements as verify_hcp_requirements
from thesis.workflows.mrtrix3.config import prepare_mrtrix3_paths
from thesis.workflows.mrtrix3.workflow import verify_requirements as verify_mrtrix3_requirements
_WorkflowFactory = Callable[[PipelineConfig, ProcessingContext], Workflow]
_Verifier = Callable[[PipelineConfig, ProcessingContext], List[str]]
logger = get_logger(__name__)
@dataclass(frozen=True)
class _BackendImpl:
"""Per-backend deltas of the synthseg meta-workflow.
Attributes:
tract_workflow: Registry name of the tractography sub-workflow to build
(``"hcp"`` for ProbTrackX2, ``"mrtrix3"`` for MRtrix3).
path_field_map: ``(inputnode_field, prepare_paths_key)`` pairs used to
seed the meta ``inputnode`` defaults.
wire_roi_gate: Whether to forward the ``entry_gate`` ordering signal into
the tractography sub-workflow's ``inputnode.roi_transform_gate``
(ProbTrackX2 only; MRtrix3's gate is inert).
"""
tract_workflow: str
path_field_map: tuple[tuple[str, str], ...]
wire_roi_gate: bool
_IMPLS: dict[str, _BackendImpl] = {
"probtrackx2": _BackendImpl(
tract_workflow="hcp",
path_field_map=(
("t1_brain", "t1_path"),
("dwi_mask", "mask_path"),
("bedpostx_thsamples", "thsamples"),
("bedpostx_phsamples", "phsamples"),
("bedpostx_fsamples", "fsamples"),
),
wire_roi_gate=True,
),
"mrtrix3": _BackendImpl(
tract_workflow="mrtrix3",
path_field_map=(
("t1_brain", "t1_path"),
("dwi_mask", "mask_path"),
("dwi_corrected", "dwi_image"),
("dwi_bvec", "bvec"),
("dwi_bval", "bval"),
),
wire_roi_gate=False,
),
}
def _resolve_method(config: PipelineConfig) -> str:
"""Return ``config.tractography.method`` (defaulting to ``probtrackx2``)."""
tract_cfg = getattr(config, "tractography", None)
return getattr(tract_cfg, "method", None) or "probtrackx2"
def _resolve_impl(config: PipelineConfig) -> _BackendImpl:
"""Map ``config.tractography.method`` to its :class:`_BackendImpl`.
Reuses :func:`select_backend` so method aliases (``fsl``, ``tckgen``) and
the unknown-method error are shared with ``full_pipeline``.
"""
return _IMPLS[select_backend(_resolve_method(config)).name]
def _prepare_paths(
impl: _BackendImpl, config: PipelineConfig, context: ProcessingContext
) -> Mapping[str, Any]:
"""Resolve the backend's static path bundle.
Dispatches via the module-level ``prepare_*`` names (resolved at call time)
so tests can monkeypatch either helper.
"""
if impl.tract_workflow == "hcp":
return prepare_hcp_paths(config, context)
return prepare_mrtrix3_paths(config, context)
def _backend_verifier(impl: _BackendImpl) -> _Verifier:
"""Return the active backend's pre-run requirement checker (call-time)."""
if impl.tract_workflow == "hcp":
return verify_hcp_requirements
return verify_mrtrix3_requirements
def _build_synthseg_workflow(config: PipelineConfig, context: ProcessingContext) -> Workflow:
"""Build the SynthSeg sub-workflow (resolved at call time for patchability)."""
factory = cast(_WorkflowFactory, WORKFLOW_REGISTRY.get("synthseg").factory)
return factory(config, context)
def _build_tract_workflow(
name: str, config: PipelineConfig, context: ProcessingContext
) -> Workflow:
"""Build the tractography sub-workflow by registry *name* (call-time)."""
factory = cast(_WorkflowFactory, WORKFLOW_REGISTRY.get(name).factory)
return factory(config, context)
[docs]
def verify_requirements(config: PipelineConfig, context: ProcessingContext) -> List[str]:
"""Pre-run checks for the combined workflow.
Dispatches to the active backend's verifier (HCP for ProbTrackX2, MRtrix3
for the MRtrix3 backend) based on ``config.tractography.method``, then adds
the shared SynthSeg ROI label-file check. The T1w image existence check is
provided declaratively by ``@requires(t1=...)`` on :func:`build_workflow`.
Args:
config: PipelineConfig
context: ProcessingContext
Returns:
List of human-readable error strings (empty = all clear).
"""
errors: List[str] = []
try:
impl = _resolve_impl(config)
except ConfigurationError as exc:
return [str(exc)]
errors.extend(_backend_verifier(impl)(config, context))
# Check SynthSeg label file if synthseg_roi_labels is configured
tract_cfg = getattr(config, "tractography", None)
synthseg_roi_cfg = getattr(tract_cfg, "synthseg_roi_labels", None) if tract_cfg else None
if synthseg_roi_cfg and synthseg_roi_cfg.get("label_file"):
from thesis.core.utils import resolve_path
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".")
label_file_raw = synthseg_roi_cfg["label_file"]
label_file = resolve_path(input_dir, label_file_raw)
if not label_file.exists():
label_file = resolve_path(context.data_dir, label_file_raw)
if not label_file.exists():
errors.append(
f"SynthSeg ROI label file not found: '{label_file_raw}' "
"(checked under input_dir and data_dir)."
)
return errors
[docs]
@workflow(
name="tract_synthseg",
description=(
"Tractography + SynthSeg combined workflow; backend selected by "
"config.tractography.method (probtrackx2 → FSL ProbTrackX2, "
"mrtrix3 → MRtrix3 tckgen). Writes to tractography/<backend>."
),
protocol="tract_synthseg",
)
@requires(
t1=PatientFile(
default="T1w/T1w_acpc_dc_restore_brain.nii.gz",
config_paths=["synthseg.t1_image", "hcp.t1_image"],
),
)
@verify(verify_requirements)
def build_workflow(
*,
t1: Path,
config: PipelineConfig,
context: ProcessingContext,
) -> Workflow:
"""Build the combined SynthSeg + tractography meta-workflow.
The tractography backend is chosen from ``config.tractography.method``:
``probtrackx2``/``fsl`` builds the ``hcp`` workflow, ``mrtrix3``/``tckgen``
builds the ``mrtrix3`` workflow. Each backend writes to its own
``tractography/<backend>`` folder.
**ROI source mixing**
Configure ``tractography.atlas_sources`` for atlas-sourced roles and
``tractography.synthseg_roi_labels`` for SynthSeg-sourced roles. Roles can
be split arbitrarily:
.. code-block:: yaml
tractography:
atlas_sources:
- name: main # atlas — seed / waypoint / target
roi_file: data/masks/annotation_full.nii.gz
label_file: data/masks/label_map.csv
waypoint_labels:
thalamus_left:
region_kind: seed
label_name: Thalamus-Left
synthseg_roi_labels: # SynthSeg — avoid / stop
label_file: data/masks/synthseg_lut.csv
waypoint_labels:
csf:
region_kind: avoid
label_name: CSF
The ``synthseg_roi_labels`` section does **not** include a ``roi_file``
because the file is supplied at runtime by connecting the SynthSeg output.
**Connection logic**
The meta-workflow uses a strict I/O contract — no node-name introspection.
``inputnode`` forwards the backend's contract fields directly to
``tract_wf.inputnode``, and also connects ``inputnode.t1_brain`` to
``synthseg_wf.inputnode.input_image``. The single cross-workflow edge
``synthseg_wf.outputnode.segmentation → tract_wf.inputnode.seg_map``
delivers the label map; per-resampler / per-hemisphere fan-out is handled
inside the tractography workflow. The ``outputnode`` re-exposes
``fdt_paths``.
Args:
config: PipelineConfig. Must satisfy the active backend's requirements.
May include a ``synthseg`` section and/or a
``tractography.synthseg_roi_labels`` section.
context: ProcessingContext with ``patient_id``, ``input_dir``,
``output_dir``.
Returns:
Nipype Workflow ready for ``NipypeExecutor`` or ``.run()``.
"""
del t1 # validated by @requires; sub-workflows re-resolve from the same config
pid = context.patient_id
impl = _resolve_impl(config)
meta = Workflow(name=f"tract_synthseg_{pid}")
synthseg_wf = _build_synthseg_workflow(config, context)
tract_wf = _build_tract_workflow(impl.tract_workflow, config, context)
# -- Re-expose the contract --------------------------------------------
# Standalone tract_synthseg runs have nothing feeding this meta inputnode, so
# its fields would stay Undefined and — because the connections below forward
# them into tract_wf.inputnode — would *overwrite* the build-time defaults the
# tractography sub-workflow seeds on its own inputnode. That Undefined then
# surfaces as missing-input errors (e.g. SynthSeg input_image, probtrackx2
# samples, 5ttgen t1, grad bvec/bval) for every subject/hemisphere. Seed the
# meta inputnode from the same paths dict the backend uses so standalone runs
# resolve statically; embedded full_pipeline runs still override these via
# standard Nipype connections.
_paths = _prepare_paths(impl, config, context)
inputnode_defaults: dict = {}
for _field, _key in impl.path_field_map:
_value = _paths.get(_key)
if not _value:
continue
# bedpostx samples are lists of files (probtrack *samples are
# InputMultiPath); seed list-typed values as a list of path strings (a
# valid multi-file value), scalars as a single path string.
inputnode_defaults[_field] = (
[str(item) for item in _value] if isinstance(_value, (list, tuple)) else str(_value)
)
# Contract fields are the destination names the backend descriptor wires
# (single source of truth, shared with full_pipeline) plus the always-present
# optional intra-subject transform.
desc = select_backend(_resolve_method(config))
contract_fields = [dst for _, dst in desc.tractography_edges] + ["t1_to_dwi_transform"]
inputnode = attach_inputnode(meta, contract_fields, defaults=inputnode_defaults)
meta.connect(inputnode, "t1_brain", synthseg_wf, "inputnode.input_image")
meta.connect(synthseg_wf, "outputnode.segmentation", tract_wf, "inputnode.seg_map")
for field in contract_fields:
meta.connect(inputnode, field, tract_wf, f"inputnode.{field}")
outputnode = attach_outputnode(meta, ["fdt_paths"])
meta.connect(tract_wf, "outputnode.fdt_paths", outputnode, "fdt_paths")
logger.info(
"Built tract_synthseg meta-workflow for {} (backend={}, contract-wired)",
pid,
impl.tract_workflow,
)
# -- Optionally attach validation sub-workflow ----------------------------
validation_cfg = getattr(config, "validation", None)
if getattr(validation_cfg, "check_rois", False):
from .validation import build_validation_workflow
brain_mask_path = str(_prepare_paths(impl, config, context)["mask_path"])
validation_wf = build_validation_workflow(config, context, brain_mask_path)
# Contract-wired — no node-name introspection. Both backends re-publish
# their ROI terminus on outputnode.roi_{seed,stop,avoid,target}, so
# connect to those stable contract outputs directly.
meta.connect(
[
(
tract_wf,
validation_wf,
[
("outputnode.roi_seed", "roi_collector.seed"),
("outputnode.roi_stop", "roi_collector.stop_mask"),
("outputnode.roi_avoid", "roi_collector.avoid_mask"),
("outputnode.roi_target", "roi_collector.target_mask"),
],
)
]
)
logger.info(
"Connected tract_wf.outputnode.roi_* -> validation_wf.roi_collector (contract-wired)"
)
# Fixed-name anchor for cross-workflow ordering. When embedded in
# full_pipeline, the meta connects registration's completion to
# ``entry_gate._ordering_signal``. For the ProbTrackX2 backend we forward
# that into the hcp workflow's ``roi_transform_gate`` so the atlas ROI warp
# (which reads the template->patient transforms by path) only runs once
# registration has written them. Seeded with "" so standalone runs (no
# external connect) leave the gate satisfied and unordered. The MRtrix3
# backend has no such gate, so its entry_gate stays inert.
entry_gate = Node(IdentityInterface(fields=["_ordering_signal"]), name="entry_gate")
meta.add_nodes([entry_gate])
if impl.wire_roi_gate:
entry_gate.inputs._ordering_signal = ""
meta.connect(entry_gate, "_ordering_signal", tract_wf, "inputnode.roi_transform_gate")
return meta