Source code for thesis.workflows.full_pipeline._core

"""Backend-agnostic scaffolding for the unified full_pipeline meta-workflow.

This module factors out the parts of the meta-workflow that are independent
of the tractography backend (ProbTrackX2 vs MRtrix3): the cross-stage edges
that wire preprocess, registration, atlas_to_patient, and tract_similarity
together, plus the :class:`BackendDescriptor` capturing the per-backend deltas.

``full_pipeline`` builds its sub-workflows, then calls
:func:`wire_core_scaffold` to install the backend-agnostic contract edges and
:func:`finalize_tract_similarity_merge` to install the stage-5 barrier. The
backend is chosen with :func:`select_backend` from ``tractography.method``.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Mapping

import nipype.pipeline.engine as pe
from nipype import Node
from nipype.interfaces.utility import Merge

from thesis.core.config import PipelineConfig
from thesis.core.exceptions import ConfigurationError
from thesis.core.logging import get_logger
from thesis.workflows.preprocess.config import PreprocessConfig

logger = get_logger(__name__)

__all__ = [
    "BackendDescriptor",
    "select_backend",
    "CoreScaffold",
    "wire_core_scaffold",
    "finalize_tract_similarity_merge",
    "apply_preproc_override",
]


[docs] @dataclass(frozen=True) class BackendDescriptor: """Captures the per-backend deltas of the full pipeline. Attributes: name: Canonical backend name (``"probtrackx2"`` | ``"mrtrix3"``). run_bedpostx: Whether preprocess should run BedpostX. required_binaries: External binaries the verifier checks on PATH. tractography_edges: ``(preprocess.outputnode.<src>, tract_synthseg.inputnode.<dst>)`` field pairs to connect. """ name: str run_bedpostx: bool required_binaries: tuple[str, ...] tractography_edges: tuple[tuple[str, str], ...]
# Always-present preprocess.outputnode → tractography.inputnode edges. The # intra-subject DWI->T1 transform is intentionally NOT here: preprocess only # exposes it when run_coregistration is on, so full_pipeline.build_workflow # wires it conditionally. _COMMON_EDGES: tuple[tuple[str, str], ...] = ( ("t1_brain", "t1_brain"), ("dwi_mask", "dwi_mask"), ) _PROBTRACKX2 = BackendDescriptor( name="probtrackx2", run_bedpostx=True, required_binaries=("probtrackx2",), tractography_edges=_COMMON_EDGES + ( ("bedpostx_thsamples", "bedpostx_thsamples"), ("bedpostx_phsamples", "bedpostx_phsamples"), ("bedpostx_fsamples", "bedpostx_fsamples"), ), ) _MRTRIX3 = BackendDescriptor( name="mrtrix3", run_bedpostx=False, required_binaries=tuple(), # filled from mrtrix3 REQUIRED_BINARIES at verify time tractography_edges=_COMMON_EDGES + ( ("dwi_corrected", "dwi_corrected"), ("rotated_bvecs", "dwi_bvec"), ("modified_bval", "dwi_bval"), ), ) _BACKENDS: dict[str, BackendDescriptor] = { "probtrackx2": _PROBTRACKX2, "fsl": _PROBTRACKX2, "mrtrix3": _MRTRIX3, "tckgen": _MRTRIX3, } def _referenced_registration_jobs(config: PipelineConfig) -> list[str]: """Return the registration jobs the transform stage actually consumes. Each ``transforms.jobs[*].from_registration`` names the registration job whose produced transforms drive that transform job. When no transform job sets ``from_registration`` we fall back to the implicit default ``patient_to_template`` job so the legacy unsuffixed transform fan-out is still emitted (preserving existing wiring + tests). Order is first-seen, deterministic. """ seen: list[str] = [] for job in getattr(config.transforms, "jobs", []) or []: name = getattr(job, "from_registration", None) if name and name not in seen: seen.append(name) if not seen: return ["patient_to_template"] return seen
[docs] def select_backend(method: str) -> BackendDescriptor: """Return the :class:`BackendDescriptor` for ``tractography.method``. Raises: ConfigurationError: when *method* is not a recognised backend. """ try: return _BACKENDS[method] except KeyError as exc: raise ConfigurationError( f"Unknown tractography.method {method!r}. " f"Expected one of {sorted(_BACKENDS)}." ) from exc
[docs] @dataclass class CoreScaffold: """Resolved handles for cross-stage wiring. Attributes: meta: The meta-workflow being assembled. preprocess_wf: Patient-level preprocessing sub-workflow. registration_wf: Patient-to-template registration sub-workflow. atlas_to_patient_wf: Template-to-patient transform sub-workflow. tract_similarity_wf: Tract similarity / metrics sub-workflow. reg_method: Registration backend (``"ants"`` | ``"fireants"``). """ meta: pe.Workflow preprocess_wf: pe.Workflow registration_wf: pe.Workflow atlas_to_patient_wf: pe.Workflow tract_similarity_wf: pe.Workflow reg_method: str
[docs] def apply_preproc_override( config: PipelineConfig, override: Mapping[str, Any] | None ) -> PipelineConfig: """Return a config with preprocess overrides applied. Used by backend variants that must structurally disable a preprocess feature (e.g. the MRtrix3 variant forces ``run_bedpostx=False`` so BedpostX never runs and the bedpostx_done gate field is never created). Args: config: The original pipeline config. override: Mapping of ``PreprocessConfig`` field names to override values, or ``None`` for no-op. Returns: A new config with the overrides merged in. ``config`` is unchanged. """ if not override: return config raw = getattr(config, "preprocess", {}) or {} if isinstance(raw, PreprocessConfig): new_preproc = raw.model_copy(update=dict(override)) else: merged = {**dict(raw), **dict(override)} new_preproc = PreprocessConfig.model_validate(merged) return config.model_copy(update={"preprocess": new_preproc})
[docs] def wire_core_scaffold( meta: pe.Workflow, config: PipelineConfig, *, preprocess_wf: pe.Workflow, registration_wf: pe.Workflow, atlas_to_patient_wf: pe.Workflow, tract_similarity_wf: pe.Workflow, ) -> CoreScaffold: """Wire the backend-agnostic edges between sub-workflows via contracts. Edges wired (all contract-to-contract, no node-name introspection): * preprocess.outputnode.t1_brain → registration.inputnode.moving_image * registration.outputnode.transform → atlas_to_patient.entry_gate.ready * preprocess.outputnode.dwi_mask → atlas_to_patient.entry_gate.preprocess_done * (FireANTs only) registration.outputnode.{reverse,forward}_transforms → atlas_to_patient.entry_gate.{reverse,forward}_transforms The runtime T1-brain injection on ``atlas_to_patient.entry_gate.reference_image`` is left to the caller (it sources the value from ``preprocess.outputnode.t1_brain``). Args: meta: The meta-workflow being assembled. config: Validated pipeline configuration. preprocess_wf: Pre-built preprocess sub-workflow. registration_wf: Pre-built registration sub-workflow. atlas_to_patient_wf: Pre-built atlas_to_patient sub-workflow. tract_similarity_wf: Pre-built tract_similarity sub-workflow. Returns: A :class:`CoreScaffold` carrying the sub-workflow handles and the resolved registration backend name. """ reg_method = getattr(config.registration, "method", "ants") # Stage 1 → Stage 2. meta.connect( preprocess_wf, "outputnode.t1_brain", registration_wf, "inputnode.moving_image", ) # Stage 2 → Stage 4. meta.connect( registration_wf, "outputnode.transform", atlas_to_patient_wf, "entry_gate.ready", ) if reg_method == "fireants": # Fan out one transform edge pair per registration job actually # referenced by transforms.jobs[*].from_registration (default # {patient_to_template}). The default job uses the LEGACY unsuffixed # registration outputnode + atlas entry_gate fields; explicit jobs use # the per-job suffixed fields (forward_transforms__<job> / # reverse_transforms__<job>) the registration outputnode and the # atlas_to_patient entry_gate both publish. for regjob in _referenced_registration_jobs(config): if regjob == "patient_to_template": suffix = "" else: suffix = f"__{regjob}" meta.connect( registration_wf, f"outputnode.reverse_transforms{suffix}", atlas_to_patient_wf, f"entry_gate.reverse_transforms{suffix}", ) meta.connect( registration_wf, f"outputnode.forward_transforms{suffix}", atlas_to_patient_wf, f"entry_gate.forward_transforms{suffix}", ) # Stage 1 → Stage 4: block atlas transforms until the DWI brain mask # exists, so per-job reference_image overrides pointing at preprocess # outputs (e.g. nodif_brain_mask.nii.gz) don't crash with FileNotFoundError. meta.connect( preprocess_wf, "outputnode.dwi_mask", atlas_to_patient_wf, "entry_gate.preprocess_done", ) return CoreScaffold( meta=meta, preprocess_wf=preprocess_wf, registration_wf=registration_wf, atlas_to_patient_wf=atlas_to_patient_wf, tract_similarity_wf=tract_similarity_wf, reg_method=reg_method, )
[docs] def finalize_tract_similarity_merge( scaffold: CoreScaffold, fdt_sources: list[tuple[pe.Workflow, str]], ) -> Node: """Wire the tract_similarity stage barrier. Builds a ``Merge`` node with ``len(fdt_sources) + 1`` inputs (one per backend fdt_paths source plus ``atlas_to_patient.exit_gate.done``) and connects its ``out`` output to ``tract_similarity.entry_gate.ready``. Merge is used (not :class:`IdentityInterface`) because ``Merge.out`` fires only when every input is set — a true "wait for all" barrier. IdentityInterface forwards per-field, which would let downstream nodes fire before all upstream stages complete. Args: scaffold: The scaffold returned by :func:`wire_core_scaffold`. fdt_sources: ``(sub_workflow, output_port)`` pairs producing the backend's ``fdt_paths`` outputs. One entry for single-hemisphere runs, two for ``hemisphere=both-separately``. Returns: The Merge node, in case the caller needs additional wiring. """ n_inputs = len(fdt_sources) + 1 similarity_merge = Node(Merge(n_inputs), name="similarity_merge") for index, (source_wf, source_port) in enumerate(fdt_sources, start=1): scaffold.meta.connect( source_wf, source_port, similarity_merge, f"in{index}", ) scaffold.meta.connect( scaffold.atlas_to_patient_wf, "exit_gate.done", similarity_merge, f"in{n_inputs}", ) scaffold.meta.connect( similarity_merge, "out", scaffold.tract_similarity_wf, "entry_gate.ready", ) return similarity_merge