Source code for thesis.workflows.hcp.nodes.roi_mapnode
"""MapNode/JoinNode factories for the atlas ROI pipeline.
These replace the per-source ``prepare_roi_extractor``/``prepare_roi_transformer``/
``prepare_roi_validator``/``prepare_roi_merger`` factory pattern used by the
legacy HCP workflow. The new design uses Nipype's native ``MapNode`` to iterate
over atlas sources and a ``JoinNode`` to merge results — eliminating the
per-source Python loop and the binary merger reduce tree.
"""
from pathlib import Path
from typing import Optional
from nipype import MapNode, Node
from nipype.interfaces.utility import Function
from thesis.core.config import PipelineConfig
from thesis.core.context import ProcessingContext
from ..common import (
ValidationThresholds,
format_patient_path,
get_atlas_sources,
get_atlas_transform_spec,
get_trial_root,
resolve_with_fallback,
)
from ..operations import (
extract_rois_task,
merge_rois_task,
transform_rois_task,
validate_warped_rois_passthrough_task,
)
_ROI_FIELDS = ("seed", "waypoints_file", "stop_mask", "avoid_mask", "target_mask")
[docs]
def prepare_atlas_extract_mapnode(
config: PipelineConfig,
context: ProcessingContext,
out_dir: Path,
name: str = "atlas_extract",
) -> Optional[MapNode]:
"""Build a ``MapNode`` that runs ``extract_rois_task`` once per atlas source.
Returns ``None`` if the config declares no atlas sources.
"""
atlas_sources = get_atlas_sources(config)
if not atlas_sources:
return None
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve()
roi_files: list[str] = []
label_files: list[str] = []
waypoint_labels: list[dict] = []
output_dirs: list[str] = []
for index, source in enumerate(atlas_sources, start=1):
source_name = str(source.get("name", f"atlas_{index}"))
roi_files.append(
str(resolve_with_fallback(source.get("roi_file", ""), input_dir, [context.data_dir]))
)
label_file_raw = source.get("label_file", "")
label_files.append(
str(resolve_with_fallback(label_file_raw, input_dir, [context.data_dir]))
if label_file_raw
else ""
)
waypoint_labels.append(source.get("waypoint_labels", {}))
output_dirs.append(str(out_dir / "rois" / source_name))
mapnode = MapNode(
Function(
input_names=[
"roi_file",
"label_file",
"waypoint_labels",
"output_dir",
"hemisphere",
],
output_names=list(_ROI_FIELDS),
function=extract_rois_task,
),
iterfield=["roi_file", "label_file", "waypoint_labels", "output_dir"],
name=name,
)
mapnode.inputs.roi_file = roi_files
mapnode.inputs.label_file = label_files
mapnode.inputs.waypoint_labels = waypoint_labels
mapnode.inputs.output_dir = output_dirs
mapnode.inputs.hemisphere = "both"
return mapnode
[docs]
def prepare_atlas_transform_mapnode(
config: PipelineConfig,
context: ProcessingContext,
out_dir: Path,
name: str = "atlas_transform",
) -> Optional[MapNode]:
"""Build a ``MapNode`` running ``transform_rois_task`` per atlas source.
Returns ``None`` if no atlas sources have a transform configured (in which
case the upstream extract MapNode connects directly to the validate
MapNode). Sources whose transform list is empty still appear as iterations,
with empty ``warp_field`` causing the task to no-op-passthrough.
"""
atlas_sources = get_atlas_sources(config)
if not atlas_sources:
return None
input_dir = Path(context.input_dir).resolve() if context.input_dir else Path(".").resolve()
# template_to_patient / reference_image typically point at the current run's
# own registration outputs (e.g. outputs/{pid}/registration/transforms/...),
# so anchor relative paths at the trial root. input_dir / data_dir remain as
# fallbacks for precomputed transforms that already exist there (standalone).
trial_root = get_trial_root(context)
patient_id = getattr(context, "patient_id", "")
warp_fields: list = []
reference_images: list[str] = []
output_dirs: list[str] = []
any_transform = False
for index, source in enumerate(atlas_sources, start=1):
source_name = str(source.get("name", f"atlas_{index}"))
spec = get_atlas_transform_spec(config, source) or {}
warp_raw = spec.get("template_to_patient", "")
ref_raw = spec.get("reference_image", "")
if not (warp_raw and ref_raw):
warp_fields.append([])
reference_images.append("")
output_dirs.append(str(out_dir / "rois_transformed" / source_name))
continue
any_transform = True
if isinstance(warp_raw, list):
warp_resolved = [
str(
resolve_with_fallback(
format_patient_path(value, patient_id),
trial_root,
[input_dir, context.data_dir],
)
)
for value in warp_raw
]
else:
warp_resolved = [
str(
resolve_with_fallback(
format_patient_path(warp_raw, patient_id),
trial_root,
[input_dir, context.data_dir],
)
)
]
ref_resolved = str(
resolve_with_fallback(
format_patient_path(ref_raw, patient_id),
trial_root,
[input_dir, context.data_dir],
)
)
warp_fields.append(warp_resolved)
reference_images.append(ref_resolved)
output_dirs.append(str(out_dir / "rois_transformed" / source_name))
if not any_transform:
return None
mapnode = MapNode(
Function(
input_names=[
"seed",
"waypoints_file",
"stop_mask",
"avoid_mask",
"target_mask",
"warp_field",
"reference_image",
"output_dir",
"hemisphere",
"_ordering_signal",
],
output_names=list(_ROI_FIELDS),
function=transform_rois_task,
),
iterfield=[
"seed",
"waypoints_file",
"stop_mask",
"avoid_mask",
"target_mask",
"warp_field",
"reference_image",
"output_dir",
],
name=name,
)
mapnode.inputs.warp_field = warp_fields
mapnode.inputs.reference_image = reference_images
mapnode.inputs.output_dir = output_dirs
# Default — overridden via wf.connect(inputs.hemisphere -> .hemisphere) so
# each iteration of the outer hemisphere iterable scopes its outputs into
# <out_dir>/rois_transformed/<source>/<hemisphere>/.
mapnode.inputs.hemisphere = "both"
return mapnode
[docs]
def prepare_atlas_validate_mapnode(
thresholds: ValidationThresholds,
reference_image: str,
out_dir: Path,
name: str = "atlas_validate",
) -> MapNode:
"""Build a ``MapNode`` that validates each per-source ROI bundle."""
mapnode = MapNode(
Function(
input_names=[
"seed",
"waypoints_file",
"stop_mask",
"avoid_mask",
"target_mask",
"reference_image",
"min_voxels",
"singularity_threshold",
"volume_ratio_min",
"volume_ratio_max",
],
output_names=list(_ROI_FIELDS),
function=validate_warped_rois_passthrough_task,
),
iterfield=list(_ROI_FIELDS),
name=name,
)
mapnode.inputs.reference_image = reference_image
mapnode.inputs.min_voxels = thresholds["min_voxels"]
mapnode.inputs.singularity_threshold = thresholds["singularity_threshold"]
mapnode.inputs.volume_ratio_min = thresholds["volume_ratio_min"]
mapnode.inputs.volume_ratio_max = thresholds["volume_ratio_max"]
return mapnode
[docs]
def prepare_atlas_merge_joinnode(
joinsource_name: str,
out_dir: Path,
name: str = "atlas_merge",
) -> Node:
"""Build a ``Node`` that collects per-source ROI bundles into one bundle.
Connected from an upstream ``MapNode`` (``atlas_validate``) whose per-source
outputs are aggregated into lists by the MapNode itself; the regular ``Node``
receives those lists directly via Nipype's standard connection semantics.
A ``JoinNode`` was used here originally, but under outer ``iterables`` on
hemisphere (``--hemisphere both-separately``) Nipype's join machinery
collected zero iterations from the MapNode joinsource and ``merge_rois_task``
was invoked with empty lists for every field — producing all-empty outputs
and a downstream ``probtrackx2.seed=''`` crash. ``joinsource_name`` is kept
in the signature for backward compatibility with callers and is unused.
"""
del joinsource_name # kept for backward compatibility with workflow callers
node = Node(
Function(
input_names=[
"seeds",
"waypoints_files",
"stop_masks",
"avoid_masks",
"target_masks",
"output_dir",
"hemisphere",
],
output_names=list(_ROI_FIELDS),
function=merge_rois_task,
),
name=name,
)
node.inputs.output_dir = str(out_dir / "rois_merged")
node.inputs.hemisphere = "both"
return node
[docs]
def prepare_final_merger(out_dir: Path, name: str = "final_merger") -> Node:
"""Build a binary 2-input merger combining atlas and SynthSeg bundles."""
from ..operations.merging import binary_merge_rois_task
node = Node(
Function(
input_names=[
"left_seed",
"left_waypoints_file",
"left_stop",
"left_avoid",
"left_target",
"right_seed",
"right_waypoints_file",
"right_stop",
"right_avoid",
"right_target",
"output_dir",
"hemisphere",
],
output_names=list(_ROI_FIELDS),
function=binary_merge_rois_task,
),
name=name,
)
node.inputs.output_dir = str(out_dir / "rois_final")
node.inputs.hemisphere = "both"
return node