"""Backend-agnostic scaffolding for the unified full_pipeline meta-workflow.
This module factors out the parts of the meta-workflow that are independent
of the tractography backend (ProbTrackX2 vs MRtrix3): the cross-stage edges
that wire preprocess, registration, atlas_to_patient, and tract_similarity
together, plus the :class:`BackendDescriptor` capturing the per-backend deltas.
``full_pipeline`` builds its sub-workflows, then calls
:func:`wire_core_scaffold` to install the backend-agnostic contract edges and
:func:`finalize_tract_similarity_merge` to install the stage-5 barrier. The
backend is chosen with :func:`select_backend` from ``tractography.method``.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Mapping
import nipype.pipeline.engine as pe
from nipype import Node
from nipype.interfaces.utility import Merge
from thesis.core.config import PipelineConfig
from thesis.core.exceptions import ConfigurationError
from thesis.core.logging import get_logger
from thesis.workflows.preprocess.config import PreprocessConfig
logger = get_logger(__name__)
__all__ = [
"BackendDescriptor",
"select_backend",
"CoreScaffold",
"wire_core_scaffold",
"finalize_tract_similarity_merge",
"apply_preproc_override",
]
[docs]
@dataclass(frozen=True)
class BackendDescriptor:
"""Captures the per-backend deltas of the full pipeline.
Attributes:
name: Canonical backend name (``"probtrackx2"`` | ``"mrtrix3"``).
run_bedpostx: Whether preprocess should run BedpostX.
required_binaries: External binaries the verifier checks on PATH.
tractography_edges: ``(preprocess.outputnode.<src>,
tract_synthseg.inputnode.<dst>)`` field pairs to connect.
"""
name: str
run_bedpostx: bool
required_binaries: tuple[str, ...]
tractography_edges: tuple[tuple[str, str], ...]
# Always-present preprocess.outputnode → tractography.inputnode edges. The
# intra-subject DWI->T1 transform is intentionally NOT here: preprocess only
# exposes it when run_coregistration is on, so full_pipeline.build_workflow
# wires it conditionally.
_COMMON_EDGES: tuple[tuple[str, str], ...] = (
("t1_brain", "t1_brain"),
("dwi_mask", "dwi_mask"),
)
_PROBTRACKX2 = BackendDescriptor(
name="probtrackx2",
run_bedpostx=True,
required_binaries=("probtrackx2",),
tractography_edges=_COMMON_EDGES
+ (
("bedpostx_thsamples", "bedpostx_thsamples"),
("bedpostx_phsamples", "bedpostx_phsamples"),
("bedpostx_fsamples", "bedpostx_fsamples"),
),
)
_MRTRIX3 = BackendDescriptor(
name="mrtrix3",
run_bedpostx=False,
required_binaries=tuple(), # filled from mrtrix3 REQUIRED_BINARIES at verify time
tractography_edges=_COMMON_EDGES
+ (
("dwi_corrected", "dwi_corrected"),
("rotated_bvecs", "dwi_bvec"),
("modified_bval", "dwi_bval"),
),
)
_BACKENDS: dict[str, BackendDescriptor] = {
"probtrackx2": _PROBTRACKX2,
"fsl": _PROBTRACKX2,
"mrtrix3": _MRTRIX3,
"tckgen": _MRTRIX3,
}
def _referenced_registration_jobs(config: PipelineConfig) -> list[str]:
"""Return the registration jobs the transform stage actually consumes.
Each ``transforms.jobs[*].from_registration`` names the registration job
whose produced transforms drive that transform job. When no transform job
sets ``from_registration`` we fall back to the implicit default
``patient_to_template`` job so the legacy unsuffixed transform fan-out is
still emitted (preserving existing wiring + tests). Order is first-seen,
deterministic.
"""
seen: list[str] = []
for job in getattr(config.transforms, "jobs", []) or []:
name = getattr(job, "from_registration", None)
if name and name not in seen:
seen.append(name)
if not seen:
return ["patient_to_template"]
return seen
[docs]
def select_backend(method: str) -> BackendDescriptor:
"""Return the :class:`BackendDescriptor` for ``tractography.method``.
Raises:
ConfigurationError: when *method* is not a recognised backend.
"""
try:
return _BACKENDS[method]
except KeyError as exc:
raise ConfigurationError(
f"Unknown tractography.method {method!r}. " f"Expected one of {sorted(_BACKENDS)}."
) from exc
[docs]
@dataclass
class CoreScaffold:
"""Resolved handles for cross-stage wiring.
Attributes:
meta: The meta-workflow being assembled.
preprocess_wf: Patient-level preprocessing sub-workflow.
registration_wf: Patient-to-template registration sub-workflow.
atlas_to_patient_wf: Template-to-patient transform sub-workflow.
tract_similarity_wf: Tract similarity / metrics sub-workflow.
reg_method: Registration backend (``"ants"`` | ``"fireants"``).
"""
meta: pe.Workflow
preprocess_wf: pe.Workflow
registration_wf: pe.Workflow
atlas_to_patient_wf: pe.Workflow
tract_similarity_wf: pe.Workflow
reg_method: str
[docs]
def apply_preproc_override(
config: PipelineConfig, override: Mapping[str, Any] | None
) -> PipelineConfig:
"""Return a config with preprocess overrides applied.
Used by backend variants that must structurally disable a preprocess
feature (e.g. the MRtrix3 variant forces ``run_bedpostx=False`` so
BedpostX never runs and the bedpostx_done gate field is never created).
Args:
config: The original pipeline config.
override: Mapping of ``PreprocessConfig`` field names to override
values, or ``None`` for no-op.
Returns:
A new config with the overrides merged in. ``config`` is unchanged.
"""
if not override:
return config
raw = getattr(config, "preprocess", {}) or {}
if isinstance(raw, PreprocessConfig):
new_preproc = raw.model_copy(update=dict(override))
else:
merged = {**dict(raw), **dict(override)}
new_preproc = PreprocessConfig.model_validate(merged)
return config.model_copy(update={"preprocess": new_preproc})
[docs]
def wire_core_scaffold(
meta: pe.Workflow,
config: PipelineConfig,
*,
preprocess_wf: pe.Workflow,
registration_wf: pe.Workflow,
atlas_to_patient_wf: pe.Workflow,
tract_similarity_wf: pe.Workflow,
) -> CoreScaffold:
"""Wire the backend-agnostic edges between sub-workflows via contracts.
Edges wired (all contract-to-contract, no node-name introspection):
* preprocess.outputnode.t1_brain → registration.inputnode.moving_image
* registration.outputnode.transform → atlas_to_patient.entry_gate.ready
* preprocess.outputnode.dwi_mask → atlas_to_patient.entry_gate.preprocess_done
* (FireANTs only) registration.outputnode.{reverse,forward}_transforms →
atlas_to_patient.entry_gate.{reverse,forward}_transforms
The runtime T1-brain injection on ``atlas_to_patient.entry_gate.reference_image``
is left to the caller (it sources the value from
``preprocess.outputnode.t1_brain``).
Args:
meta: The meta-workflow being assembled.
config: Validated pipeline configuration.
preprocess_wf: Pre-built preprocess sub-workflow.
registration_wf: Pre-built registration sub-workflow.
atlas_to_patient_wf: Pre-built atlas_to_patient sub-workflow.
tract_similarity_wf: Pre-built tract_similarity sub-workflow.
Returns:
A :class:`CoreScaffold` carrying the sub-workflow handles and the
resolved registration backend name.
"""
reg_method = getattr(config.registration, "method", "ants")
# Stage 1 → Stage 2.
meta.connect(
preprocess_wf,
"outputnode.t1_brain",
registration_wf,
"inputnode.moving_image",
)
# Stage 2 → Stage 4.
meta.connect(
registration_wf,
"outputnode.transform",
atlas_to_patient_wf,
"entry_gate.ready",
)
if reg_method == "fireants":
# Fan out one transform edge pair per registration job actually
# referenced by transforms.jobs[*].from_registration (default
# {patient_to_template}). The default job uses the LEGACY unsuffixed
# registration outputnode + atlas entry_gate fields; explicit jobs use
# the per-job suffixed fields (forward_transforms__<job> /
# reverse_transforms__<job>) the registration outputnode and the
# atlas_to_patient entry_gate both publish.
for regjob in _referenced_registration_jobs(config):
if regjob == "patient_to_template":
suffix = ""
else:
suffix = f"__{regjob}"
meta.connect(
registration_wf,
f"outputnode.reverse_transforms{suffix}",
atlas_to_patient_wf,
f"entry_gate.reverse_transforms{suffix}",
)
meta.connect(
registration_wf,
f"outputnode.forward_transforms{suffix}",
atlas_to_patient_wf,
f"entry_gate.forward_transforms{suffix}",
)
# Stage 1 → Stage 4: block atlas transforms until the DWI brain mask
# exists, so per-job reference_image overrides pointing at preprocess
# outputs (e.g. nodif_brain_mask.nii.gz) don't crash with FileNotFoundError.
meta.connect(
preprocess_wf,
"outputnode.dwi_mask",
atlas_to_patient_wf,
"entry_gate.preprocess_done",
)
return CoreScaffold(
meta=meta,
preprocess_wf=preprocess_wf,
registration_wf=registration_wf,
atlas_to_patient_wf=atlas_to_patient_wf,
tract_similarity_wf=tract_similarity_wf,
reg_method=reg_method,
)
[docs]
def finalize_tract_similarity_merge(
scaffold: CoreScaffold,
fdt_sources: list[tuple[pe.Workflow, str]],
) -> Node:
"""Wire the tract_similarity stage barrier.
Builds a ``Merge`` node with ``len(fdt_sources) + 1`` inputs (one per
backend fdt_paths source plus ``atlas_to_patient.exit_gate.done``) and
connects its ``out`` output to ``tract_similarity.entry_gate.ready``.
Merge is used (not :class:`IdentityInterface`) because ``Merge.out``
fires only when every input is set — a true "wait for all" barrier.
IdentityInterface forwards per-field, which would let downstream nodes
fire before all upstream stages complete.
Args:
scaffold: The scaffold returned by :func:`wire_core_scaffold`.
fdt_sources: ``(sub_workflow, output_port)`` pairs producing the
backend's ``fdt_paths`` outputs. One entry for single-hemisphere
runs, two for ``hemisphere=both-separately``.
Returns:
The Merge node, in case the caller needs additional wiring.
"""
n_inputs = len(fdt_sources) + 1
similarity_merge = Node(Merge(n_inputs), name="similarity_merge")
for index, (source_wf, source_port) in enumerate(fdt_sources, start=1):
scaffold.meta.connect(
source_wf,
source_port,
similarity_merge,
f"in{index}",
)
scaffold.meta.connect(
scaffold.atlas_to_patient_wf,
"exit_gate.done",
similarity_merge,
f"in{n_inputs}",
)
scaffold.meta.connect(
similarity_merge,
"out",
scaffold.tract_similarity_wf,
"entry_gate.ready",
)
return similarity_merge