"""HCP ProbTrackX2 workflow orchestration."""
from pathlib import Path
from nipype import JoinNode, Node, Workflow
from nipype.interfaces.utility import IdentityInterface
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import HCP_ROI_OUTPUT_FIELDS, attach_inputnode, fan_out
from thesis.core.decorators import produces, requires, verify, workflow
from thesis.core.path_declarations import (
GlobGroup,
GlobGroupResult,
OutputDir,
PatientDir,
PatientFile,
)
from thesis.core.utils import write_list_file
from .common import (
get_atlas_sources,
get_role_kinds,
get_validation_thresholds,
resolve_with_fallback,
)
from .config import prepare_hcp_paths, prepare_seed_path, prepare_tractography_params
from .nodes import (
prepare_atlas_extract_mapnode,
prepare_atlas_merge_joinnode,
prepare_atlas_transform_mapnode,
prepare_atlas_validate_mapnode,
prepare_final_merger,
prepare_probtrackx2_node,
prepare_probtrackx2_outdir_router,
prepare_probtrackx2_params_writer,
prepare_roi_dwi_resampler,
prepare_roi_validator,
prepare_streamline_warper,
prepare_synthseg_roi_extractor,
prepare_synthseg_seg_resampler,
prepare_waypoint_avoid_overlap_verifier,
)
# Re-exported from the workflow module so existing tests / external callers can
# still import these names from ``thesis.workflows.hcp.workflow``.
from .verifiers import ( # noqa: F401
_verify_atlas_transform,
_verify_bedpostx_samples,
_verify_brain_mask,
_verify_roi_inputs,
_verify_seed_inputs,
verify_requirements,
)
ROI_FIELDS = ("seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask")
ROI_TO_ROI = [(f, f) for f in ROI_FIELDS]
PROBTRACK_WIRING = [
("seed", "seed"),
("waypoints_file", "waypoints"),
("stop_mask", "stop_mask"),
("avoid_mask", "avoid_mp"),
("target_mask", "target_masks"),
]
def _hemispheres_for_mode(mode: str) -> list[str]:
if mode == "both-separately":
return ["left", "right"]
if mode in ("left", "right", "both"):
return [mode]
return ["both"]
def _make_target_list(path):
"""Wrap a single target mask path as a list for ProbTrackX2's ``target_masks`` input."""
return [path] if path else []
def _warped_dir_from_outdir(out_dir):
"""Return ``<out_dir>/warped_streamlines`` for the streamline warper.
Fed from ``probtrackx2_outdir_router.out_dir`` — the exact per-hemisphere
directory ProbTrackX2 writes its ``fdt_paths.nii.gz`` into — so the warper
writes alongside whichever hemisphere subdir was produced, without
reverse-engineering it from ProbTrackX2's output file paths.
"""
from pathlib import Path as _P
if not out_dir:
return ""
return str(_P(out_dir) / "warped_streamlines")
def _prepare_synthseg(config, context, paths, out_dir):
"""Return (resampler, extractor, validator) or (None, None, None) when synthseg is absent."""
extractor = prepare_synthseg_roi_extractor(config, Path(context.input_dir), context, out_dir)
if extractor is None:
return None, None, None
resampler = prepare_synthseg_seg_resampler(str(paths["t1_path"]), out_dir)
synthseg_roi_cfg = getattr(config.tractography, "synthseg_roi_labels", None) or {}
if synthseg_roi_cfg.get("roi_file"):
roi_file = resolve_with_fallback(
synthseg_roi_cfg["roi_file"], Path(context.input_dir), [context.data_dir]
)
resampler.inputs.input_image = str(roi_file)
validator = prepare_roi_validator(
reference_image=str(paths["t1_path"]),
name="synthseg_roi_validator",
**get_validation_thresholds(config),
)
return resampler, extractor, validator
def _kinds_present(config) -> set[str]:
"""Union of ROI role kinds (seed/waypoint/stop/avoid/target) declared across sources."""
kinds: set[str] = set()
for source in get_atlas_sources(config):
kinds |= get_role_kinds(source)
kinds |= get_role_kinds(getattr(config.tractography, "synthseg_roi_labels", None))
return kinds
def _configure_fixed_path_inputs(probtrack, config, input_dir, out_dir):
"""Set static seed/waypoints inputs when no ROI sources are configured."""
probtrack.inputs.seed = str(prepare_seed_path(config, input_dir))
waypoint_masks = config.tractography.waypoint_masks
if waypoint_masks:
waypoint_paths = [
resolve_with_fallback(str(path), input_dir, []) for path in waypoint_masks
]
probtrack.inputs.waypoints = str(write_list_file(waypoint_paths, out_dir / "waypoints.txt"))
[docs]
@workflow(
name="hcp",
description="HCP ProbTrackX2 tractography workflow using FSL ProbTrackX2.",
protocol="hcp",
)
@requires(
t1=PatientFile(
default="T1w/T1w_acpc_dc_restore_brain.nii.gz",
config_paths=["synthseg.t1_image", "hcp.t1_image"],
),
mask=PatientFile(
default="T1w/Diffusion/nodif_brain_mask.nii",
config_path="hcp.mask_path",
),
bedpostx=GlobGroup(
items={
"thsamples": "merged_th*samples.nii.gz",
"phsamples": "merged_ph*samples.nii.gz",
"fsamples": "merged_f*samples.nii.gz",
},
primary_dir=PatientDir(
default="T1w/Diffusion.bedpostX",
config_path="hcp.bedpostx_dir",
),
),
)
@produces(tract_dir=OutputDir("tractography/probtrackx2"))
@verify(verify_requirements)
def build_workflow(
*,
t1: Path,
mask: Path,
bedpostx: GlobGroupResult,
tract_dir: Path,
config: PipelineConfig,
context: ProcessingContext,
) -> Workflow:
"""Build the HCP ProbTrackX2 tractography workflow.
Hemisphere modes via ``config.tractography.hemisphere``: ``"left"``/``"right"``
(single hemisphere via fixed input), ``"both"`` (merged ROIs, default),
``"both-separately"`` (two hemispheres via Nipype ``iterables``).
"""
wf = Workflow(name=f"hcp_{context.patient_id}")
paths = prepare_hcp_paths(config, context)
# Prefer @requires-resolved paths when they exist on disk; otherwise fall
# back to what prepare_hcp_paths returned (which may be a runtime path
# produced by upstream stages in full_pipeline, or a mocked path in tests).
if t1.exists():
paths = {**paths, "t1_path": t1}
if mask.exists():
paths = {**paths, "mask_path": mask}
if bedpostx._found_dir is not None:
paths = {
**paths,
"thsamples": bedpostx.thsamples,
"phsamples": bedpostx.phsamples,
"fsamples": bedpostx.fsamples,
"bedpostx_dir": bedpostx._found_dir,
"samples_base_name": bedpostx._found_dir / "merged",
}
params = prepare_tractography_params(config)
input_dir = paths["input_dir"]
out_dir = tract_dir.resolve()
inputs = Node(IdentityInterface(fields=["hemisphere"]), name="inputs")
hemispheres = _hemispheres_for_mode(getattr(config.tractography, "hemisphere", "both"))
if len(hemispheres) > 1:
inputs.iterables = ("hemisphere", hemispheres)
else:
inputs.inputs.hemisphere = hemispheres[0]
atlas_extract = prepare_atlas_extract_mapnode(config, context, out_dir)
atlas_transform = prepare_atlas_transform_mapnode(config, context, out_dir)
atlas_validate = (
prepare_atlas_validate_mapnode(
get_validation_thresholds(config), str(paths["t1_path"]), out_dir
)
if atlas_extract is not None
else None
)
atlas_merge = (
prepare_atlas_merge_joinnode(atlas_validate.name, out_dir)
if atlas_validate is not None
else None
)
synthseg_resampler, synthseg_extract, synthseg_validate = _prepare_synthseg(
config, context, paths, out_dir
)
if atlas_merge is not None and synthseg_validate is not None:
final_merge = prepare_final_merger(out_dir)
else:
final_merge = None
has_any_roi = atlas_extract is not None or synthseg_extract is not None
dwi_resampler = (
prepare_roi_dwi_resampler(str(paths["mask_path"]), out_dir) if has_any_roi else None
)
# Pre-ProbTrackX2 fail-fast: catches waypoint masks landing inside avoid
# (the silent waytotal=0 failure mode that surfaces when a warp or
# hemisphere route goes wrong).
waypoint_avoid_verifier = prepare_waypoint_avoid_overlap_verifier() if has_any_roi else None
probtrack = prepare_probtrackx2_node(
params,
paths,
str(out_dir),
use_gpu=getattr(config.hardware, "gpu_enabled", False),
gpu_runtime_env=getattr(config.tractography, "gpu_runtime_env", None),
gpu_slot_cost=int(getattr(config.hardware, "n_gpu_procs", 1)),
)
params_writer = prepare_probtrackx2_params_writer(
out_dir=out_dir, n_samples=int(params["n_samples"])
)
warper = prepare_streamline_warper(config, context, out_dir)
# Adapter node that emits ``<out_dir>/<hemisphere>`` (or ``<out_dir>`` when
# hemisphere is "both"). Lets probtrackx2 + params_writer + warper share a
# single per-iteration output root, so they don't trample each other when
# the workflow runs under ``--hemisphere both-separately``. Only relevant
# when multiple hemispheres iterate; for single-hemisphere runs we still
# build the node (it just emits the base path) to keep wiring uniform.
probtrack_outdir_router = prepare_probtrackx2_outdir_router(out_dir)
connections: list = []
# Route the active hemisphere into every stage that writes per-hemisphere
# outputs, so identically-named files never collide across iterations.
# Without these, ab65d63 only fixed the extractor and downstream stages
# raced — producing waytotal=0 for ~20% of subjects under both-separately.
connections.append((inputs, probtrack_outdir_router, [("hemisphere", "hemisphere")]))
if atlas_extract is not None:
connections.append((inputs, atlas_extract, [("hemisphere", "hemisphere")]))
if atlas_transform is not None:
connections.append((inputs, atlas_transform, [("hemisphere", "hemisphere")]))
if atlas_extract is not None and atlas_transform is not None:
connections.append((atlas_extract, atlas_transform, ROI_TO_ROI))
if atlas_validate is not None:
upstream_for_validate = atlas_transform or atlas_extract
connections.append((upstream_for_validate, atlas_validate, ROI_TO_ROI))
if atlas_merge is not None and atlas_validate is not None:
connections.append((inputs, atlas_merge, [("hemisphere", "hemisphere")]))
# JoinNode wiring: collect each ROI field across MapNode iterations.
connections.append(
(
atlas_validate,
atlas_merge,
[
("seed", "seeds"),
("waypoints_file", "waypoints_files"),
("stop_mask", "stop_masks"),
("avoid_mask", "avoid_masks"),
("target_mask", "target_masks"),
],
)
)
if synthseg_resampler is not None and synthseg_extract is not None:
connections.append((inputs, synthseg_extract, [("hemisphere", "hemisphere")]))
connections.append(
(synthseg_resampler, synthseg_extract, [("resampled_segmentation", "roi_file")])
)
if synthseg_validate is not None:
connections.append((synthseg_extract, synthseg_validate, ROI_TO_ROI))
if final_merge is not None:
connections.append((inputs, final_merge, [("hemisphere", "hemisphere")]))
connections.append(
(
atlas_merge,
final_merge,
[
("seed", "left_seed"),
("waypoints_file", "left_waypoints_file"),
("stop_mask", "left_stop"),
("avoid_mask", "left_avoid"),
("target_mask", "left_target"),
],
)
)
connections.append(
(
synthseg_validate,
final_merge,
[
("seed", "right_seed"),
("waypoints_file", "right_waypoints_file"),
("stop_mask", "right_stop"),
("avoid_mask", "right_avoid"),
("target_mask", "right_target"),
],
)
)
roi_terminus = final_merge
elif atlas_merge is not None:
roi_terminus = atlas_merge
elif synthseg_validate is not None:
roi_terminus = synthseg_validate
else:
roi_terminus = None
if roi_terminus is not None and dwi_resampler is not None:
connections.append((inputs, dwi_resampler, [("hemisphere", "hemisphere")]))
connections.append((roi_terminus, dwi_resampler, ROI_TO_ROI))
# Route per-hemisphere out_dir into the three nodes that write into the
# ProbTrackX2 directory (probtrack itself, the params_writer, and the
# streamline warper). Without this, both hemisphere iterations under
# both-separately write fdt_paths.nii.gz / waytotal / tractography_params.json
# to the same path and the last writer wins.
connections.append((probtrack_outdir_router, probtrack, [("out_dir", "out_dir")]))
connections.append((probtrack_outdir_router, params_writer, [("out_dir", "output_dir")]))
kinds = _kinds_present(config)
if dwi_resampler is not None:
# waypoint_avoid_verifier and dwi_resampler are gated on the same
# has_any_roi predicate, so the verifier always exists in this branch.
# Pass all five ROI fields through it as a fail-fast guard — pure
# passthrough on success, PipelineError on catastrophic overlap.
assert waypoint_avoid_verifier is not None # for mypy / invariant doc
connections.append((dwi_resampler, waypoint_avoid_verifier, ROI_TO_ROI))
probtrack_source_node = waypoint_avoid_verifier
probtrack_source_pairs = [
(out_f, in_f)
for out_f, in_f in PROBTRACK_WIRING
if {
"seed": "seed",
"waypoints_file": "waypoint",
"stop_mask": "stop",
"avoid_mask": "avoid",
"target_mask": "target",
}[out_f]
in kinds
]
# Wrap target_mask in a list inline via Nipype's connect-with-function
# syntax — ProbTrackX2's ``target_masks`` input wants a list of paths.
plain_pairs = [(o, i) for o, i in probtrack_source_pairs if o != "target_mask"]
if plain_pairs:
connections.append((probtrack_source_node, probtrack, plain_pairs))
if "target" in kinds:
probtrack.inputs.os2t = True
connections.append(
(
probtrack_source_node,
probtrack,
[(("target_mask", _make_target_list), "target_masks")],
)
)
else:
_configure_fixed_path_inputs(probtrack, config, input_dir, out_dir)
if warper is not None:
# Feed the warper's directories from the outdir router — the single
# source of truth for the per-hemisphere directory ProbTrackX2 writes
# into. ``fdt_paths`` is the directory the warper scans for outputs and
# ``output_dir`` is ``<router_out_dir>/warped_streamlines``, avoiding any
# reverse-engineering of directories from ProbTrackX2's output paths.
connections.append(
(
probtrack_outdir_router,
warper,
[
("out_dir", "fdt_paths"),
(("out_dir", _warped_dir_from_outdir), "output_dir"),
],
)
)
# Ordering dependency: the router path is available before ProbTrackX2
# runs, so wire probtrackx2.fdt_paths into the warper's ignored ordering
# input. This forces the warper to run only after ProbTrackX2 has written
# its outputs into that directory (otherwise the warper scans an empty
# dir and silently produces no warped streamlines).
connections.append((probtrack, warper, [("fdt_paths", "_ordering_signal")]))
# Add every non-None node we created; Nipype dedups.
for node in (
inputs,
atlas_extract,
atlas_transform,
atlas_validate,
atlas_merge,
synthseg_resampler,
synthseg_extract,
synthseg_validate,
final_merge,
dwi_resampler,
waypoint_avoid_verifier,
probtrack_outdir_router,
probtrack,
params_writer,
warper,
):
if node is not None:
wf.add_nodes([node])
real_connections = [c for c in connections if c[0] is not None and c[1] is not None]
if real_connections:
wf.connect(real_connections)
# -- Input contract: runtime-overridable inputs fanned to internal nodes -
# Standalone runs need build-time defaults on the contract node itself, because
# the fan_out edges below overwrite any static values previously set on the
# consumer nodes (atlas_validate, probtrack, dwi_resampler, ...). Without
# defaults, an inputnode field that no meta-workflow has connected stays
# Undefined, which then propagates to downstream Function nodes and surfaces
# as ``TypeError: missing 1 required positional argument`` at execute time.
# ``t1_to_dwi_transform``, ``seg_map`` and ``roi_transform_gate`` are
# genuinely optional / meta-only — leave them Undefined; their consumers are
# either gated by conditionals (``dwi_resampler is not None``,
# ``synthseg_resampler is not None``) or accept Undefined (ordering signal).
# Meta-workflow edges still override these defaults — that's standard Nipype
# connection semantics.
inputnode_defaults: dict = {
"t1_brain": str(paths["t1_path"]),
"dwi_mask": str(paths["mask_path"]),
}
for field, key in (
("bedpostx_thsamples", "thsamples"),
("bedpostx_phsamples", "phsamples"),
("bedpostx_fsamples", "fsamples"),
):
value = paths.get(key)
if not value:
continue
# bedpostx samples are lists of files (probtrack thsamples/phsamples/
# fsamples are InputMultiPath). Seed as a list of path strings, not a
# stringified list — this default overrides probtrack's build-time value
# via the inputnode->probtrack connection at runtime, so it must itself be
# a valid multi-file value.
inputnode_defaults[field] = (
[str(item) for item in value] if isinstance(value, (list, tuple)) else [str(value)]
)
inputnode = attach_inputnode(
wf,
[
"t1_brain",
"dwi_mask",
"t1_to_dwi_transform",
"seg_map",
"bedpostx_thsamples",
"bedpostx_phsamples",
"bedpostx_fsamples",
"roi_transform_gate",
],
defaults=inputnode_defaults,
)
t1_targets: list = []
for n in (atlas_validate, synthseg_validate):
if n is not None:
t1_targets.append((n, "reference_image"))
if synthseg_resampler is not None:
t1_targets.append((synthseg_resampler, "reference"))
fan_out(wf, inputnode, "t1_brain", t1_targets)
if synthseg_resampler is not None:
wf.connect(inputnode, "seg_map", synthseg_resampler, "input_image")
# No warp_field field: atlas_transform is a MapNode whose warp_field /
# reference_image are per-source iterfields resolved from config at build
# time; a scalar runtime override would break per-source iteration. (The old
# full_pipeline broadcast fireants reverse_transforms to legacy roi_transformer
# nodes the MapNode design no longer produces, so this was already a no-op.)
# Ordering-only gate: block the atlas ROI warp until the patient->template
# registration has written its transforms (which atlas_transform reads by
# path). full_pipeline drives roi_transform_gate from registration via
# tract_synthseg.entry_gate; standalone runs leave it unset (transforms are
# precomputed). Broadcast onto the MapNode's non-iterfield _ordering_signal,
# the same mechanism as `hemisphere`; the task ignores the value.
if atlas_transform is not None:
wf.connect(inputnode, "roi_transform_gate", atlas_transform, "_ordering_signal")
if dwi_resampler is not None:
wf.connect(
[
(
inputnode,
dwi_resampler,
[
("dwi_mask", "reference"),
("t1_to_dwi_transform", "t1_to_dwi_transform"),
],
),
]
)
wf.connect(
[
(
inputnode,
probtrack,
[
("dwi_mask", "mask"),
("bedpostx_thsamples", "thsamples"),
("bedpostx_phsamples", "phsamples"),
("bedpostx_fsamples", "fsamples"),
],
),
]
)
# -- Output contract: fdt_paths, joined across the hemisphere iterable ----
# The ROI terminus (atlas_merge / final_merger / synthseg_validate) is a
# single combined Node, so its seed/stop/avoid/target outputs are scalar and
# are re-published as regular (non-join) outputnode fields under the stable
# contract names roi_seed/roi_stop/roi_avoid/roi_target. A meta-workflow
# (e.g. tract_synthseg ROI validation) wires to these instead of scanning the
# internal graph for a node name. They stay Undefined when no ROI assembly
# was built. fdt_paths remains the only field joined across hemispheres.
contract_fields = ["fdt_paths", *HCP_ROI_OUTPUT_FIELDS]
if len(hemispheres) > 1:
outputnode = JoinNode(
IdentityInterface(fields=contract_fields),
joinsource="inputs",
joinfield="fdt_paths",
name="outputnode",
)
else:
outputnode = Node(IdentityInterface(fields=contract_fields), name="outputnode")
wf.add_nodes([outputnode])
wf.connect(probtrack, "fdt_paths", outputnode, "fdt_paths")
if roi_terminus is not None:
wf.connect(
[
(
roi_terminus,
outputnode,
[
("seed", "roi_seed"),
("stop_mask", "roi_stop"),
("avoid_mask", "roi_avoid"),
("target_mask", "roi_target"),
],
)
]
)
return wf