Source code for thesis.workflows.hcp.workflow

"""HCP ProbTrackX2 workflow orchestration."""

from pathlib import Path

from nipype import JoinNode, Node, Workflow
from nipype.interfaces.utility import IdentityInterface

from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import HCP_ROI_OUTPUT_FIELDS, attach_inputnode, fan_out
from thesis.core.decorators import produces, requires, verify, workflow
from thesis.core.path_declarations import (
    GlobGroup,
    GlobGroupResult,
    OutputDir,
    PatientDir,
    PatientFile,
)
from thesis.core.utils import write_list_file

from .common import (
    get_atlas_sources,
    get_role_kinds,
    get_validation_thresholds,
    resolve_with_fallback,
)
from .config import prepare_hcp_paths, prepare_seed_path, prepare_tractography_params
from .nodes import (
    prepare_atlas_extract_mapnode,
    prepare_atlas_merge_joinnode,
    prepare_atlas_transform_mapnode,
    prepare_atlas_validate_mapnode,
    prepare_final_merger,
    prepare_probtrackx2_node,
    prepare_probtrackx2_outdir_router,
    prepare_probtrackx2_params_writer,
    prepare_roi_dwi_resampler,
    prepare_roi_validator,
    prepare_streamline_warper,
    prepare_synthseg_roi_extractor,
    prepare_synthseg_seg_resampler,
    prepare_waypoint_avoid_overlap_verifier,
)

# Re-exported from the workflow module so existing tests / external callers can
# still import these names from ``thesis.workflows.hcp.workflow``.
from .verifiers import (  # noqa: F401
    _verify_atlas_transform,
    _verify_bedpostx_samples,
    _verify_brain_mask,
    _verify_roi_inputs,
    _verify_seed_inputs,
    verify_requirements,
)

ROI_FIELDS = ("seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask")
ROI_TO_ROI = [(f, f) for f in ROI_FIELDS]
PROBTRACK_WIRING = [
    ("seed", "seed"),
    ("waypoints_file", "waypoints"),
    ("stop_mask", "stop_mask"),
    ("avoid_mask", "avoid_mp"),
    ("target_mask", "target_masks"),
]


def _hemispheres_for_mode(mode: str) -> list[str]:
    if mode == "both-separately":
        return ["left", "right"]
    if mode in ("left", "right", "both"):
        return [mode]
    return ["both"]


def _make_target_list(path):
    """Wrap a single target mask path as a list for ProbTrackX2's ``target_masks`` input."""
    return [path] if path else []


def _warped_dir_from_outdir(out_dir):
    """Return ``<out_dir>/warped_streamlines`` for the streamline warper.

    Fed from ``probtrackx2_outdir_router.out_dir`` — the exact per-hemisphere
    directory ProbTrackX2 writes its ``fdt_paths.nii.gz`` into — so the warper
    writes alongside whichever hemisphere subdir was produced, without
    reverse-engineering it from ProbTrackX2's output file paths.
    """
    from pathlib import Path as _P

    if not out_dir:
        return ""
    return str(_P(out_dir) / "warped_streamlines")


def _prepare_synthseg(config, context, paths, out_dir):
    """Return (resampler, extractor, validator) or (None, None, None) when synthseg is absent."""
    extractor = prepare_synthseg_roi_extractor(config, Path(context.input_dir), context, out_dir)
    if extractor is None:
        return None, None, None
    resampler = prepare_synthseg_seg_resampler(str(paths["t1_path"]), out_dir)
    synthseg_roi_cfg = getattr(config.tractography, "synthseg_roi_labels", None) or {}
    if synthseg_roi_cfg.get("roi_file"):
        roi_file = resolve_with_fallback(
            synthseg_roi_cfg["roi_file"], Path(context.input_dir), [context.data_dir]
        )
        resampler.inputs.input_image = str(roi_file)
    validator = prepare_roi_validator(
        reference_image=str(paths["t1_path"]),
        name="synthseg_roi_validator",
        **get_validation_thresholds(config),
    )
    return resampler, extractor, validator


def _kinds_present(config) -> set[str]:
    """Union of ROI role kinds (seed/waypoint/stop/avoid/target) declared across sources."""
    kinds: set[str] = set()
    for source in get_atlas_sources(config):
        kinds |= get_role_kinds(source)
    kinds |= get_role_kinds(getattr(config.tractography, "synthseg_roi_labels", None))
    return kinds


def _configure_fixed_path_inputs(probtrack, config, input_dir, out_dir):
    """Set static seed/waypoints inputs when no ROI sources are configured."""
    probtrack.inputs.seed = str(prepare_seed_path(config, input_dir))
    waypoint_masks = config.tractography.waypoint_masks
    if waypoint_masks:
        waypoint_paths = [
            resolve_with_fallback(str(path), input_dir, []) for path in waypoint_masks
        ]
        probtrack.inputs.waypoints = str(write_list_file(waypoint_paths, out_dir / "waypoints.txt"))


[docs] @workflow( name="hcp", description="HCP ProbTrackX2 tractography workflow using FSL ProbTrackX2.", protocol="hcp", ) @requires( t1=PatientFile( default="T1w/T1w_acpc_dc_restore_brain.nii.gz", config_paths=["synthseg.t1_image", "hcp.t1_image"], ), mask=PatientFile( default="T1w/Diffusion/nodif_brain_mask.nii", config_path="hcp.mask_path", ), bedpostx=GlobGroup( items={ "thsamples": "merged_th*samples.nii.gz", "phsamples": "merged_ph*samples.nii.gz", "fsamples": "merged_f*samples.nii.gz", }, primary_dir=PatientDir( default="T1w/Diffusion.bedpostX", config_path="hcp.bedpostx_dir", ), ), ) @produces(tract_dir=OutputDir("tractography/probtrackx2")) @verify(verify_requirements) def build_workflow( *, t1: Path, mask: Path, bedpostx: GlobGroupResult, tract_dir: Path, config: PipelineConfig, context: ProcessingContext, ) -> Workflow: """Build the HCP ProbTrackX2 tractography workflow. Hemisphere modes via ``config.tractography.hemisphere``: ``"left"``/``"right"`` (single hemisphere via fixed input), ``"both"`` (merged ROIs, default), ``"both-separately"`` (two hemispheres via Nipype ``iterables``). """ wf = Workflow(name=f"hcp_{context.patient_id}") paths = prepare_hcp_paths(config, context) # Prefer @requires-resolved paths when they exist on disk; otherwise fall # back to what prepare_hcp_paths returned (which may be a runtime path # produced by upstream stages in full_pipeline, or a mocked path in tests). if t1.exists(): paths = {**paths, "t1_path": t1} if mask.exists(): paths = {**paths, "mask_path": mask} if bedpostx._found_dir is not None: paths = { **paths, "thsamples": bedpostx.thsamples, "phsamples": bedpostx.phsamples, "fsamples": bedpostx.fsamples, "bedpostx_dir": bedpostx._found_dir, "samples_base_name": bedpostx._found_dir / "merged", } params = prepare_tractography_params(config) input_dir = paths["input_dir"] out_dir = tract_dir.resolve() inputs = Node(IdentityInterface(fields=["hemisphere"]), name="inputs") hemispheres = _hemispheres_for_mode(getattr(config.tractography, "hemisphere", "both")) if len(hemispheres) > 1: inputs.iterables = ("hemisphere", hemispheres) else: inputs.inputs.hemisphere = hemispheres[0] atlas_extract = prepare_atlas_extract_mapnode(config, context, out_dir) atlas_transform = prepare_atlas_transform_mapnode(config, context, out_dir) atlas_validate = ( prepare_atlas_validate_mapnode( get_validation_thresholds(config), str(paths["t1_path"]), out_dir ) if atlas_extract is not None else None ) atlas_merge = ( prepare_atlas_merge_joinnode(atlas_validate.name, out_dir) if atlas_validate is not None else None ) synthseg_resampler, synthseg_extract, synthseg_validate = _prepare_synthseg( config, context, paths, out_dir ) if atlas_merge is not None and synthseg_validate is not None: final_merge = prepare_final_merger(out_dir) else: final_merge = None has_any_roi = atlas_extract is not None or synthseg_extract is not None dwi_resampler = ( prepare_roi_dwi_resampler(str(paths["mask_path"]), out_dir) if has_any_roi else None ) # Pre-ProbTrackX2 fail-fast: catches waypoint masks landing inside avoid # (the silent waytotal=0 failure mode that surfaces when a warp or # hemisphere route goes wrong). waypoint_avoid_verifier = prepare_waypoint_avoid_overlap_verifier() if has_any_roi else None probtrack = prepare_probtrackx2_node( params, paths, str(out_dir), use_gpu=getattr(config.hardware, "gpu_enabled", False), gpu_runtime_env=getattr(config.tractography, "gpu_runtime_env", None), gpu_slot_cost=int(getattr(config.hardware, "n_gpu_procs", 1)), ) params_writer = prepare_probtrackx2_params_writer( out_dir=out_dir, n_samples=int(params["n_samples"]) ) warper = prepare_streamline_warper(config, context, out_dir) # Adapter node that emits ``<out_dir>/<hemisphere>`` (or ``<out_dir>`` when # hemisphere is "both"). Lets probtrackx2 + params_writer + warper share a # single per-iteration output root, so they don't trample each other when # the workflow runs under ``--hemisphere both-separately``. Only relevant # when multiple hemispheres iterate; for single-hemisphere runs we still # build the node (it just emits the base path) to keep wiring uniform. probtrack_outdir_router = prepare_probtrackx2_outdir_router(out_dir) connections: list = [] # Route the active hemisphere into every stage that writes per-hemisphere # outputs, so identically-named files never collide across iterations. # Without these, ab65d63 only fixed the extractor and downstream stages # raced — producing waytotal=0 for ~20% of subjects under both-separately. connections.append((inputs, probtrack_outdir_router, [("hemisphere", "hemisphere")])) if atlas_extract is not None: connections.append((inputs, atlas_extract, [("hemisphere", "hemisphere")])) if atlas_transform is not None: connections.append((inputs, atlas_transform, [("hemisphere", "hemisphere")])) if atlas_extract is not None and atlas_transform is not None: connections.append((atlas_extract, atlas_transform, ROI_TO_ROI)) if atlas_validate is not None: upstream_for_validate = atlas_transform or atlas_extract connections.append((upstream_for_validate, atlas_validate, ROI_TO_ROI)) if atlas_merge is not None and atlas_validate is not None: connections.append((inputs, atlas_merge, [("hemisphere", "hemisphere")])) # JoinNode wiring: collect each ROI field across MapNode iterations. connections.append( ( atlas_validate, atlas_merge, [ ("seed", "seeds"), ("waypoints_file", "waypoints_files"), ("stop_mask", "stop_masks"), ("avoid_mask", "avoid_masks"), ("target_mask", "target_masks"), ], ) ) if synthseg_resampler is not None and synthseg_extract is not None: connections.append((inputs, synthseg_extract, [("hemisphere", "hemisphere")])) connections.append( (synthseg_resampler, synthseg_extract, [("resampled_segmentation", "roi_file")]) ) if synthseg_validate is not None: connections.append((synthseg_extract, synthseg_validate, ROI_TO_ROI)) if final_merge is not None: connections.append((inputs, final_merge, [("hemisphere", "hemisphere")])) connections.append( ( atlas_merge, final_merge, [ ("seed", "left_seed"), ("waypoints_file", "left_waypoints_file"), ("stop_mask", "left_stop"), ("avoid_mask", "left_avoid"), ("target_mask", "left_target"), ], ) ) connections.append( ( synthseg_validate, final_merge, [ ("seed", "right_seed"), ("waypoints_file", "right_waypoints_file"), ("stop_mask", "right_stop"), ("avoid_mask", "right_avoid"), ("target_mask", "right_target"), ], ) ) roi_terminus = final_merge elif atlas_merge is not None: roi_terminus = atlas_merge elif synthseg_validate is not None: roi_terminus = synthseg_validate else: roi_terminus = None if roi_terminus is not None and dwi_resampler is not None: connections.append((inputs, dwi_resampler, [("hemisphere", "hemisphere")])) connections.append((roi_terminus, dwi_resampler, ROI_TO_ROI)) # Route per-hemisphere out_dir into the three nodes that write into the # ProbTrackX2 directory (probtrack itself, the params_writer, and the # streamline warper). Without this, both hemisphere iterations under # both-separately write fdt_paths.nii.gz / waytotal / tractography_params.json # to the same path and the last writer wins. connections.append((probtrack_outdir_router, probtrack, [("out_dir", "out_dir")])) connections.append((probtrack_outdir_router, params_writer, [("out_dir", "output_dir")])) kinds = _kinds_present(config) if dwi_resampler is not None: # waypoint_avoid_verifier and dwi_resampler are gated on the same # has_any_roi predicate, so the verifier always exists in this branch. # Pass all five ROI fields through it as a fail-fast guard — pure # passthrough on success, PipelineError on catastrophic overlap. assert waypoint_avoid_verifier is not None # for mypy / invariant doc connections.append((dwi_resampler, waypoint_avoid_verifier, ROI_TO_ROI)) probtrack_source_node = waypoint_avoid_verifier probtrack_source_pairs = [ (out_f, in_f) for out_f, in_f in PROBTRACK_WIRING if { "seed": "seed", "waypoints_file": "waypoint", "stop_mask": "stop", "avoid_mask": "avoid", "target_mask": "target", }[out_f] in kinds ] # Wrap target_mask in a list inline via Nipype's connect-with-function # syntax — ProbTrackX2's ``target_masks`` input wants a list of paths. plain_pairs = [(o, i) for o, i in probtrack_source_pairs if o != "target_mask"] if plain_pairs: connections.append((probtrack_source_node, probtrack, plain_pairs)) if "target" in kinds: probtrack.inputs.os2t = True connections.append( ( probtrack_source_node, probtrack, [(("target_mask", _make_target_list), "target_masks")], ) ) else: _configure_fixed_path_inputs(probtrack, config, input_dir, out_dir) if warper is not None: # Feed the warper's directories from the outdir router — the single # source of truth for the per-hemisphere directory ProbTrackX2 writes # into. ``fdt_paths`` is the directory the warper scans for outputs and # ``output_dir`` is ``<router_out_dir>/warped_streamlines``, avoiding any # reverse-engineering of directories from ProbTrackX2's output paths. connections.append( ( probtrack_outdir_router, warper, [ ("out_dir", "fdt_paths"), (("out_dir", _warped_dir_from_outdir), "output_dir"), ], ) ) # Ordering dependency: the router path is available before ProbTrackX2 # runs, so wire probtrackx2.fdt_paths into the warper's ignored ordering # input. This forces the warper to run only after ProbTrackX2 has written # its outputs into that directory (otherwise the warper scans an empty # dir and silently produces no warped streamlines). connections.append((probtrack, warper, [("fdt_paths", "_ordering_signal")])) # Add every non-None node we created; Nipype dedups. for node in ( inputs, atlas_extract, atlas_transform, atlas_validate, atlas_merge, synthseg_resampler, synthseg_extract, synthseg_validate, final_merge, dwi_resampler, waypoint_avoid_verifier, probtrack_outdir_router, probtrack, params_writer, warper, ): if node is not None: wf.add_nodes([node]) real_connections = [c for c in connections if c[0] is not None and c[1] is not None] if real_connections: wf.connect(real_connections) # -- Input contract: runtime-overridable inputs fanned to internal nodes - # Standalone runs need build-time defaults on the contract node itself, because # the fan_out edges below overwrite any static values previously set on the # consumer nodes (atlas_validate, probtrack, dwi_resampler, ...). Without # defaults, an inputnode field that no meta-workflow has connected stays # Undefined, which then propagates to downstream Function nodes and surfaces # as ``TypeError: missing 1 required positional argument`` at execute time. # ``t1_to_dwi_transform``, ``seg_map`` and ``roi_transform_gate`` are # genuinely optional / meta-only — leave them Undefined; their consumers are # either gated by conditionals (``dwi_resampler is not None``, # ``synthseg_resampler is not None``) or accept Undefined (ordering signal). # Meta-workflow edges still override these defaults — that's standard Nipype # connection semantics. inputnode_defaults: dict = { "t1_brain": str(paths["t1_path"]), "dwi_mask": str(paths["mask_path"]), } for field, key in ( ("bedpostx_thsamples", "thsamples"), ("bedpostx_phsamples", "phsamples"), ("bedpostx_fsamples", "fsamples"), ): value = paths.get(key) if not value: continue # bedpostx samples are lists of files (probtrack thsamples/phsamples/ # fsamples are InputMultiPath). Seed as a list of path strings, not a # stringified list — this default overrides probtrack's build-time value # via the inputnode->probtrack connection at runtime, so it must itself be # a valid multi-file value. inputnode_defaults[field] = ( [str(item) for item in value] if isinstance(value, (list, tuple)) else [str(value)] ) inputnode = attach_inputnode( wf, [ "t1_brain", "dwi_mask", "t1_to_dwi_transform", "seg_map", "bedpostx_thsamples", "bedpostx_phsamples", "bedpostx_fsamples", "roi_transform_gate", ], defaults=inputnode_defaults, ) t1_targets: list = [] for n in (atlas_validate, synthseg_validate): if n is not None: t1_targets.append((n, "reference_image")) if synthseg_resampler is not None: t1_targets.append((synthseg_resampler, "reference")) fan_out(wf, inputnode, "t1_brain", t1_targets) if synthseg_resampler is not None: wf.connect(inputnode, "seg_map", synthseg_resampler, "input_image") # No warp_field field: atlas_transform is a MapNode whose warp_field / # reference_image are per-source iterfields resolved from config at build # time; a scalar runtime override would break per-source iteration. (The old # full_pipeline broadcast fireants reverse_transforms to legacy roi_transformer # nodes the MapNode design no longer produces, so this was already a no-op.) # Ordering-only gate: block the atlas ROI warp until the patient->template # registration has written its transforms (which atlas_transform reads by # path). full_pipeline drives roi_transform_gate from registration via # tract_synthseg.entry_gate; standalone runs leave it unset (transforms are # precomputed). Broadcast onto the MapNode's non-iterfield _ordering_signal, # the same mechanism as `hemisphere`; the task ignores the value. if atlas_transform is not None: wf.connect(inputnode, "roi_transform_gate", atlas_transform, "_ordering_signal") if dwi_resampler is not None: wf.connect( [ ( inputnode, dwi_resampler, [ ("dwi_mask", "reference"), ("t1_to_dwi_transform", "t1_to_dwi_transform"), ], ), ] ) wf.connect( [ ( inputnode, probtrack, [ ("dwi_mask", "mask"), ("bedpostx_thsamples", "thsamples"), ("bedpostx_phsamples", "phsamples"), ("bedpostx_fsamples", "fsamples"), ], ), ] ) # -- Output contract: fdt_paths, joined across the hemisphere iterable ---- # The ROI terminus (atlas_merge / final_merger / synthseg_validate) is a # single combined Node, so its seed/stop/avoid/target outputs are scalar and # are re-published as regular (non-join) outputnode fields under the stable # contract names roi_seed/roi_stop/roi_avoid/roi_target. A meta-workflow # (e.g. tract_synthseg ROI validation) wires to these instead of scanning the # internal graph for a node name. They stay Undefined when no ROI assembly # was built. fdt_paths remains the only field joined across hemispheres. contract_fields = ["fdt_paths", *HCP_ROI_OUTPUT_FIELDS] if len(hemispheres) > 1: outputnode = JoinNode( IdentityInterface(fields=contract_fields), joinsource="inputs", joinfield="fdt_paths", name="outputnode", ) else: outputnode = Node(IdentityInterface(fields=contract_fields), name="outputnode") wf.add_nodes([outputnode]) wf.connect(probtrack, "fdt_paths", outputnode, "fdt_paths") if roi_terminus is not None: wf.connect( [ ( roi_terminus, outputnode, [ ("seed", "roi_seed"), ("stop_mask", "roi_stop"), ("avoid_mask", "roi_avoid"), ("target_mask", "roi_target"), ], ) ] ) return wf