"""MRtrix3 tractography workflow orchestration.
Sibling of the HCP ProbTrackX2 workflow that swaps the FSL BedpostX +
ProbTrackX2 tail for MRtrix3 (5ttgen → dhollander → msmt_csd →
mtnormalise → tckgen with ACT → tcksift2 → tckmap). The ROI extraction
pipeline (atlas → ANTs warp → validation → merging → DWI-grid resampling)
reuses the HCP ``MapNode``/``JoinNode`` factories.
Outputs land under ``<output_dir>/tractography/mrtrix3/``:
* ``_shared/`` — patient-level intermediates (mask, 5TT, GM-WM interface,
response functions, FODs, normalised FODs).
* Per-hemisphere tractography products plus ``tracks.tck``,
``sift2_weights.txt``, ``fdt_paths.nii.gz``, and ``waytotal``.
"""
from dataclasses import dataclass
from pathlib import Path
from nipype import Function, JoinNode, Node, Workflow
from nipype.interfaces.utility import IdentityInterface, Merge
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from thesis.core.contracts import MRTRIX3_ROI_OUTPUT_FIELDS, attach_inputnode, fan_out
from thesis.core.decorators import produces, requires, verify, workflow
from thesis.core.path_declarations import OutputDir, PatientFile
from ..hcp.nodes import (
prepare_atlas_extract_mapnode,
prepare_atlas_merge_joinnode,
prepare_atlas_transform_mapnode,
prepare_atlas_validate_mapnode,
prepare_final_merger,
prepare_outdir_router,
prepare_roi_dwi_resampler,
prepare_roi_validator,
prepare_streamline_warper,
prepare_synthseg_roi_extractor,
prepare_synthseg_seg_resampler,
prepare_waypoint_avoid_overlap_verifier,
)
from .common import get_validation_thresholds, resolve_with_fallback
from .config import MRtrix3Paths # noqa: F401 re-exported for tests
from .config import (
MRtrix3Params,
prepare_mrtrix3_params,
prepare_mrtrix3_paths,
)
from .config.paths import (
DEFAULT_BVAL_NAME,
DEFAULT_BVEC_NAME,
DEFAULT_DWI_NAME,
DEFAULT_MASK_NAME,
)
from .nodes import (
prepare_5tt_node,
prepare_5tt_to_dwi_node,
prepare_density_map_renamer,
prepare_dwi2mask_node,
prepare_dwi_import,
prepare_fod_node,
prepare_gmwmi_node,
prepare_mask_node,
prepare_mrtrix3_params_writer,
prepare_mtnormalise_node,
prepare_per_target_maps_node,
prepare_response_node,
prepare_tckgen_input_adapter,
prepare_tckgen_node,
prepare_tckmap_node,
prepare_tcksift2_node,
prepare_waytotal_writer,
)
from .verifiers import verify_requirements
from .verifiers.binaries import REQUIRED_BINARIES # noqa: F401 re-exported
ROI_FIELDS = ("seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask")
ROI_TO_ROI = [(f, f) for f in ROI_FIELDS]
def _hemispheres_for_mode(mode: str) -> list[str]:
if mode == "both-separately":
return ["left", "right"]
if mode in ("left", "right", "both"):
return [mode]
return ["both"]
def _fdt_dir_from_paths(fdt_paths):
"""Return the parent directory of an MRtrix3 fdt_paths output file."""
from pathlib import Path as _P
if not fdt_paths:
return ""
path = fdt_paths[0] if isinstance(fdt_paths, (list, tuple)) else fdt_paths
return str(_P(path).parent)
def _warped_dir_from_paths(fdt_paths):
"""Derive ``<fdt_dir>/warped_streamlines`` from an fdt_paths output file.
Mirrors the HCP workflow so the streamline warper writes inside whichever
per-hemisphere subdir ``fdt_paths.nii.gz`` was written into, without baking
the hemisphere into the warper at build time.
"""
from pathlib import Path as _P
if not fdt_paths:
return ""
path = fdt_paths[0] if isinstance(fdt_paths, (list, tuple)) else fdt_paths
return str(_P(path).parent / "warped_streamlines")
def _tdi_out_file(out_dir):
"""Build the per-hemisphere ``tckmap`` output path ``<out_dir>/tdi_raw.nii.gz``."""
from pathlib import Path as _P
if not out_dir:
return ""
return str(_P(out_dir) / "tdi_raw.nii.gz")
def _pack_grad_fsl(bvec: str, bval: str) -> tuple[str, str]:
"""Pack runtime bvec/bval paths into the tuple MRConvert.grad_fsl expects."""
return (str(bvec), str(bval))
[docs]
@dataclass
class SharedSubgraph:
"""Handles to the patient-level shared nodes (FOD, 5TT, mask, gmwmi)."""
mask_node: Node
dwi_import_node: Node
fivett_node: Node
fivett_t1_node: Node
gmwmi_node: Node | None
response_node: Node
fod_node: Node
mtnormalise_node: Node | None
wm_fod_owner: Node
wm_fod_field: str
is_multi_tissue: bool
mask_from_dwi: bool
def _build_shared_subgraph(
workflow: Workflow,
*,
t1_path: Path,
mask_path: Path,
dwi_image: Path,
bvec: Path,
bval: Path,
params: MRtrix3Params,
shared_out_dir: Path,
) -> SharedSubgraph:
"""Build the patient-level shared subgraph (FOD/5TT/mask/gmwmi)."""
shared_out_dir.mkdir(parents=True, exist_ok=True)
is_multi_tissue = params["fod_algorithm"] == "msmt_csd"
dwi_import_node = prepare_dwi_import(
dwi_image=dwi_image, bvec=bvec, bval=bval, out_dir=shared_out_dir
)
# Brain-mask source (config.tractography.mask_source):
# * 'fsl_nodif' (default) — import + dilate the HCP nodif_brain_mask.
# * 'dwi2mask' — derive the mask from the imported DWI .mif.
mask_from_dwi = params["mask_source"] == "dwi2mask"
if mask_from_dwi:
mask_node = prepare_dwi2mask_node(
out_dir=shared_out_dir, dilate_voxels=params["mask_dilate_voxels"]
)
else:
mask_node = prepare_mask_node(
mask_path=mask_path, out_dir=shared_out_dir, dilate_voxels=params["mask_dilate_voxels"]
)
algo = params["fivett_algorithm"]
fivett_t1_node = prepare_5tt_node(
out_dir=shared_out_dir,
algorithm=algo,
t1_path=t1_path if algo == "fsl" else None,
parc_path=None,
)
fivett_to_dwi_node = prepare_5tt_to_dwi_node(out_dir=shared_out_dir, reference_image=mask_path)
# Only build the GM-WM interface when it is actually consumed as the tckgen
# seed (seed_strategy='gmwmi'). Under ROI seeding it has no consumer, so
# building it would run 5tt2gmwmi for nothing.
gmwmi_node = (
prepare_gmwmi_node(out_dir=shared_out_dir) if params["seed_strategy"] == "gmwmi" else None
)
response_node = prepare_response_node(
out_dir=shared_out_dir,
algorithm=params["response_algorithm"],
)
fod_node = prepare_fod_node(
out_dir=shared_out_dir,
algorithm=params["fod_algorithm"],
)
nodes_to_add = [
dwi_import_node,
mask_node,
fivett_t1_node,
fivett_to_dwi_node,
response_node,
fod_node,
]
if gmwmi_node is not None:
nodes_to_add.append(gmwmi_node)
workflow.add_nodes(nodes_to_add)
connections = [
(dwi_import_node, response_node, [("out_file", "in_file")]),
(mask_node, response_node, [("out_file", "in_mask")]),
(dwi_import_node, fod_node, [("out_file", "in_file")]),
(mask_node, fod_node, [("out_file", "mask_file")]),
(response_node, fod_node, [("wm_file", "wm_txt")]),
(fivett_t1_node, fivett_to_dwi_node, [("out_file", "fivett_file")]),
]
if gmwmi_node is not None:
connections.append((fivett_to_dwi_node, gmwmi_node, [("out_file", "fivett_file")]))
if mask_from_dwi:
# dwi2mask consumes the imported DWI .mif (carries the gradient table).
connections.append((dwi_import_node, mask_node, [("out_file", "dwi_file")]))
if is_multi_tissue:
connections.append(
(response_node, fod_node, [("gm_file", "gm_txt"), ("csf_file", "csf_txt")])
)
mtnormalise_node: Node | None = None
wm_fod_owner: Node = fod_node
wm_fod_field: str = "wm_odf"
if params["use_mtnormalise"] and is_multi_tissue:
mtnormalise_node = prepare_mtnormalise_node(out_dir=shared_out_dir)
workflow.add_nodes([mtnormalise_node])
connections.append(
(
fod_node,
mtnormalise_node,
[
("wm_odf", "wm_fod"),
("gm_odf", "gm_fod"),
("csf_odf", "csf_fod"),
],
)
)
connections.append((mask_node, mtnormalise_node, [("out_file", "mask")]))
wm_fod_owner = mtnormalise_node
wm_fod_field = "out_file_wm"
workflow.connect(connections)
return SharedSubgraph(
mask_node=mask_node,
dwi_import_node=dwi_import_node,
fivett_node=fivett_to_dwi_node,
fivett_t1_node=fivett_t1_node,
gmwmi_node=gmwmi_node,
response_node=response_node,
fod_node=fod_node,
mtnormalise_node=mtnormalise_node,
wm_fod_owner=wm_fod_owner,
wm_fod_field=wm_fod_field,
is_multi_tissue=is_multi_tissue,
mask_from_dwi=mask_from_dwi,
)
def _prepare_synthseg(config, context, paths, out_dir):
"""Return (resampler, extractor, validator) or (None, None, None) when absent."""
extractor = prepare_synthseg_roi_extractor(config, Path(context.input_dir), context, out_dir)
if extractor is None:
return None, None, None
resampler = prepare_synthseg_seg_resampler(str(paths["t1_path"]), out_dir)
synthseg_roi_cfg = getattr(config.tractography, "synthseg_roi_labels", None) or {}
if synthseg_roi_cfg.get("roi_file"):
roi_file = resolve_with_fallback(
synthseg_roi_cfg["roi_file"], Path(context.input_dir), [context.data_dir]
)
resampler.inputs.input_image = str(roi_file)
validator = prepare_roi_validator(
reference_image=str(paths["t1_path"]),
name="synthseg_roi_validator",
**get_validation_thresholds(config),
)
return resampler, extractor, validator
def _build_tractography_subgraph(
wf: Workflow,
config: PipelineConfig,
context: ProcessingContext,
*,
inputs: Node,
inputnode: Node,
paths: MRtrix3Paths,
params: MRtrix3Params,
shared: SharedSubgraph,
out_dir: Path,
) -> tuple[Node | None, Node | None]:
"""Build the per-hemisphere ROI assembly + tractography tail.
Seeding is controlled by ``params["seed_strategy"]``:
* ``"roi"`` (default) — seed from the atlas/synthseg seed mask. Requires at
least one ROI source; returns ``None`` when none are configured.
* ``"gmwmi"`` — seed from the GM-WM interface (``-seed_gmwmi``, valid because
ACT's 5TT is always wired). The ROI assembly (waypoint/avoid/stop/target →
``-include``/``-exclude``/``-mask``) is still built when sources are
configured AND ``params["gmwmi_apply_roi_filters"]`` is True; otherwise the
tail tracks whole-brain with no ROI filtering.
Returns:
``(fdt_paths_writer, roi_terminus)``. ``fdt_paths_writer`` is the
density-map node whose ``out_file`` feeds the workflow ``outputnode``
(``None`` when no tractography tail was built). ``roi_terminus`` is the
node carrying the final ``seed``/``stop_mask``/``avoid_mask``/
``target_mask`` ROI outputs (``None`` when no ROI assembly was built),
exposed so the workflow can re-publish them on the contract outputnode.
"""
use_gmwmi_seed = params["seed_strategy"] == "gmwmi"
apply_roi_filters = bool(params["gmwmi_apply_roi_filters"])
atlas_extract = prepare_atlas_extract_mapnode(config, context, out_dir)
atlas_transform = prepare_atlas_transform_mapnode(config, context, out_dir)
atlas_validate = (
prepare_atlas_validate_mapnode(
get_validation_thresholds(config), str(paths["t1_path"]), out_dir
)
if atlas_extract is not None
else None
)
atlas_merge = (
prepare_atlas_merge_joinnode(atlas_validate.name, out_dir)
if atlas_validate is not None
else None
)
synthseg_resampler, synthseg_extract, synthseg_validate = _prepare_synthseg(
config, context, paths, out_dir
)
final_merge = (
prepare_final_merger(out_dir)
if (atlas_merge is not None and synthseg_validate is not None)
else None
)
has_roi_sources = atlas_extract is not None or synthseg_extract is not None
# ROI seeding has nowhere to draw a seed without an ROI source. (gmwmi seeding
# uses the GM-WM interface, so it can run whole-brain with no ROI sources.)
if not use_gmwmi_seed and not has_roi_sources:
return None, None
# Build the ROI assembly only when its masks are consumed:
# * ROI seeding always needs it (the seed lives there);
# * gmwmi seeding needs it only to apply include/exclude/stop filters.
use_roi_assembly = has_roi_sources and (not use_gmwmi_seed or apply_roi_filters)
if not use_roi_assembly:
# Whole-brain gmwmi tracking: drop the ROI assembly entirely. Null the
# handles so the add-nodes loop and the t1 fan-out skip them.
atlas_extract = atlas_transform = atlas_validate = atlas_merge = None
synthseg_resampler = synthseg_extract = synthseg_validate = None
final_merge = None
# Router that scopes every per-hemisphere output dir into <base>/<hemisphere>/
# (flat for "both"). Mirrors the HCP probtrackx2_outdir router so that under
# --hemisphere both-separately the left/right iterations of the ROI stages
# and the tckgen tail no longer collide on shared output paths.
outdir_router = prepare_outdir_router(out_dir, name="tractography_outdir")
connections: list = []
connections.append((inputs, outdir_router, [("hemisphere", "hemisphere")]))
dwi_resampler: Node | None = None
waypoint_avoid_verifier: Node | None = None
tckgen_inputs_node: Node | None = None
roi_terminus: Node | None = None
if use_roi_assembly:
dwi_resampler = prepare_roi_dwi_resampler(str(paths["mask_path"]), out_dir)
# Pre-tckgen fail-fast guard (mirrors the HCP workflow): catches a
# waypoint mask landing substantially inside the avoid mask, which makes
# tckgen's -include/-exclude mutually contradictory and silently yields 0
# streamlines (after which tcksift2 crashes downstream). Passthrough on
# success.
waypoint_avoid_verifier = prepare_waypoint_avoid_overlap_verifier()
if atlas_extract is not None:
connections.append((inputs, atlas_extract, [("hemisphere", "hemisphere")]))
if atlas_transform is not None:
connections.append((inputs, atlas_transform, [("hemisphere", "hemisphere")]))
connections.append((atlas_extract, atlas_transform, ROI_TO_ROI))
connections.append(((atlas_transform or atlas_extract), atlas_validate, ROI_TO_ROI))
# atlas_merge always exists when atlas_extract does (built from atlas_validate).
connections.append((inputs, atlas_merge, [("hemisphere", "hemisphere")]))
connections.append(
(
atlas_validate,
atlas_merge,
[
("seed", "seeds"),
("waypoints_file", "waypoints_files"),
("stop_mask", "stop_masks"),
("avoid_mask", "avoid_masks"),
("target_mask", "target_masks"),
],
)
)
if synthseg_resampler is not None and synthseg_extract is not None:
connections.append((inputs, synthseg_extract, [("hemisphere", "hemisphere")]))
connections.append(
(synthseg_resampler, synthseg_extract, [("resampled_segmentation", "roi_file")])
)
if synthseg_validate is not None:
connections.append((synthseg_extract, synthseg_validate, ROI_TO_ROI))
if final_merge is not None:
connections.append((inputs, final_merge, [("hemisphere", "hemisphere")]))
connections.append(
(
atlas_merge,
final_merge,
[
("seed", "left_seed"),
("waypoints_file", "left_waypoints_file"),
("stop_mask", "left_stop"),
("avoid_mask", "left_avoid"),
("target_mask", "left_target"),
],
)
)
connections.append(
(
synthseg_validate,
final_merge,
[
("seed", "right_seed"),
("waypoints_file", "right_waypoints_file"),
("stop_mask", "right_stop"),
("avoid_mask", "right_avoid"),
("target_mask", "right_target"),
],
)
)
roi_terminus = final_merge
elif atlas_merge is not None:
roi_terminus = atlas_merge
elif synthseg_validate is not None:
roi_terminus = synthseg_validate
if roi_terminus is not None:
connections.append((inputs, dwi_resampler, [("hemisphere", "hemisphere")]))
connections.append((roi_terminus, dwi_resampler, ROI_TO_ROI))
connections.append((dwi_resampler, waypoint_avoid_verifier, ROI_TO_ROI))
# Adapter: 5-tuple → tckgen-shaped inputs
tckgen_inputs_node = prepare_tckgen_input_adapter(name="tckgen_inputs")
connections.append(
(
waypoint_avoid_verifier,
tckgen_inputs_node,
[
("seed", "seed"),
("waypoints_file", "waypoints_file"),
("target_mask", "target_mask"),
("avoid_mask", "avoid_mask"),
("stop_mask", "stop_mask"),
],
)
)
tckgen_node = prepare_tckgen_node(params, out_dir=out_dir, name="tckgen")
connections.append((shared.wm_fod_owner, tckgen_node, [(shared.wm_fod_field, "fod_file")]))
connections.append((shared.fivett_node, tckgen_node, [("out_file", "act_file")]))
# Scope tckgen's tracks.tck output per hemisphere.
connections.append((outdir_router, tckgen_node, [("out_dir", "output_dir")]))
# Seed source. gmwmi_node exists in the shared subgraph exactly when
# seed_strategy='gmwmi'; -seed_gmwmi is valid because -act is wired above.
if use_gmwmi_seed:
assert shared.gmwmi_node is not None # invariant: built for gmwmi seeding
connections.append((shared.gmwmi_node, tckgen_node, [("out_file", "seed_gmwmi")]))
else:
assert tckgen_inputs_node is not None # ROI seeding implies ROI assembly
connections.append((tckgen_inputs_node, tckgen_node, [("seed_image", "seed_image")]))
# ROI filters (-include/-exclude/-mask) — applied whenever the ROI assembly
# exists, independent of the seeding mechanism.
if tckgen_inputs_node is not None:
connections.append(
(
tckgen_inputs_node,
tckgen_node,
[
("include_masks", "include_masks"),
("exclude_masks", "exclude_masks"),
("mask_stop", "mask_stop"),
],
)
)
weights_owner: Node | None = None
if params["use_sift2"]:
sift2_node = prepare_tcksift2_node(out_dir=out_dir, name="tcksift2")
connections.append((tckgen_node, sift2_node, [("out_file", "tracks_file")]))
connections.append((shared.wm_fod_owner, sift2_node, [(shared.wm_fod_field, "fod_file")]))
connections.append((shared.fivett_node, sift2_node, [("out_file", "fivett_file")]))
connections.append((outdir_router, sift2_node, [("out_dir", "output_dir")]))
weights_owner = sift2_node
else:
sift2_node = None
tckmap_node = prepare_tckmap_node(
template_image=paths["mask_path"],
out_dir=out_dir,
name="tckmap",
)
connections.append((tckgen_node, tckmap_node, [("out_file", "in_file")]))
# tckmap is a ComputeTDI interface (out_file, not output_dir): build the
# per-hemisphere tdi_raw.nii.gz path inline from the scoped directory.
connections.append((outdir_router, tckmap_node, [(("out_dir", _tdi_out_file), "out_file")]))
if weights_owner is not None:
connections.append((weights_owner, tckmap_node, [("weights_file", "tck_weights")]))
fdt_paths_writer = prepare_density_map_renamer(out_dir=out_dir, name="fdt_paths")
connections.append((tckmap_node, fdt_paths_writer, [("out_file", "in_file")]))
connections.append((outdir_router, fdt_paths_writer, [("out_dir", "output_dir")]))
waytotal_writer = prepare_waytotal_writer(out_dir=out_dir, name="waytotal_writer")
connections.append((tckgen_node, waytotal_writer, [("out_file", "tracks_file")]))
connections.append((outdir_router, waytotal_writer, [("out_dir", "output_dir")]))
if weights_owner is not None:
connections.append((weights_owner, waytotal_writer, [("weights_file", "weights_file")]))
# Per-target maps require a target ROI, which only exists via the ROI
# assembly. Whole-brain gmwmi tracking has no target, so skip the node.
per_target_node: Node | None = None
if waypoint_avoid_verifier is not None:
per_target_node = prepare_per_target_maps_node(out_dir=out_dir, name="per_target_maps")
connections.append((tckgen_node, per_target_node, [("out_file", "tracks_file")]))
connections.append((fdt_paths_writer, per_target_node, [("out_file", "reference_image")]))
connections.append(
(waypoint_avoid_verifier, per_target_node, [("target_mask", "target_mask")])
)
connections.append((outdir_router, per_target_node, [("out_dir", "output_dir")]))
if weights_owner is not None:
connections.append((weights_owner, per_target_node, [("weights_file", "weights_file")]))
params_writer = prepare_mrtrix3_params_writer(
out_dir=out_dir,
select=int(params["tckgen_select"]),
name="mrtrix3_params",
)
connections.append((outdir_router, params_writer, [("out_dir", "output_dir")]))
if weights_owner is not None:
connections.append((weights_owner, params_writer, [("mu_file", "mu_file")]))
warper = prepare_streamline_warper(config, context, out_dir)
warp_gate: Node | None = None
if warper is not None:
# Both directory inputs derived from the (already per-hemisphere)
# fdt_paths output so the warper writes inside the same scoped subdir and
# depends on fdt_paths_writer for ordering.
connections.append(
(
fdt_paths_writer,
warper,
[
(("out_file", _fdt_dir_from_paths), "fdt_paths"),
(("out_file", _warped_dir_from_paths), "output_dir"),
],
)
)
# The warper warps every recognised output in that directory
# (fdt_paths.nii.gz, waytotal.nii.gz, seeds_to_*.nii.gz). waytotal_writer
# and per_target_node write the latter two and otherwise race the warper
# (they fan off fdt_paths_writer / tckgen, not before the warper), so it
# could run before they finish and silently skip those maps. Gate the
# warper on their completion via a Merge feeding its ignored
# _ordering_signal.
ordering_sources: list[tuple[Node, str]] = [(waytotal_writer, "waytotal_file")]
if per_target_node is not None:
ordering_sources.append((per_target_node, "seeds_files"))
warp_gate = Node(Merge(len(ordering_sources)), name="warp_gate")
for _idx, (_src_node, _src_field) in enumerate(ordering_sources, start=1):
connections.append((_src_node, warp_gate, [(_src_field, f"in{_idx}")]))
connections.append((warp_gate, warper, [("out", "_ordering_signal")]))
# Add every non-None node to the workflow (Nipype dedups).
for node in (
atlas_extract,
atlas_transform,
atlas_validate,
atlas_merge,
synthseg_resampler,
synthseg_extract,
synthseg_validate,
final_merge,
dwi_resampler,
waypoint_avoid_verifier,
outdir_router,
tckgen_inputs_node,
tckgen_node,
sift2_node,
tckmap_node,
fdt_paths_writer,
waytotal_writer,
per_target_node,
params_writer,
warp_gate,
warper,
):
if node is not None:
wf.add_nodes([node])
real_connections = [c for c in connections if c[0] is not None and c[1] is not None]
if real_connections:
wf.connect(real_connections)
# -- Input-contract fan-out (per-hemisphere consumers) ------------------
# t1_brain → VALIDATORS (regular Nodes, shared reference_image) + synthseg
# resampler. DO NOT wire atlas_transform (MapNode per-source iterfields).
t1_targets: list = []
for n in (atlas_validate, synthseg_validate):
if n is not None:
t1_targets.append((n, "reference_image"))
if synthseg_resampler is not None:
t1_targets.append((synthseg_resampler, "reference"))
if t1_targets:
fan_out(wf, inputnode, "t1_brain", t1_targets)
if synthseg_resampler is not None:
wf.connect(inputnode, "seg_map", synthseg_resampler, "input_image")
if roi_terminus is not None and dwi_resampler is not None:
wf.connect(
[
(
inputnode,
dwi_resampler,
[
("dwi_mask", "reference"),
("t1_to_dwi_transform", "t1_to_dwi_transform"),
],
),
]
)
# tckmap reference grid is the DWI brain mask.
wf.connect(inputnode, "dwi_mask", tckmap_node, "reference")
return fdt_paths_writer, roi_terminus
[docs]
@workflow(
name="mrtrix3",
description=(
"MRtrix3 tractography workflow: 5ttgen + dhollander + msmt_csd + "
"mtnormalise + tckgen (ACT) + tcksift2 + tckmap. Reuses the HCP "
"ROI extraction pipeline (MapNode/JoinNode based)."
),
protocol="mrtrix3",
)
@requires(
t1=PatientFile(
default="T1w/T1w_acpc_dc_restore_brain.nii.gz",
config_paths=["synthseg.t1_image", "hcp.t1_image"],
),
mask=PatientFile(
default=f"T1w/Diffusion/{DEFAULT_MASK_NAME}",
config_path="hcp.mask_path",
),
dwi=PatientFile(
default=f"T1w/Diffusion/{DEFAULT_DWI_NAME}",
config_path="mrtrix3.dwi_name",
),
bvec=PatientFile(
default=f"T1w/Diffusion/{DEFAULT_BVEC_NAME}",
config_path="mrtrix3.bvec_name",
),
bval=PatientFile(
default=f"T1w/Diffusion/{DEFAULT_BVAL_NAME}",
config_path="mrtrix3.bval_name",
),
)
@produces(tract_dir=OutputDir("tractography/mrtrix3"))
@verify(verify_requirements)
def build_workflow(
*,
t1: Path,
mask: Path,
dwi: Path,
bvec: Path,
bval: Path,
tract_dir: Path,
config: PipelineConfig,
context: ProcessingContext,
) -> Workflow:
"""Build the MRtrix3 tractography workflow for one patient.
Hemisphere modes via ``config.tractography.hemisphere``: ``"left"``/``"right"``
(single hemisphere via fixed input), ``"both"`` (merged ROIs, default),
``"both-separately"`` (two hemispheres via Nipype ``iterables``).
"""
wf = Workflow(name=f"mrtrix3_{context.patient_id}")
paths = prepare_mrtrix3_paths(config, context)
# Prefer @requires-resolved paths when they exist on disk; otherwise fall
# back to what prepare_mrtrix3_paths returned (which may be a runtime path
# produced by upstream stages in full_pipeline (mrtrix3 backend), or a mocked path).
if t1.exists():
paths = {**paths, "t1_path": t1}
if mask.exists():
paths = {**paths, "mask_path": mask}
if dwi.exists():
paths = {**paths, "dwi_image": dwi}
if bvec.exists():
paths = {**paths, "bvec": bvec}
if bval.exists():
paths = {**paths, "bval": bval}
params = prepare_mrtrix3_params(config)
base_tract_dir = tract_dir.resolve()
shared = _build_shared_subgraph(
wf,
t1_path=paths["t1_path"],
mask_path=paths["mask_path"],
dwi_image=paths["dwi_image"],
bvec=paths["bvec"],
bval=paths["bval"],
params=params,
shared_out_dir=base_tract_dir / "_shared",
)
inputs = Node(IdentityInterface(fields=["hemisphere"]), name="inputs")
hemispheres = _hemispheres_for_mode(getattr(config.tractography, "hemisphere", "both"))
if len(hemispheres) > 1:
inputs.iterables = ("hemisphere", hemispheres)
else:
inputs.inputs.hemisphere = hemispheres[0]
wf.add_nodes([inputs])
# Seed the contract inputnode from the resolved paths so a standalone
# ``-w mrtrix3`` run resolves statically (nothing upstream feeds this node).
# ``seg_map`` and ``t1_to_dwi_transform`` stay Undefined: seg_map is only
# supplied by tract_synthseg, and HCP T1/diffusion share a space so no
# transform is needed. Embedded full_pipeline / tract_synthseg runs
# override these via Nipype connections to this node.
inputnode = attach_inputnode(
wf,
[
"t1_brain",
"dwi_mask",
"t1_to_dwi_transform",
"seg_map",
"dwi_corrected",
"dwi_bvec",
"dwi_bval",
],
defaults={
"t1_brain": str(paths["t1_path"]),
"dwi_mask": str(paths["mask_path"]),
"dwi_corrected": str(paths["dwi_image"]),
"dwi_bvec": str(paths["bvec"]),
"dwi_bval": str(paths["bval"]),
},
)
# Shared-node fan-out (all regular Nodes). The brain-mask node's input
# depends on mask_source: fsl_nodif imports the runtime nodif_brain_mask
# (dwi_mask), whereas dwi2mask derives the mask from the imported DWI .mif
# (wired inside _build_shared_subgraph) and takes no dwi_mask edge here.
if not shared.mask_from_dwi:
wf.connect(inputnode, "dwi_mask", shared.mask_node, "mask_path")
wf.connect(inputnode, "dwi_corrected", shared.dwi_import_node, "in_file")
grad_combiner = Node(
Function(input_names=["bvec", "bval"], output_names=["grad_fsl"], function=_pack_grad_fsl),
name="dwi_grad_combiner",
)
wf.add_nodes([grad_combiner])
wf.connect(
[
(inputnode, grad_combiner, [("dwi_bvec", "bvec"), ("dwi_bval", "bval")]),
(grad_combiner, shared.dwi_import_node, [("grad_fsl", "grad_fsl")]),
]
)
if params["fivett_algorithm"] == "fsl":
wf.connect(inputnode, "t1_brain", shared.fivett_t1_node, "t1_path")
elif params["fivett_algorithm"] == "freesurfer":
# 5ttgen freesurfer consumes a FreeSurfer-LUT parcellation; the live
# SynthSeg label map arrives on the contract seg_map field (supplied by
# tract_synthseg). Standalone mrtrix3 leaves seg_map unset.
wf.connect(inputnode, "seg_map", shared.fivett_t1_node, "parc_path")
wf.connect(
[
(
inputnode,
shared.fivett_node,
[
("dwi_mask", "reference"),
("t1_to_dwi_transform", "t1_to_dwi_transform"),
],
),
]
)
fdt_paths_writer, roi_terminus = _build_tractography_subgraph(
wf,
config,
context,
inputs=inputs,
inputnode=inputnode,
paths=paths,
params=params,
shared=shared,
out_dir=base_tract_dir,
)
# Output contract. ``fdt_paths`` is always exposed; the ROI terminus
# outputs (roi_seed/roi_stop/roi_avoid/roi_target) are re-published so a
# meta-workflow (e.g. tract_synthseg validation) can wire to them by
# name instead of scanning the internal graph. They stay Undefined when
# no ROI assembly was built (whole-brain gmwmi tracking).
#
# NOTE: under ``hemisphere: both-separately`` the outputnode is a JoinNode,
# so each contract field collects a per-hemisphere list rather than a single
# path; downstream consumers must handle the list form in that mode.
if fdt_paths_writer is not None:
contract_fields = ["fdt_paths", *MRTRIX3_ROI_OUTPUT_FIELDS]
if len(hemispheres) > 1:
outputnode = JoinNode(
IdentityInterface(fields=contract_fields),
joinsource="inputs",
joinfield=contract_fields,
name="outputnode",
)
else:
outputnode = Node(IdentityInterface(fields=contract_fields), name="outputnode")
wf.add_nodes([outputnode])
wf.connect(fdt_paths_writer, "out_file", outputnode, "fdt_paths")
if roi_terminus is not None:
wf.connect(
[
(
roi_terminus,
outputnode,
[
("seed", "roi_seed"),
("stop_mask", "roi_stop"),
("avoid_mask", "roi_avoid"),
("target_mask", "roi_target"),
],
)
]
)
return wf